Skip to content
Snippets Groups Projects
Commit 5911a83d authored by Lannoy, Carlos de's avatar Lannoy, Carlos de
Browse files

abundance validation: allow multiple samples per species

parent e06ba693
No related branches found
No related tags found
No related merge requests found
......@@ -5,6 +5,9 @@ import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from itertools import chain
__location__ = os.path.realpath(os.path.join(os.getcwd(), os.path.dirname(__file__)))
baseless_location = os.path.realpath(f'{__location__}/..')
sys.path.append(baseless_location)
......@@ -12,13 +15,12 @@ sys.path.append(baseless_location)
from helper_functions import parse_output_path
def get_freqtable(species, freq_tables_fn_list, kmers):
cur_ft_list = [ft for ft in freq_tables_fn_list if ft.endswith('freq_table_' + species + '.csv')]
if not len(cur_ft_list): return
ft_df = pd.read_csv(cur_ft_list[0], index_col=0).loc[kmers, :]
def get_freqtable(ft_fn, kmers):
ft_df = pd.read_csv(ft_fn, index_col=0).loc[kmers, :]
ft_df = ft_df[~ft_df.index.duplicated(keep='first')]
return ft_df
def get_rank_score(s1, s2):
sp1_list = np.array(s1.sort_values().index)
sp2_list = np.array(s2.sort_values().index)
......@@ -26,46 +28,49 @@ def get_rank_score(s1, s2):
def format_species_names(x):
genus, species = x.split('_')
element_list = x.split('_')
genus, species = element_list[0], element_list[1]
if len(genus) == 1: genus = genus.upper() + '.'
out_name = f'$\it{{{genus} {species}}}$'
return out_name
def analyse_abundance_results(abundance_csv_list, freq_tables_list, species_list, out_dir):
def analyse_abundance_results(freq_tables_dict, analysis_sample_dict, out_dir):
out_dir = parse_output_path(out_dir)
timeseries_list = []
test_species_list = []
rankscore_df = pd.DataFrame(columns=['species', 'sample_nb'] + species_list).set_index(['species', 'sample_nb'])
for ae_fn in abundance_csv_list:
species, sample_txt = re.search('(?<=inference_)[^/]+', ae_fn).group(0).rsplit('_', 1)
if species not in test_species_list:
test_species_list.append(species)
sample_nb = int(re.search('(?<=sample)[0-9]+', sample_txt).group(0))
ae_df = pd.read_csv(ae_fn, index_col=0)
kmers = list(ae_df.columns)
ae_cumsum_df = ae_df.cumsum()
ts_df = pd.DataFrame({'species': species, 'sample_nb': sample_nb, 'rank_score': None}, index=ae_cumsum_df.index)
ft_df = get_freqtable(species, freq_tables_list, kmers)
for bg_sp in species_list:
col_idx = 'target' if bg_sp == species else bg_sp
ts_df.loc[:, bg_sp] = ae_cumsum_df.apply(lambda x: get_rank_score(x, ft_df.loc[:, col_idx]), axis=1)
rankscore_df.loc[(species, sample_nb), bg_sp] = ts_df.iloc[-1].loc[bg_sp]
timeseries_list.append(ts_df)
# test_species_list = list(analysis_sample_dict)
all_species_list = list(freq_tables_dict)
sample_list = list(chain.from_iterable([list(analysis_sample_dict[sp]) for sp in analysis_sample_dict]))
rankscore_df = pd.DataFrame(columns=['sample_id', 'rep_nb', 'species'] + all_species_list).set_index(['sample_id', 'rep_nb'])
for species in analysis_sample_dict:
for sa in analysis_sample_dict[species]:
for rep in analysis_sample_dict[species][sa]:
ae_fn = analysis_sample_dict[species][sa][rep]
ae_df = pd.read_csv(ae_fn, index_col=0)
kmers = list(ae_df.columns)
ae_cumsum_df = ae_df.cumsum()
ts_df = pd.DataFrame({'species': species, 'sample_id': sa, 'rep': rep, 'rank_score': None}, index=ae_cumsum_df.index)
ft_df = get_freqtable(freq_tables_dict[species], kmers)
for bg_sp in freq_tables_dict:
col_idx = 'target' if bg_sp == species else bg_sp
ts_df.loc[:, bg_sp] = ae_cumsum_df.apply(lambda x: get_rank_score(x, ft_df.loc[:, col_idx]), axis=1)
rankscore_df.loc[(sa, rep), bg_sp] = ts_df.iloc[-1].loc[bg_sp]
timeseries_list.append(ts_df)
# --- heatmap ---
heat_df = pd.DataFrame(columns=species_list, index=species_list[::-1], dtype="float")
heat_df = pd.DataFrame(index=sample_list, columns=all_species_list, dtype="float")
heat_df_sd = heat_df.copy()
for sp_true in test_species_list:
for sp_pred in species_list:
heat_df.loc[sp_true, sp_pred] = rankscore_df.loc[(sp_true, ), sp_pred].mean()
heat_df_sd.loc[sp_true, sp_pred] = rankscore_df.loc[(sp_true,), sp_pred].std()
for sp_true in analysis_sample_dict:
for sid in analysis_sample_dict[sp_true]:
for sp_pred in all_species_list:
heat_df.loc[sid, sp_pred] = rankscore_df.loc[(sid, ), sp_pred].mean()
heat_df_sd.loc[sid, sp_pred] = rankscore_df.loc[(sid,), sp_pred].std()
# heat_df = heat_df.div(heat_df.sum(axis=1), axis=0) # row-normalization
annot_mat = (heat_df.round(2).astype(str) + '\n±' + heat_df_sd.round(2).astype(str))
heat_df.index = [format_species_names(x) for x in heat_df.index]
# heat_df.index = [format_species_names(x) for x in heat_df.index]
heat_df.columns = [format_species_names(x) for x in heat_df.columns]
heat_df.rename_axis('Truth', axis='rows', inplace=True)
heat_df.rename_axis('Predicted', axis='columns', inplace=True)
......@@ -82,15 +87,13 @@ def analyse_abundance_results(abundance_csv_list, freq_tables_list, species_list
timeseries_df.rename({'index': 'nb_reads'}, inplace=True, axis=1)
fig, axes = plt.subplots(1,3, figsize=(18,5))
mid_plot_idx = len(test_species_list) // 2
for si, sp in enumerate(test_species_list):
tsmelt_df = timeseries_df.query(f'species == "{sp}"').melt(id_vars='nb_reads', value_vars=species_list).rename({'value': 'RDS', 'variable': 'species', 'nb_reads': '# reads'}, axis=1)
mid_plot_idx = len(sample_list) // 2
for si, sa in enumerate(sample_list):
tsmelt_df = timeseries_df.query(f'sample_id == "{sa}"').melt(id_vars='nb_reads', value_vars=all_species_list).rename({'value': 'RDS', 'variable': 'species', 'nb_reads': '# reads'}, axis=1)
tsmelt_df.species = tsmelt_df.species.apply(lambda x: format_species_names(x))
print(si)
print(sp)
sns.lineplot(x='# reads', y='RDS', hue='species', data=tsmelt_df, ax=axes[si])
axes[si].set_title(format_species_names(sp))
if si != len(test_species_list) - 1: axes[si].legend().remove()
axes[si].set_title(sa)
if si != len(sample_list) - 1: axes[si].legend().remove()
if si != 0: axes[si].set_ylabel('')
if si != mid_plot_idx: axes[si].set_xlabel('')
plt.savefig(f'{out_dir}timeseries_rankdiff.svg', dpi=400)
......
import argparse, sys, os, shutil, yaml
import argparse, sys, os, shutil, yaml, re
from snakemake import snakemake
from glob import glob
from jinja2 import Template
......@@ -14,7 +14,7 @@ from inference.analyse_abundance_results import analyse_abundance_results
def copy_dir(tup):
read_list, out_species_dir, max_reads, rep = tup
out_sample_dir = out_species_dir + f'sample{rep}/'
out_sample_dir = out_species_dir + f'{rep}/'
shuffle(read_list)
read_list = read_list[:max_reads]
os.makedirs(out_sample_dir, exist_ok=True)
......@@ -28,7 +28,7 @@ parser.add_argument('--index-csv', type=str, required=True,
help='Index file containing paths to genome and accompanying test reads directory.')
parser.add_argument('--max-test-reads', type=int, default=2000,
help='Maximum number of reads to sample from given test reads [default: 2000]')
parser.add_argument('--nb-samplings', type=int, default=5,
parser.add_argument('--nb-repeats', type=int, default=5,
help='Number of repeated samplings to perform from given test reads [default: 5]')
parser.add_argument('--training-read-dir', type=str, required=True,
help='directory containing resquiggled training reads')
......@@ -39,6 +39,8 @@ parser.add_argument('--parameter-file', type=str, default=f'{baseless_location}/
parser.add_argument('--model-size', type=int, default=25,
help='Define of how many k-mers abundance is estimated [default:25]')
parser.add_argument('--cores', type=int, default=4)
parser.add_argument('--min-kmer-mod-accuracy', type=float, default=0.85,
help='Filter kmer models on validation accuracy. May cause model size to shrink!')
parser.add_argument('--nb-gpus', type=int, default=0,
help='Number of GPUs that can be simultaneously engaged [default:0]')
parser.add_argument('--dryrun', action='store_true')
......@@ -52,20 +54,26 @@ genomes_dir = parse_output_path(out_dir + 'genomes')
species_dict = {}
test_species_list = []
sample_dict = {}
mp_list = []
with open(args.index_csv, 'r') as fh:
for line in fh.readlines():
species, genome_fn, test_dir = line.strip().split(',')
species_dict[species] = genome_fn
species, sample_id, genome_fn, test_dir = line.strip().split(',')
if species in species_dict:
assert genome_fn == species_dict[species] # two samples for same species cannot list different genomes
else:
species_dict[species] = genome_fn
if test_dir == 'None': continue
sample_dict[species] = sample_dict.get(species, []) + [sample_id]
# prepare copying test reads
if test_dir[-1] != '/': test_dir += '/'
test_species_list.append(species)
read_list = [test_dir + fn for fn in os.listdir(test_dir)]
chunk_size = len(read_list) // args.nb_samplings
for ns in range(args.nb_samplings):
mp_list.append((read_list[chunk_size * ns:chunk_size * (ns+1)], f'{test_read_dir}{species}/', args.max_test_reads, ns))
chunk_size = len(read_list) // args.nb_repeats
for ns in range(args.nb_repeats):
mp_list.append((read_list[chunk_size * ns:chunk_size * (ns+1)], f'{test_read_dir}{species}/{sample_id}/', args.max_test_reads, ns))
with mp.Pool(min(len(test_species_list), args.cores)) as pool:
pool.map(copy_dir, mp_list)
......@@ -100,8 +108,10 @@ sf_txt = Template(template_txt).render(
kmer_size=param_dict['kmer_size'],
baseless_location=baseless_location,
test_species_list=test_species_list,
sample_dict=sample_dict,
model_size=args.model_size,
nb_samplings=args.nb_samplings
nb_repeats=args.nb_repeats,
min_kmer_mod_accuracy=args.min_kmer_mod_accuracy
)
sf_fn = f'{out_dir}quick_benchmark_abundance_estimation.sf'
......@@ -110,10 +120,22 @@ resources = {}
if args.nb_gpus > 0:
resources['gpu'] = args.nb_gpus
snakemake(sf_fn, cores=args.cores, keepgoing=True, dryrun=args.dryrun, resources=resources)
if args.dryrun: exit(0)
abundance_csv_list = glob(f'{out_dir}inference/*/abundance_estimation.csv')
abundance_csv_list = glob(f'{out_dir}inference/*/*/*/abundance_estimation.csv')
abundance_csv_list.sort()
freq_tables_list = glob(f'{out_dir}kmer_lists/freq_table*')
freq_tables_list.sort()
freq_tables_dict = {re.search('(?<=freq_table_).+(?=.csv)', ft).group(0): ft for ft in freq_tables_list}
analysis_sample_dict = {}
for ac in abundance_csv_list:
_, sp, sa, rep, _ = ac.rsplit('/', 4)
sp = re.search('(?<=inference_).+', sp).group(0)
if sp not in analysis_sample_dict: analysis_sample_dict[sp] = {}
if sa not in analysis_sample_dict: analysis_sample_dict[sp][sa] = {}
analysis_sample_dict[sp][sa][int(rep)] = ac
analysis_species_list = list(species_dict)
analyse_abundance_results(abundance_csv_list, freq_tables_list, analysis_species_list, out_dir + 'analysis')
analyse_abundance_results(freq_tables_dict, analysis_sample_dict, out_dir + 'analysis')
......@@ -17,32 +17,34 @@ nn_dir = '{{ nn_dir }}' # needs to have species-level subfolders created!
# --- params ---
baseless_location = '{{ baseless_location }}'
kmer_size = {{ kmer_size }}
sample_dict = {{ sample_dict }}
test_species_list = {{ test_species_list }}
model_size = {{ model_size }}
nb_samplings = {{ nb_samplings }}
nb_repeats = {{ nb_repeats }}
min_kmer_mod_accuracy= {{ min_kmer_mod_accuracy }}
target_list = []
for sp in sample_dict:
for sa in sample_dict[sp]:
target_list.extend([f'{inference_dir}inference_{sp}/{sa}/{rep}/abundance_estimation.csv' for rep in range(nb_repeats)])
rule target:
input:
abundance_files=expand('{{ inference_dir }}inference_{species}_sample{sample}/abundance_estimation.csv',
species=test_species_list, sample=range(nb_samplings))
# output:
# corr_matrix_svg=''
# run:
#
abundance_files=target_list
rule run_inference:
input:
fast5_in='{{ test_read_dir }}{species}/sample{sample}/',
fast5_in='{{ test_read_dir }}{species}/{sample}/{rep}',
model='{{ mod_dir }}compiled_model_{species}.h5'
threads: 3
resources:
gpu=1
benchmark: '{{ benchmark_dir }}benchmark_{species}_sample{sample}.tsv'
benchmark: '{{ benchmark_dir }}benchmark_{species}_sample{sample}_rep{rep}.tsv'
params:
out_dir='{{ inference_dir }}inference_{species}_sample{sample}/'
out_dir='{{ inference_dir }}inference_{species}/{sample}/{rep}'
output:
out_csv='{{ inference_dir }}inference_{species}_sample{sample}/abundance_estimation.csv'
out_csv='{{ inference_dir }}inference_{species}/{sample}/{rep}/abundance_estimation.csv'
shell:
"""
python {baseless_location} run_inference_evaluation \
......@@ -70,7 +72,7 @@ rule compile_model:
--out-dir {nn_dir}{wildcards.species}/ \
--cores {threads} \
--model-type abundance \
--accuracy-threshold 0.85 \
--accuracy-threshold {min_kmer_mod_accuracy} \
--parameter-file {parameter_file} &> {logs_dir}compile_model_{wildcards.species}.log
"""
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment