diff --git a/compare_benchmark_performance/compare_accuracy.py b/compare_benchmark_performance/compare_accuracy.py index c9d0978b2ca59bda5086476dffdfbbb6e70983f5..1e63371d644417620b6340f14fdc0347ea6cc5d1 100644 --- a/compare_benchmark_performance/compare_accuracy.py +++ b/compare_benchmark_performance/compare_accuracy.py @@ -114,9 +114,18 @@ def parse_paf(paf_path, ground_truth): y_true = merged_df['species id'].apply(lambda x: 1 if x.find(target_id) == 0 else 0) - # Mapping to a * means mapping to no species - y_pred = merged_df['Target sequence name'].apply( - lambda x: 0 if x == '*' else 1) + if tool == 'deepnano': + residue_matches = np.array(merged_df["Number of residue matches"]) + alignment_length = np.array(merged_df["Alignment block length"]) + align_perc = residue_matches / alignment_length + y_pred = (align_perc > 0.50).astype(int) + elif tool == 'uncalled': + # Mapping to a * means mapping to no species + y_pred = merged_df['Target sequence name'].apply( + lambda x: 0 if x == '*' else 1) + else: + raise ValueError('Tool not found') + f1 = f1_score(y_true, y_pred) accuracy = accuracy_score(y_true, y_pred) cm = confusion_matrix(y_true, y_pred) @@ -204,8 +213,7 @@ def calculate_accuracy_from_output(args): "squigglenet/*/fold?/inference") # Prepare arguments to parse to the functions in the pool args_baseless = [[file, ground_truth] for file in baseless_files] - args_uncalled_deepnano = [[file, ground_truth] for file - in chain(deepnano_files, uncalled_files)] + args_uncalled = [[file, ground_truth] for file in uncalled_files] args_guppy = [[file, ground_truth] for file in guppy_files] args_squigglenet = [[folder, ground_truth] for folder in squigglenet_files] @@ -213,11 +221,11 @@ def calculate_accuracy_from_output(args): baseless_results = p.starmap(parse_baseless_output, args_baseless) squigglenet_results = p.starmap(parse_squigglenet_output, args_squigglenet) - uncalled_deepnano_results = p.starmap(parse_paf, - args_uncalled_deepnano) + uncalled_results = p.starmap(parse_paf, args_uncalled) guppy_results = p.starmap(parse_sequencing_summary, args_guppy) - all_records = (uncalled_deepnano_results + guppy_results - + squigglenet_results + baseless_results) + deepnano_results = [parse_paf(file, ground_truth) for file in deepnano_files] + all_records = (guppy_results + deepnano_results + uncalled_results + + baseless_results + squigglenet_results) assert all_records, ('No files found: make sure the input folder ' 'contains directories called "uncalled",' ' "guppy", "baseless" and "deepnano"')