diff --git a/threshold_search/optimize_thresholds_analysis.sf b/threshold_search/optimize_thresholds_analysis.sf
index 6b147a30388a042739b62e998598149ce9ee6c4e..f25d2f7a52220575ad680467d9346b93def039d6 100644
--- a/threshold_search/optimize_thresholds_analysis.sf
+++ b/threshold_search/optimize_thresholds_analysis.sf
@@ -27,7 +27,8 @@ rule combine:
     output:
         f1_csv=f'{out_dir}f1.csv',
         precision_csv=f'{out_dir}precision.csv',
-        recall_csv=f'{out_dir}recall.csv'
+        recall_csv=f'{out_dir}recall.csv',
+        pr_plot=f'{out_dir}pr_plot.svg'
     run:
         gt_df = pd.read_csv(input.ground_truth_csv).set_index('file_name')
 
@@ -41,27 +42,33 @@ rule combine:
         h_idx, v_idx = np.arange(len(gt_df)), gt_df.species_numeric
         bool_mat = np.zeros((len(gt_df), len(species_list) + 1),dtype=bool)
         bool_mat[h_idx, v_idx] = True
-        gt_bool_df = pd.DataFrame(bool_mat[:, :-1], index=gt_df.index,columns=species_list)
+        gt_bool_df = pd.DataFrame(bool_mat[:, :-1],index=gt_df.index,columns=species_list)
 
         for csv in input.thresholded_csvs:
             lt = splitext(basename(csv))[0]
-            thresholded_df = pd.read_csv(csv, index_col=0)
+            thresholded_df = pd.read_csv(csv,index_col=0)
             thresholded_mat = thresholded_df.to_numpy()
-            gt_df_sub = gt_bool_df.loc[thresholded_df.index, :].to_numpy()
-            pos_kmer_counts = thresholded_mat[gt_df_sub]
-            neg_kmer_counts = thresholded_mat[np.invert(gt_df_sub)]
-            all_kmer_counts = np.concatenate((pos_kmer_counts, neg_kmer_counts))
-            nb_kmers = thresholded_df.max().max()
+            max_k_counts = thresholded_mat.max(axis=1)  # for each read, find which model found max number of kmers
+            nb_kmers = thresholded_mat.max()
+            gt_bool_mat = gt_bool_df.loc[thresholded_df.index, :].to_numpy()
+            for i in range(1,nb_kmers):
+                thd = (thresholded_df > i).to_numpy()
+                tp_df.loc[lt, i] = thd[gt_bool_mat].sum()
+                predPos_df.loc[lt, i] = thd.sum()
+                pos_df = gt_bool_mat.sum()
+
+            thresholded_df.loc[:, 'pred_species'] = thresholded_df.idxmax(axis=1)
+            thresholded_df.loc[:, 'ground_truth'] = gt_df.loc[thresholded_df.index, 'species_short']
+            thresholded_df.loc[:, 'pred_correct'] = thresholded_df.pred_species == thresholded_df.ground_truth
 
-            for i in range(1, nb_kmers):
-                tp_df.loc[lt, i] = np.sum(pos_kmer_counts > i)
-                predPos_df.loc[lt, i] = np.sum(all_kmer_counts > i)
-                pos_df.loc[lt, i] = len(pos_kmer_counts)
+            pos_kmer_counts = max_k_counts[thresholded_df.pred_correct]
+            neg_kmer_counts = max_k_counts[np.invert(thresholded_df.pred_correct)]
+            all_kmer_counts = np.concatenate((pos_kmer_counts, neg_kmer_counts))
 
             plot_df = pd.DataFrame({'kmer_count': all_kmer_counts,
                                     'correct': [True] * len(pos_kmer_counts) + [False] * len(neg_kmer_counts)})
-            fig = plt.Figure(figsize=(10,10), dpi=400)
-            sns.histplot(x='kmer_count', hue='correct', data=plot_df, bins=np.arange(1,plot_df.kmer_count.max() + 1))
+            fig = plt.Figure(figsize=(10, 10),dpi=400)
+            sns.histplot(x='kmer_count',hue='correct',data=plot_df,bins=np.arange(1,plot_df.kmer_count.max() + 1))
             plt.savefig(f'{histograms_dir}hist_logitThreshold{lt}.svg')
             plt.clf(); plt.close(fig)
         precision_df = tp_df / predPos_df
@@ -70,8 +77,24 @@ rule combine:
         precision_df.to_csv(output.precision_csv)
         recall_df.to_csv(output.recall_csv)
         f1_df.to_csv(output.f1_csv)
+        f1_log, f1_k = f1_df[f1_df == f1_df.max().max()].stack().index.tolist()[0]
+        max_coords = (float(recall_df.loc[f1_log, f1_k].mean()), float(precision_df.loc[f1_log, f1_k].mean()))
 
-        # --- precision recall graphs ---
+        re_melted_df = recall_df.reset_index().melt(id_vars='index').rename({'index': 'logit',
+                                                                             'variable': 'k',
+                                                                             'value': 'recall'},axis=1).set_index(['logit', 'k'])
+        pr_melted_df = precision_df.reset_index().melt(id_vars='index').rename({'index': 'logit',
+                                                                                'variable': 'k',
+                                                                                'value': 'precision'},axis=1).set_index(['logit', 'k'])
+        melted_df = pd.concat((re_melted_df, pr_melted_df),axis=1)
+        k_list = list(recall_df.columns)
+        k_half = k_list[len(k_list) // 2]
+        melted_df.reset_index(inplace=True)
+        melted_df.k = melted_df.k.astype(int)
+        sns.lineplot(data=melted_df,x='recall',y='precision',hue='k')
+        plt.plot(max_coords[0],max_coords[1],'ro')
+        plt.text(*max_coords,f'F1_max={round(f1_df.max().max(),3)}')
+        plt.savefig(output.pr_plot, dpi=400)
 
 
 rule apply_thresholds: