From f498de920f591b67d333ec821964c9a33ca20096 Mon Sep 17 00:00:00 2001 From: noord087 <ben.noordijk@wur.nl> Date: Wed, 30 Mar 2022 14:56:44 +0200 Subject: [PATCH] Updated inference procedure of deepnano and guppy to increase their performance --- .../compare_accuracy.py | 26 ++++++++++++------- 1 file changed, 17 insertions(+), 9 deletions(-) diff --git a/compare_benchmark_performance/compare_accuracy.py b/compare_benchmark_performance/compare_accuracy.py index c9d0978..1e63371 100644 --- a/compare_benchmark_performance/compare_accuracy.py +++ b/compare_benchmark_performance/compare_accuracy.py @@ -114,9 +114,18 @@ def parse_paf(paf_path, ground_truth): y_true = merged_df['species id'].apply(lambda x: 1 if x.find(target_id) == 0 else 0) - # Mapping to a * means mapping to no species - y_pred = merged_df['Target sequence name'].apply( - lambda x: 0 if x == '*' else 1) + if tool == 'deepnano': + residue_matches = np.array(merged_df["Number of residue matches"]) + alignment_length = np.array(merged_df["Alignment block length"]) + align_perc = residue_matches / alignment_length + y_pred = (align_perc > 0.50).astype(int) + elif tool == 'uncalled': + # Mapping to a * means mapping to no species + y_pred = merged_df['Target sequence name'].apply( + lambda x: 0 if x == '*' else 1) + else: + raise ValueError('Tool not found') + f1 = f1_score(y_true, y_pred) accuracy = accuracy_score(y_true, y_pred) cm = confusion_matrix(y_true, y_pred) @@ -204,8 +213,7 @@ def calculate_accuracy_from_output(args): "squigglenet/*/fold?/inference") # Prepare arguments to parse to the functions in the pool args_baseless = [[file, ground_truth] for file in baseless_files] - args_uncalled_deepnano = [[file, ground_truth] for file - in chain(deepnano_files, uncalled_files)] + args_uncalled = [[file, ground_truth] for file in uncalled_files] args_guppy = [[file, ground_truth] for file in guppy_files] args_squigglenet = [[folder, ground_truth] for folder in squigglenet_files] @@ -213,11 +221,11 @@ def calculate_accuracy_from_output(args): baseless_results = p.starmap(parse_baseless_output, args_baseless) squigglenet_results = p.starmap(parse_squigglenet_output, args_squigglenet) - uncalled_deepnano_results = p.starmap(parse_paf, - args_uncalled_deepnano) + uncalled_results = p.starmap(parse_paf, args_uncalled) guppy_results = p.starmap(parse_sequencing_summary, args_guppy) - all_records = (uncalled_deepnano_results + guppy_results - + squigglenet_results + baseless_results) + deepnano_results = [parse_paf(file, ground_truth) for file in deepnano_files] + all_records = (guppy_results + deepnano_results + uncalled_results + + baseless_results + squigglenet_results) assert all_records, ('No files found: make sure the input folder ' 'contains directories called "uncalled",' ' "guppy", "baseless" and "deepnano"') -- GitLab