Skip to content
Snippets Groups Projects
Commit f498de92 authored by Noordijk, Ben's avatar Noordijk, Ben
Browse files

Updated inference procedure of deepnano and guppy to increase their performance

parent 30b35d31
No related branches found
No related tags found
No related merge requests found
......@@ -114,9 +114,18 @@ def parse_paf(paf_path, ground_truth):
y_true = merged_df['species id'].apply(lambda x:
1 if x.find(target_id) == 0
else 0)
# Mapping to a * means mapping to no species
y_pred = merged_df['Target sequence name'].apply(
lambda x: 0 if x == '*' else 1)
if tool == 'deepnano':
residue_matches = np.array(merged_df["Number of residue matches"])
alignment_length = np.array(merged_df["Alignment block length"])
align_perc = residue_matches / alignment_length
y_pred = (align_perc > 0.50).astype(int)
elif tool == 'uncalled':
# Mapping to a * means mapping to no species
y_pred = merged_df['Target sequence name'].apply(
lambda x: 0 if x == '*' else 1)
else:
raise ValueError('Tool not found')
f1 = f1_score(y_true, y_pred)
accuracy = accuracy_score(y_true, y_pred)
cm = confusion_matrix(y_true, y_pred)
......@@ -204,8 +213,7 @@ def calculate_accuracy_from_output(args):
"squigglenet/*/fold?/inference")
# Prepare arguments to parse to the functions in the pool
args_baseless = [[file, ground_truth] for file in baseless_files]
args_uncalled_deepnano = [[file, ground_truth] for file
in chain(deepnano_files, uncalled_files)]
args_uncalled = [[file, ground_truth] for file in uncalled_files]
args_guppy = [[file, ground_truth] for file in guppy_files]
args_squigglenet = [[folder, ground_truth]
for folder in squigglenet_files]
......@@ -213,11 +221,11 @@ def calculate_accuracy_from_output(args):
baseless_results = p.starmap(parse_baseless_output, args_baseless)
squigglenet_results = p.starmap(parse_squigglenet_output,
args_squigglenet)
uncalled_deepnano_results = p.starmap(parse_paf,
args_uncalled_deepnano)
uncalled_results = p.starmap(parse_paf, args_uncalled)
guppy_results = p.starmap(parse_sequencing_summary, args_guppy)
all_records = (uncalled_deepnano_results + guppy_results
+ squigglenet_results + baseless_results)
deepnano_results = [parse_paf(file, ground_truth) for file in deepnano_files]
all_records = (guppy_results + deepnano_results + uncalled_results
+ baseless_results + squigglenet_results)
assert all_records, ('No files found: make sure the input folder '
'contains directories called "uncalled",'
' "guppy", "baseless" and "deepnano"')
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment