From ba83bc1131f70f2e4b1d7cea948ddd538e86f52f Mon Sep 17 00:00:00 2001 From: Carlos de Lannoy <cvdelannoy@gmail.com> Date: Wed, 2 Feb 2022 21:29:37 +0100 Subject: [PATCH] misc --- inference/analyse_abundance_results.py | 3 +-- inference/compile_model.py | 12 +++++++--- tools/plot_nn_performance.py | 2 ++ .../quick_benchmark_abundance_estimation.py | 23 +++++++++++-------- .../quick_benchmark_abundance_estimation.sf | 1 - 5 files changed, 26 insertions(+), 15 deletions(-) diff --git a/inference/analyse_abundance_results.py b/inference/analyse_abundance_results.py index 9b5ed72..5c10cfd 100644 --- a/inference/analyse_abundance_results.py +++ b/inference/analyse_abundance_results.py @@ -47,8 +47,7 @@ def analyse_abundance_results(freq_tables_dict, analysis_sample_dict, all_specie for species in analysis_sample_dict: for sa in analysis_sample_dict[species]: for rep in analysis_sample_dict[species][sa]: - ae_fn = analysis_sample_dict[species][sa][rep] - ae_df = pd.read_csv(ae_fn, index_col=0) + ae_df = pd.read_csv(analysis_sample_dict[species][sa][rep], index_col=0) kmers = list(ae_df.columns) ae_cumsum_df = ae_df.cumsum() ts_df = pd.DataFrame({'species': species, 'sample_id': sa, 'rep': rep, 'rank_score': None}, index=ae_cumsum_df.index) diff --git a/inference/compile_model.py b/inference/compile_model.py index c23b164..6224203 100644 --- a/inference/compile_model.py +++ b/inference/compile_model.py @@ -157,13 +157,18 @@ def train_on_the_fly(kmer_list, available_mod_dict, args): def filter_accuracy(kmer_dict, acc_threshold): out_dict = {} + discard_list = [] for kmd in kmer_dict: perf_fn = kmer_dict[kmd] + '/performance.pkl' if not os.path.isfile(perf_fn): continue with open(perf_fn, 'rb') as fh: perf_dict = pickle.load(fh) - if perf_dict['val_binary_accuracy'][-1] > acc_threshold: + # metric = 2 / ( perf_dict['val_precision'][-1] ** -1 + perf_dict['val_recall'][-1] ** -1) # F1 + metric = perf_dict['val_binary_accuracy'][-1] # plain accuracy + if metric > acc_threshold: out_dict[kmd] = kmer_dict[kmd] - return out_dict + else: + discard_list.append(kmd) + return out_dict, discard_list def main(args): @@ -205,7 +210,8 @@ def main(args): raise ValueError('Either provide --nn-directory or --target-16S') if args.accuracy_threshold: - target_kmer_dict = filter_accuracy(target_kmer_dict, args.accuracy_threshold) + target_kmer_dict, discard_list = filter_accuracy(target_kmer_dict, args.accuracy_threshold) + print(f'discarded {len(discard_list)} k-mers because accuracy is lower than threshold: {discard_list}') if args.model_type == 'binary': mod = compile_model(target_kmer_dict, diff --git a/tools/plot_nn_performance.py b/tools/plot_nn_performance.py index dcf7418..0028dca 100755 --- a/tools/plot_nn_performance.py +++ b/tools/plot_nn_performance.py @@ -39,6 +39,8 @@ def main(): # accuracy sns.violinplot(y='accuracy', data=performance_df, color="0.8", ax=ax_acc) sns.stripplot(y='accuracy', data=performance_df, ax=ax_acc) + for km, tup in performance_df.iterrows(): + ax_acc.text(x=0.1, y=tup.accuracy, s=km, fontsize=5) ll = min(performance_df.precision.min(), performance_df.recall.min()) - 0.01 ax_pr.set_ylim(ll,1); ax_pr.set_xlim(ll,1) ax_pr.set_aspect('equal') diff --git a/validation/quick_benchmark_abundance_estimation.py b/validation/quick_benchmark_abundance_estimation.py index 60e0807..4e7f09e 100644 --- a/validation/quick_benchmark_abundance_estimation.py +++ b/validation/quick_benchmark_abundance_estimation.py @@ -32,6 +32,8 @@ parser.add_argument('--nb-repeats', type=int, default=5, help='Number of repeated samplings to perform from given test reads [default: 5]') parser.add_argument('--training-read-dir', type=str, required=True, help='directory containing resquiggled training reads') +parser.add_argument('--kmer-mod-dir', type=str, required=False, + help='Directory containing pretrained nns') #--- output --- parser.add_argument('--out-dir', type=str, required=True) #--- params --- @@ -51,9 +53,8 @@ out_dir = parse_output_path(args.out_dir, clean=True) test_read_dir = parse_output_path(out_dir + 'test_reads') genomes_dir = parse_output_path(out_dir + 'genomes') - species_dict = {} -test_species_list = [] +sample_id_list = [] sample_dict = {} mp_list = [] with open(args.index_csv, 'r') as fh: @@ -63,18 +64,17 @@ with open(args.index_csv, 'r') as fh: assert genome_fn == species_dict[species] # two samples for same species cannot list different genomes else: species_dict[species] = genome_fn - if test_dir == 'None': continue sample_dict[species] = sample_dict.get(species, []) + [sample_id] # prepare copying test reads if test_dir[-1] != '/': test_dir += '/' - test_species_list.append(species) + sample_id_list.append(sample_id) read_list = [test_dir + fn for fn in os.listdir(test_dir)] chunk_size = len(read_list) // args.nb_repeats for ns in range(args.nb_repeats): mp_list.append((read_list[chunk_size * ns:chunk_size * (ns+1)], f'{test_read_dir}{species}/{sample_id}/', args.max_test_reads, ns)) -with mp.Pool(min(len(test_species_list), args.cores)) as pool: +with mp.Pool(min(len(sample_id_list), args.cores)) as pool: pool.map(copy_dir, mp_list) for species in species_dict: @@ -84,11 +84,17 @@ for species in species_dict: if bgs == species: continue os.symlink(species_dict[bgs], f'{cur_bg_dir}{bgs}.fasta') - # kmer nns for different species need to be separated for simultaneous running. Pregenerate folders for that. nn_dir = parse_output_path(out_dir + 'kmer_nns') -for species in test_species_list: +for species in list(sample_dict): _ = parse_output_path(nn_dir + species) +if args.kmer_mod_dir: + kmer_mod_dir = args.kmer_mod_dir + if kmer_mod_dir[-1] != '/': kmer_mod_dir += '/' + kmer_mods_species = os.listdir(kmer_mod_dir) + for species in list(sample_dict): + if species in kmer_mods_species: + shutil.copytree(f'{kmer_mod_dir}{species}/nns', f'{nn_dir}{species}/nns') with open(args.parameter_file, 'r') as fh: param_dict = yaml.load(fh, Loader=yaml.FullLoader) @@ -107,7 +113,6 @@ sf_txt = Template(template_txt).render( nn_dir=nn_dir, kmer_size=param_dict['kmer_size'], baseless_location=baseless_location, - test_species_list=test_species_list, sample_dict=sample_dict, model_size=args.model_size, nb_repeats=args.nb_repeats, @@ -138,4 +143,4 @@ for ac in abundance_csv_list: analysis_species_list = list(species_dict) -analyse_abundance_results(freq_tables_dict, analysis_sample_dict, out_dir + 'analysis') +analyse_abundance_results(freq_tables_dict, analysis_sample_dict, analysis_species_list, out_dir + 'analysis') diff --git a/validation/quick_benchmark_abundance_estimation.sf b/validation/quick_benchmark_abundance_estimation.sf index 958c9db..7108f90 100644 --- a/validation/quick_benchmark_abundance_estimation.sf +++ b/validation/quick_benchmark_abundance_estimation.sf @@ -18,7 +18,6 @@ nn_dir = '{{ nn_dir }}' # needs to have species-level subfolders created! baseless_location = '{{ baseless_location }}' kmer_size = {{ kmer_size }} sample_dict = {{ sample_dict }} -test_species_list = {{ test_species_list }} model_size = {{ model_size }} nb_repeats = {{ nb_repeats }} min_kmer_mod_accuracy= {{ min_kmer_mod_accuracy }} -- GitLab