From 2d27b6f3d63196271cce1a787d2c0895bf8dc006 Mon Sep 17 00:00:00 2001
From: noord087 <ben.noordijk@wur.nl>
Date: Wed, 4 May 2022 15:19:35 +0200
Subject: [PATCH] Balanced dataset for all species is now used to calculate
 performance. Also increased deepnano accuracy.

---
 .../compare_accuracy_per_read.py              | 30 ++++++++++++++-----
 1 file changed, 23 insertions(+), 7 deletions(-)

diff --git a/compare_benchmark_performance/compare_accuracy_per_read.py b/compare_benchmark_performance/compare_accuracy_per_read.py
index 0f2d08d..727c03c 100644
--- a/compare_benchmark_performance/compare_accuracy_per_read.py
+++ b/compare_benchmark_performance/compare_accuracy_per_read.py
@@ -39,6 +39,19 @@ SPECIES_TO_ID = {
 }
 
 
+def make_balanced_df(df):
+    """From dataframe, make sure each species is present 100 times
+
+    :param df: dataframe that contains column called 'species id'
+    :type df: pd.Dataframe
+    :return: dataframe with balanced species
+    """
+    # Remove species that are too rare
+    df = df.groupby('species id').filter(lambda x: len(x) > 100)
+    df = df.groupby('species id').sample(100)
+    return df
+
+
 def get_performance_metrics_from_predictions(y_pred, y_true):
     """Given two arrays of predict and true labels, return accuracy score,
     confusion matrix and f1 score"""
@@ -71,8 +84,8 @@ def parse_sequencing_summary(file_path, ground_truth, return_df=False):
     df_guppy = pd.read_csv(file_path, sep='\t')
     df_guppy.rename(columns={'read_id': 'read id'}, inplace=True)
     df = df_guppy.merge(ground_truth, on='read id')
-    # df = df.groupby('species id').filter(lambda x: len(x) > 100)
-    # df = df.groupby('species id').sample(100)['file name']
+    df = make_balanced_df(df)
+    # print('Guppy accuracy on df of size', df.shape)
 
     y_pred = df['alignment_identity'].apply(lambda x: 0 if x < 0.95 else 1)
     df['y_pred'] = y_pred
@@ -122,17 +135,17 @@ def parse_paf(paf_path, ground_truth):
     uncalled_paf = pd.read_csv(paf_path, sep='\t', names=headers,
                                usecols=range(12))
     df = uncalled_paf.merge(ground_truth, on='read id')
-    # df = df.groupby('species id').filter(lambda x: len(x) > 100)
-    # df = df.groupby('species id').sample(100)['file name']
+    df = make_balanced_df(df)
+    # print(f'{tool} accuracy on df of size', df.shape)
     # Set truth to 1 if ground truth species is found
     y_true = df['species id'].apply(lambda x:
                                            1 if x.find(target_id) == 0
                                            else 0)
     if tool == 'deepnano':
-        residue_matches = np.array(df["Number of residue matches"])
+        query_len = np.array(df["Query sequence length"])
         alignment_length = np.array(df["Alignment block length"])
-        align_perc = residue_matches / alignment_length
-        y_pred = (align_perc > 0.50).astype(int)
+        align_perc = alignment_length / query_len
+        y_pred = (align_perc > 0.90).astype(int)
     elif tool == 'uncalled':
         # Mapping to a * means mapping to no species
         y_pred = df['Target sequence name'].apply(
@@ -169,6 +182,8 @@ def parse_squigglenet_output(in_dir, ground_truth):
     df = pd.DataFrame.from_records(predictions, columns=['read id', 'y pred'])
 
     merged_df = df.merge(ground_truth, on='read id')
+    merged_df = make_balanced_df(merged_df)
+    # print('Squigglenet accuracy on df of size', merged_df.shape)
     y_true = merged_df['species id'].apply(lambda x:
                                            1 if x.find(target_id) == 0
                                            else 0)
@@ -196,6 +211,7 @@ def parse_baseless_output(file, ground_truth, for_pr_plot=False):
 
     baseless_output = pd.read_csv(file, names=['file name', 'species_is_found'])
     merged_df = baseless_output.merge(ground_truth, on='file name')
+    # print('Baseless accuracy on df of size', merged_df.shape)
     y_true = merged_df['species id'].apply(lambda x:
                                            1 if x.find(target_id) == 0
                                            else 0)
-- 
GitLab