From ba83bc1131f70f2e4b1d7cea948ddd538e86f52f Mon Sep 17 00:00:00 2001
From: Carlos de Lannoy <cvdelannoy@gmail.com>
Date: Wed, 2 Feb 2022 21:29:37 +0100
Subject: [PATCH] misc

---
 inference/analyse_abundance_results.py        |  3 +--
 inference/compile_model.py                    | 12 +++++++---
 tools/plot_nn_performance.py                  |  2 ++
 .../quick_benchmark_abundance_estimation.py   | 23 +++++++++++--------
 .../quick_benchmark_abundance_estimation.sf   |  1 -
 5 files changed, 26 insertions(+), 15 deletions(-)

diff --git a/inference/analyse_abundance_results.py b/inference/analyse_abundance_results.py
index 9b5ed72..5c10cfd 100644
--- a/inference/analyse_abundance_results.py
+++ b/inference/analyse_abundance_results.py
@@ -47,8 +47,7 @@ def analyse_abundance_results(freq_tables_dict, analysis_sample_dict, all_specie
     for species in analysis_sample_dict:
         for sa in analysis_sample_dict[species]:
             for rep in analysis_sample_dict[species][sa]:
-                ae_fn = analysis_sample_dict[species][sa][rep]
-                ae_df = pd.read_csv(ae_fn, index_col=0)
+                ae_df = pd.read_csv(analysis_sample_dict[species][sa][rep], index_col=0)
                 kmers = list(ae_df.columns)
                 ae_cumsum_df = ae_df.cumsum()
                 ts_df = pd.DataFrame({'species': species, 'sample_id': sa, 'rep': rep, 'rank_score': None}, index=ae_cumsum_df.index)
diff --git a/inference/compile_model.py b/inference/compile_model.py
index c23b164..6224203 100644
--- a/inference/compile_model.py
+++ b/inference/compile_model.py
@@ -157,13 +157,18 @@ def train_on_the_fly(kmer_list, available_mod_dict, args):
 
 def filter_accuracy(kmer_dict, acc_threshold):
     out_dict = {}
+    discard_list = []
     for kmd in kmer_dict:
         perf_fn = kmer_dict[kmd] + '/performance.pkl'
         if not os.path.isfile(perf_fn): continue
         with open(perf_fn, 'rb') as fh: perf_dict = pickle.load(fh)
-        if perf_dict['val_binary_accuracy'][-1] > acc_threshold:
+        # metric = 2 / ( perf_dict['val_precision'][-1] ** -1 + perf_dict['val_recall'][-1] ** -1)  # F1
+        metric = perf_dict['val_binary_accuracy'][-1]  # plain accuracy
+        if metric > acc_threshold:
             out_dict[kmd] = kmer_dict[kmd]
-    return out_dict
+        else:
+            discard_list.append(kmd)
+    return out_dict, discard_list
 
 
 def main(args):
@@ -205,7 +210,8 @@ def main(args):
         raise ValueError('Either provide --nn-directory or --target-16S')
 
     if args.accuracy_threshold:
-        target_kmer_dict = filter_accuracy(target_kmer_dict, args.accuracy_threshold)
+        target_kmer_dict, discard_list = filter_accuracy(target_kmer_dict, args.accuracy_threshold)
+        print(f'discarded {len(discard_list)} k-mers because accuracy is lower than threshold: {discard_list}')
 
     if args.model_type == 'binary':
         mod = compile_model(target_kmer_dict,
diff --git a/tools/plot_nn_performance.py b/tools/plot_nn_performance.py
index dcf7418..0028dca 100755
--- a/tools/plot_nn_performance.py
+++ b/tools/plot_nn_performance.py
@@ -39,6 +39,8 @@ def main():
     # accuracy
     sns.violinplot(y='accuracy', data=performance_df, color="0.8", ax=ax_acc)
     sns.stripplot(y='accuracy', data=performance_df, ax=ax_acc)
+    for km, tup in performance_df.iterrows():
+        ax_acc.text(x=0.1, y=tup.accuracy, s=km, fontsize=5)
     ll = min(performance_df.precision.min(), performance_df.recall.min()) - 0.01
     ax_pr.set_ylim(ll,1); ax_pr.set_xlim(ll,1)
     ax_pr.set_aspect('equal')
diff --git a/validation/quick_benchmark_abundance_estimation.py b/validation/quick_benchmark_abundance_estimation.py
index 60e0807..4e7f09e 100644
--- a/validation/quick_benchmark_abundance_estimation.py
+++ b/validation/quick_benchmark_abundance_estimation.py
@@ -32,6 +32,8 @@ parser.add_argument('--nb-repeats', type=int, default=5,
                     help='Number of repeated samplings to perform from given test reads [default: 5]')
 parser.add_argument('--training-read-dir', type=str, required=True,
                     help='directory containing resquiggled training reads')
+parser.add_argument('--kmer-mod-dir', type=str, required=False,
+                    help='Directory containing pretrained nns')
 #--- output ---
 parser.add_argument('--out-dir', type=str, required=True)
 #--- params ---
@@ -51,9 +53,8 @@ out_dir = parse_output_path(args.out_dir, clean=True)
 test_read_dir = parse_output_path(out_dir + 'test_reads')
 genomes_dir = parse_output_path(out_dir + 'genomes')
 
-
 species_dict = {}
-test_species_list = []
+sample_id_list = []
 sample_dict = {}
 mp_list = []
 with open(args.index_csv, 'r') as fh:
@@ -63,18 +64,17 @@ with open(args.index_csv, 'r') as fh:
             assert genome_fn == species_dict[species]  # two samples for same species cannot list different genomes
         else:
             species_dict[species] = genome_fn
-
         if test_dir == 'None': continue
         sample_dict[species] = sample_dict.get(species, []) + [sample_id]
 
         # prepare copying test reads
         if test_dir[-1] != '/': test_dir += '/'
-        test_species_list.append(species)
+        sample_id_list.append(sample_id)
         read_list = [test_dir + fn for fn in os.listdir(test_dir)]
         chunk_size = len(read_list) // args.nb_repeats
         for ns in range(args.nb_repeats):
             mp_list.append((read_list[chunk_size * ns:chunk_size * (ns+1)], f'{test_read_dir}{species}/{sample_id}/', args.max_test_reads, ns))
-with mp.Pool(min(len(test_species_list), args.cores)) as pool:
+with mp.Pool(min(len(sample_id_list), args.cores)) as pool:
     pool.map(copy_dir, mp_list)
 
 for species in species_dict:
@@ -84,11 +84,17 @@ for species in species_dict:
         if bgs == species: continue
         os.symlink(species_dict[bgs], f'{cur_bg_dir}{bgs}.fasta')
 
-
 # kmer nns for different species need to be separated for simultaneous running. Pregenerate folders for that.
 nn_dir = parse_output_path(out_dir + 'kmer_nns')
-for species in test_species_list:
+for species in list(sample_dict):
     _ = parse_output_path(nn_dir + species)
+if args.kmer_mod_dir:
+    kmer_mod_dir = args.kmer_mod_dir
+    if kmer_mod_dir[-1] != '/': kmer_mod_dir += '/'
+    kmer_mods_species = os.listdir(kmer_mod_dir)
+    for species in list(sample_dict):
+        if species in kmer_mods_species:
+            shutil.copytree(f'{kmer_mod_dir}{species}/nns', f'{nn_dir}{species}/nns')
 
 with open(args.parameter_file, 'r') as fh: param_dict = yaml.load(fh, Loader=yaml.FullLoader)
 
@@ -107,7 +113,6 @@ sf_txt = Template(template_txt).render(
     nn_dir=nn_dir,
     kmer_size=param_dict['kmer_size'],
     baseless_location=baseless_location,
-    test_species_list=test_species_list,
     sample_dict=sample_dict,
     model_size=args.model_size,
     nb_repeats=args.nb_repeats,
@@ -138,4 +143,4 @@ for ac in abundance_csv_list:
 
 
 analysis_species_list = list(species_dict)
-analyse_abundance_results(freq_tables_dict, analysis_sample_dict, out_dir + 'analysis')
+analyse_abundance_results(freq_tables_dict, analysis_sample_dict, analysis_species_list, out_dir + 'analysis')
diff --git a/validation/quick_benchmark_abundance_estimation.sf b/validation/quick_benchmark_abundance_estimation.sf
index 958c9db..7108f90 100644
--- a/validation/quick_benchmark_abundance_estimation.sf
+++ b/validation/quick_benchmark_abundance_estimation.sf
@@ -18,7 +18,6 @@ nn_dir = '{{ nn_dir }}'  # needs to have species-level subfolders created!
 baseless_location = '{{ baseless_location }}'
 kmer_size = {{ kmer_size }}
 sample_dict = {{ sample_dict }}
-test_species_list = {{ test_species_list }}
 model_size = {{ model_size }}
 nb_repeats = {{ nb_repeats }}
 min_kmer_mod_accuracy= {{ min_kmer_mod_accuracy }}
-- 
GitLab