From 2d27b6f3d63196271cce1a787d2c0895bf8dc006 Mon Sep 17 00:00:00 2001 From: noord087 <ben.noordijk@wur.nl> Date: Wed, 4 May 2022 15:19:35 +0200 Subject: [PATCH] Balanced dataset for all species is now used to calculate performance. Also increased deepnano accuracy. --- .../compare_accuracy_per_read.py | 30 ++++++++++++++----- 1 file changed, 23 insertions(+), 7 deletions(-) diff --git a/compare_benchmark_performance/compare_accuracy_per_read.py b/compare_benchmark_performance/compare_accuracy_per_read.py index 0f2d08d..727c03c 100644 --- a/compare_benchmark_performance/compare_accuracy_per_read.py +++ b/compare_benchmark_performance/compare_accuracy_per_read.py @@ -39,6 +39,19 @@ SPECIES_TO_ID = { } +def make_balanced_df(df): + """From dataframe, make sure each species is present 100 times + + :param df: dataframe that contains column called 'species id' + :type df: pd.Dataframe + :return: dataframe with balanced species + """ + # Remove species that are too rare + df = df.groupby('species id').filter(lambda x: len(x) > 100) + df = df.groupby('species id').sample(100) + return df + + def get_performance_metrics_from_predictions(y_pred, y_true): """Given two arrays of predict and true labels, return accuracy score, confusion matrix and f1 score""" @@ -71,8 +84,8 @@ def parse_sequencing_summary(file_path, ground_truth, return_df=False): df_guppy = pd.read_csv(file_path, sep='\t') df_guppy.rename(columns={'read_id': 'read id'}, inplace=True) df = df_guppy.merge(ground_truth, on='read id') - # df = df.groupby('species id').filter(lambda x: len(x) > 100) - # df = df.groupby('species id').sample(100)['file name'] + df = make_balanced_df(df) + # print('Guppy accuracy on df of size', df.shape) y_pred = df['alignment_identity'].apply(lambda x: 0 if x < 0.95 else 1) df['y_pred'] = y_pred @@ -122,17 +135,17 @@ def parse_paf(paf_path, ground_truth): uncalled_paf = pd.read_csv(paf_path, sep='\t', names=headers, usecols=range(12)) df = uncalled_paf.merge(ground_truth, on='read id') - # df = df.groupby('species id').filter(lambda x: len(x) > 100) - # df = df.groupby('species id').sample(100)['file name'] + df = make_balanced_df(df) + # print(f'{tool} accuracy on df of size', df.shape) # Set truth to 1 if ground truth species is found y_true = df['species id'].apply(lambda x: 1 if x.find(target_id) == 0 else 0) if tool == 'deepnano': - residue_matches = np.array(df["Number of residue matches"]) + query_len = np.array(df["Query sequence length"]) alignment_length = np.array(df["Alignment block length"]) - align_perc = residue_matches / alignment_length - y_pred = (align_perc > 0.50).astype(int) + align_perc = alignment_length / query_len + y_pred = (align_perc > 0.90).astype(int) elif tool == 'uncalled': # Mapping to a * means mapping to no species y_pred = df['Target sequence name'].apply( @@ -169,6 +182,8 @@ def parse_squigglenet_output(in_dir, ground_truth): df = pd.DataFrame.from_records(predictions, columns=['read id', 'y pred']) merged_df = df.merge(ground_truth, on='read id') + merged_df = make_balanced_df(merged_df) + # print('Squigglenet accuracy on df of size', merged_df.shape) y_true = merged_df['species id'].apply(lambda x: 1 if x.find(target_id) == 0 else 0) @@ -196,6 +211,7 @@ def parse_baseless_output(file, ground_truth, for_pr_plot=False): baseless_output = pd.read_csv(file, names=['file name', 'species_is_found']) merged_df = baseless_output.merge(ground_truth, on='file name') + # print('Baseless accuracy on df of size', merged_df.shape) y_true = merged_df['species id'].apply(lambda x: 1 if x.find(target_id) == 0 else 0) -- GitLab