Skip to content
Snippets Groups Projects
Commit 83e37100 authored by Lannoy, Carlos de's avatar Lannoy, Carlos de
Browse files

overhaul threshold analysis, add precision recall plot

parent fcdb00a7
No related branches found
No related tags found
No related merge requests found
...@@ -27,7 +27,8 @@ rule combine: ...@@ -27,7 +27,8 @@ rule combine:
output: output:
f1_csv=f'{out_dir}f1.csv', f1_csv=f'{out_dir}f1.csv',
precision_csv=f'{out_dir}precision.csv', precision_csv=f'{out_dir}precision.csv',
recall_csv=f'{out_dir}recall.csv' recall_csv=f'{out_dir}recall.csv',
pr_plot=f'{out_dir}pr_plot.svg'
run: run:
gt_df = pd.read_csv(input.ground_truth_csv).set_index('file_name') gt_df = pd.read_csv(input.ground_truth_csv).set_index('file_name')
...@@ -41,27 +42,33 @@ rule combine: ...@@ -41,27 +42,33 @@ rule combine:
h_idx, v_idx = np.arange(len(gt_df)), gt_df.species_numeric h_idx, v_idx = np.arange(len(gt_df)), gt_df.species_numeric
bool_mat = np.zeros((len(gt_df), len(species_list) + 1),dtype=bool) bool_mat = np.zeros((len(gt_df), len(species_list) + 1),dtype=bool)
bool_mat[h_idx, v_idx] = True bool_mat[h_idx, v_idx] = True
gt_bool_df = pd.DataFrame(bool_mat[:, :-1], index=gt_df.index,columns=species_list) gt_bool_df = pd.DataFrame(bool_mat[:, :-1],index=gt_df.index,columns=species_list)
for csv in input.thresholded_csvs: for csv in input.thresholded_csvs:
lt = splitext(basename(csv))[0] lt = splitext(basename(csv))[0]
thresholded_df = pd.read_csv(csv, index_col=0) thresholded_df = pd.read_csv(csv,index_col=0)
thresholded_mat = thresholded_df.to_numpy() thresholded_mat = thresholded_df.to_numpy()
gt_df_sub = gt_bool_df.loc[thresholded_df.index, :].to_numpy() max_k_counts = thresholded_mat.max(axis=1) # for each read, find which model found max number of kmers
pos_kmer_counts = thresholded_mat[gt_df_sub] nb_kmers = thresholded_mat.max()
neg_kmer_counts = thresholded_mat[np.invert(gt_df_sub)] gt_bool_mat = gt_bool_df.loc[thresholded_df.index, :].to_numpy()
all_kmer_counts = np.concatenate((pos_kmer_counts, neg_kmer_counts)) for i in range(1,nb_kmers):
nb_kmers = thresholded_df.max().max() thd = (thresholded_df > i).to_numpy()
tp_df.loc[lt, i] = thd[gt_bool_mat].sum()
predPos_df.loc[lt, i] = thd.sum()
pos_df = gt_bool_mat.sum()
thresholded_df.loc[:, 'pred_species'] = thresholded_df.idxmax(axis=1)
thresholded_df.loc[:, 'ground_truth'] = gt_df.loc[thresholded_df.index, 'species_short']
thresholded_df.loc[:, 'pred_correct'] = thresholded_df.pred_species == thresholded_df.ground_truth
for i in range(1, nb_kmers): pos_kmer_counts = max_k_counts[thresholded_df.pred_correct]
tp_df.loc[lt, i] = np.sum(pos_kmer_counts > i) neg_kmer_counts = max_k_counts[np.invert(thresholded_df.pred_correct)]
predPos_df.loc[lt, i] = np.sum(all_kmer_counts > i) all_kmer_counts = np.concatenate((pos_kmer_counts, neg_kmer_counts))
pos_df.loc[lt, i] = len(pos_kmer_counts)
plot_df = pd.DataFrame({'kmer_count': all_kmer_counts, plot_df = pd.DataFrame({'kmer_count': all_kmer_counts,
'correct': [True] * len(pos_kmer_counts) + [False] * len(neg_kmer_counts)}) 'correct': [True] * len(pos_kmer_counts) + [False] * len(neg_kmer_counts)})
fig = plt.Figure(figsize=(10,10), dpi=400) fig = plt.Figure(figsize=(10, 10),dpi=400)
sns.histplot(x='kmer_count', hue='correct', data=plot_df, bins=np.arange(1,plot_df.kmer_count.max() + 1)) sns.histplot(x='kmer_count',hue='correct',data=plot_df,bins=np.arange(1,plot_df.kmer_count.max() + 1))
plt.savefig(f'{histograms_dir}hist_logitThreshold{lt}.svg') plt.savefig(f'{histograms_dir}hist_logitThreshold{lt}.svg')
plt.clf(); plt.close(fig) plt.clf(); plt.close(fig)
precision_df = tp_df / predPos_df precision_df = tp_df / predPos_df
...@@ -70,8 +77,24 @@ rule combine: ...@@ -70,8 +77,24 @@ rule combine:
precision_df.to_csv(output.precision_csv) precision_df.to_csv(output.precision_csv)
recall_df.to_csv(output.recall_csv) recall_df.to_csv(output.recall_csv)
f1_df.to_csv(output.f1_csv) f1_df.to_csv(output.f1_csv)
f1_log, f1_k = f1_df[f1_df == f1_df.max().max()].stack().index.tolist()[0]
max_coords = (float(recall_df.loc[f1_log, f1_k].mean()), float(precision_df.loc[f1_log, f1_k].mean()))
# --- precision recall graphs --- re_melted_df = recall_df.reset_index().melt(id_vars='index').rename({'index': 'logit',
'variable': 'k',
'value': 'recall'},axis=1).set_index(['logit', 'k'])
pr_melted_df = precision_df.reset_index().melt(id_vars='index').rename({'index': 'logit',
'variable': 'k',
'value': 'precision'},axis=1).set_index(['logit', 'k'])
melted_df = pd.concat((re_melted_df, pr_melted_df),axis=1)
k_list = list(recall_df.columns)
k_half = k_list[len(k_list) // 2]
melted_df.reset_index(inplace=True)
melted_df.k = melted_df.k.astype(int)
sns.lineplot(data=melted_df,x='recall',y='precision',hue='k')
plt.plot(max_coords[0],max_coords[1],'ro')
plt.text(*max_coords,f'F1_max={round(f1_df.max().max(),3)}')
plt.savefig(output.pr_plot, dpi=400)
rule apply_thresholds: rule apply_thresholds:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment