Skip to content
Snippets Groups Projects
Commit 950c8699 authored by Noordijk, Ben's avatar Noordijk, Ben
Browse files

Find optimal cutoff and print PR curves for baseless inference on sample

parent 01251377
No related branches found
No related tags found
No related merge requests found
from pathlib import Path
import argparse
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import precision_recall_curve, PrecisionRecallDisplay
from inference.InferenceModel import InferenceModel
from inference.run_inference_naive import scan_one_file
from compare_benchmark_performance.compare_accuracy_per_sample import \
create_in_silico_samples
def baseless_find_if_species_in_sample(read_dir, model, reads):
"""From directory of reads and list of reads to inspect, yield percent of
reads that are classified as target species
:param read_dir:
:param model:
:param reads:
:return:
"""
read_dir = Path(read_dir)
threshold = 5
k_mer_fraction = 0.55
stride = 125
mod = InferenceModel(model, threshold=threshold, kmer_frac=k_mer_fraction)
all_output = []
for i, fast5_file in enumerate(reads):
if i % 100 == 0:
print(i/len(reads))
read_path = read_dir / fast5_file
all_output.append(scan_one_file(read_path, mod, stride))
df = pd.DataFrame.from_records(all_output,
columns=['file_name', 'species_is_found'])
return sum(df.species_is_found) / len(df)
def main(args):
"""This function works based on the InferenceModel object"""
target = 'escherichia coli'
ground_truth = pd.read_csv(args.ground_truth)
# sample_size = 5000
# nr_pos_reads = 50
sample_size = 500
nr_pos_reads = 5
baseless_predictions = []
ground_truth_label = []
for i in range(10):
print(f'Repetition {i}')
print('pos example')
reads_to_inspect = create_in_silico_samples(ground_truth,
args.input_fold_folders, 0,
True, target, sample_size, nr_pos_reads)
baseless_predictions.append(baseless_find_if_species_in_sample(args.reads_path, args.model_path,
reads_to_inspect))
ground_truth_label.append(1)
print('neg example')
reads_to_inspect = create_in_silico_samples(ground_truth,
args.input_fold_folders, 0,
True, target, sample_size,
nr_pos_examples=0)
baseless_predictions.append(baseless_find_if_species_in_sample(args.reads_path, args.model_path,
reads_to_inspect))
ground_truth_label.append(0)
precision, recall, thresholds = precision_recall_curve(ground_truth_label,
baseless_predictions)
print(precision, recall, thresholds, sep='\n\n')
PrecisionRecallDisplay.from_predictions(ground_truth_label, baseless_predictions)
plt.savefig(args.out_path / 'pr_plot.png')
if not np.any(precision <= 0.99):
# All thresholds work
return thresholds[-1]
best_index = np.argwhere(precision <= 0.99)[0] - 1
best_treshold = thresholds[best_index]
return best_treshold
if __name__ == '__main__':
parser = argparse.ArgumentParser(description="""Plot accuracy of
benchmarked algorithms. Provide only performance CSV and out-dir
or all arguments except for input-performance-csv""")
parser.add_argument('--model-path',
help='Path to compiled model',
required=True, type=Path)
parser.add_argument('--ground-truth',
help='Path to csv with ground truth labels. '
'It is output by set_ground_truths_of_reads.py',
required=True, type=Path)
parser.add_argument('--reads-path',
help='Path to directory that contains fast5 reads',
required=True, type=Path)
parser.add_argument('--input-fold-folders',
help='Path to folders of input folds',
required=True, type=Path)
parser.add_argument('--out-path',
help='File path to output txt',
required=True, type=Path)
args = parser.parse_args()
print(main(args))
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment