From d6ed77cb1c70816b1e1e2fe8ff9eda84c584d576 Mon Sep 17 00:00:00 2001
From: noord087 <ben.noordijk@wur.nl>
Date: Tue, 26 Apr 2022 13:40:44 +0200
Subject: [PATCH] Baseless cutoff at 55% of kmers and some minor refactoring

---
 .../compare_accuracy_per_read.py              | 20 +++++++++++--------
 1 file changed, 12 insertions(+), 8 deletions(-)

diff --git a/compare_benchmark_performance/compare_accuracy_per_read.py b/compare_benchmark_performance/compare_accuracy_per_read.py
index 0b9482f..0f2d08d 100644
--- a/compare_benchmark_performance/compare_accuracy_per_read.py
+++ b/compare_benchmark_performance/compare_accuracy_per_read.py
@@ -71,6 +71,9 @@ def parse_sequencing_summary(file_path, ground_truth, return_df=False):
     df_guppy = pd.read_csv(file_path, sep='\t')
     df_guppy.rename(columns={'read_id': 'read id'}, inplace=True)
     df = df_guppy.merge(ground_truth, on='read id')
+    # df = df.groupby('species id').filter(lambda x: len(x) > 100)
+    # df = df.groupby('species id').sample(100)['file name']
+
     y_pred = df['alignment_identity'].apply(lambda x: 0 if x < 0.95 else 1)
     df['y_pred'] = y_pred
     if return_df:
@@ -118,23 +121,24 @@ def parse_paf(paf_path, ground_truth):
 
     uncalled_paf = pd.read_csv(paf_path, sep='\t', names=headers,
                                usecols=range(12))
-    merged_df = uncalled_paf.merge(ground_truth, on='read id')
+    df = uncalled_paf.merge(ground_truth, on='read id')
+    # df = df.groupby('species id').filter(lambda x: len(x) > 100)
+    # df = df.groupby('species id').sample(100)['file name']
     # Set truth to 1 if ground truth species is found
-    y_true = merged_df['species id'].apply(lambda x:
+    y_true = df['species id'].apply(lambda x:
                                            1 if x.find(target_id) == 0
                                            else 0)
     if tool == 'deepnano':
-        residue_matches = np.array(merged_df["Number of residue matches"])
-        alignment_length = np.array(merged_df["Alignment block length"])
+        residue_matches = np.array(df["Number of residue matches"])
+        alignment_length = np.array(df["Alignment block length"])
         align_perc = residue_matches / alignment_length
         y_pred = (align_perc > 0.50).astype(int)
     elif tool == 'uncalled':
         # Mapping to a * means mapping to no species
-        y_pred = merged_df['Target sequence name'].apply(
+        y_pred = df['Target sequence name'].apply(
             lambda x: 0 if x == '*' else 1)
     else:
         raise ValueError('Tool not found')
-
     accuracy, cm, f1 = get_performance_metrics_from_predictions(y_pred, y_true)
     return tool, target_species, fold, accuracy, f1, cm
 
@@ -195,7 +199,7 @@ def parse_baseless_output(file, ground_truth, for_pr_plot=False):
     y_true = merged_df['species id'].apply(lambda x:
                                            1 if x.find(target_id) == 0
                                            else 0)
-    y_pred = [0 if float(i) < 0.55 else 1 for i in merged_df.species_is_found]
+    y_pred = [1 if float(i) > 0.55 else 0 for i in merged_df.species_is_found]
     if not y_pred:
         # No positive examples
         return None
@@ -281,7 +285,7 @@ def main(args):
                                             if x.find('gingivalis') > 0
                                             else True)
     df = df[is_not_gingivalis]
-    # TODO Temporary; remove A. odontolyliticus fold 1 because it contained no pos examples
+    # Temporary; remove A. odontolyliticus fold 1 because it contained no pos examples
     i_to_remove = df[((df.tool == 'baseless')
                       & (df.species =='actinomyces odontolyticus')
                       & (df.fold == 1))].index
-- 
GitLab