Skip to content
Snippets Groups Projects
Commit 3f6cd45b authored by Lannoy, Carlos de's avatar Lannoy, Carlos de
Browse files

add abundance estimation mode

parent a1aa952f
No related branches found
No related tags found
No related merge requests found
......@@ -232,6 +232,14 @@ hyperopt_parallel_jobs = ('--hyperopt-parallel-jobs', {
'help': 'Number of hyperopt optimization to run in parallel [default: 4]'
})
model_type = ('--model-type', {
'type': str,
'default': 'binary',
'choices': ['binary', 'abundance'],
'help': 'Specify type of model to compile [binary] to predict presence of fraction of k-mers, [abundance] for'
'k-mer abundance estimation [default: binary]'
})
# --- parser getters ---
def get_run_production_pipeline_parser():
......@@ -297,7 +305,7 @@ def get_compile_model_parser():
'help': 'Minimum number of kmer models to include in the multi-network'
})
for arg in (kmer_list, target_16S, nn_directory, out_model, min_nb_models, parameter_file):
for arg in (kmer_list, target_16S, nn_directory, out_model, min_nb_models, parameter_file, model_type):
parser.add_argument(arg[0], **arg[1])
parser.add_argument('--train-required', action='store_true',
help='Train new models as required [default: use only available models]')
......
......@@ -39,7 +39,7 @@ class ReadTable(object):
manager_process.start()
return manager_process
def get_read_to_predict(self, batch_size=32):
def get_read_to_predict(self, batch_size=4):
"""Get read and a kmer for which to scan the read file
:return: Tuple of: (path to read fast5 file,
......
import os, re, yaml, subprocess, warnings
import os, re, yaml, subprocess, warnings, h5py
from pathlib import Path
from tempfile import TemporaryDirectory
......@@ -34,7 +34,7 @@ def fa2kmers(fa_fn, kmer_size):
with open(tdo + '/target.fasta', 'w') as fh:
fh.write(targets_dict[td])
subprocess.run(
f'jellyfish count -m {kmer_size} -s {4 ** kmer_size} -C -o {tdo}/mer_counts.jf {fa_fn}',
f'jellyfish count -m {kmer_size} -s {4 ** kmer_size} -C -o {tdo}/mer_counts.jf {tdo}/target.fasta',
shell=True)
kmer_dump = subprocess.run(f'jellyfish dump -c {tdo}/mer_counts.jf', shell=True, capture_output=True)
kmer_list = [km.split(' ')[0] for km in kmer_dump.stdout.decode('utf-8').split('\n')]
......@@ -113,43 +113,69 @@ def compile_model(kmer_dict, filter_width, filter_stride, threshold, k_frac):
return meta_mod
def compile_model_abundance(kmer_dict, filter_width, filter_stride, threshold):
input = tf.keras.Input(shape=(None, 1), ragged=True)
input_strided = tf.signal.frame(input.to_tensor(default_value=np.nan), frame_length=filter_width, frame_step=filter_stride, axis=1)
input_strided = tf.keras.layers.Masking(mask_value=np.nan)(input_strided)
ht_list = []
for km in kmer_dict:
mod = tf.keras.models.load_model(f'{kmer_dict[km]}/nn.h5',compile=False)
mod._name = km
h = tf.keras.layers.TimeDistributed(mod)(input_strided)
h = K.cast_to_floatx(K.greater(h, threshold))
h = K.sum(h, axis=0)
h = K.sum(h, axis=0)
ht_list.append(h)
output = tf.keras.layers.concatenate(ht_list)
meta_mod = tf.keras.Model(inputs=input, outputs=output)
meta_mod.compile()
return meta_mod
def train_on_the_fly(kmer_list, available_mod_dict, args):
kmers_no_models = [km for km in kmer_list if km not in available_mod_dict]
if len(kmers_no_models): # train additional models, if required
print(f'No models found for {len(kmers_no_models)} kmers, training on the fly!')
args.kmer_list = kmers_no_models
rpp(args)
# add newly generated models to available model list
for km in kmers_no_models:
if os.path.exists(f'{args.out_dir}nns/{km}/nn.h5'): # Check to filter out failed models
available_mod_dict[km] = f'{args.out_dir}nns/{km}'
else:
warnings.warn(
f'model generation failed for {km}, see {args.out_dir}logs. Continuing compilation without it.')
kmer_list.remove(km)
out_dict = {km: available_mod_dict.get(km, None) for km in kmer_list if km in available_mod_dict}
return out_dict
def main(args):
# List for which k-mers models are available
if args.nn_directory:
nn_directory = args.nn_directory
available_mod_dict = {pth.name: str(pth) for pth in Path(args.nn_directory).iterdir() if pth.is_dir()}
elif args.target_16S:
nn_directory = f'{__location__}/../16S_db/'
available_mod_dict = {pth.name: str(pth) for pth in Path(f'{__location__}/../16S_db/').iterdir() if pth.is_dir()}
else:
raise ValueError('Either provide --nn-directory or --target-16S')
available_mod_dict = {pth.name: str(pth) for pth in Path(nn_directory).iterdir() if pth.is_dir()} # List models for which k-mer is available
available_mod_dict = {}
# Parse target k-mers
if args.kmer_list:
if args.kmer_list: # parse a given list of kmers
with open(args.kmer_list, 'r') as fh:
requested_kmer_list = [km.strip() for km in fh.readlines()]
requested_kmer_list = [km for km in requested_kmer_list if len(km)]
target_kmer_dict = {km: available_mod_dict.get(km, None) for km in requested_kmer_list if km in available_mod_dict}
elif args.target_16S:
if args.train_required:
target_kmer_dict = train_on_the_fly(requested_kmer_list, available_mod_dict, args)
else:
target_kmer_dict = {km: available_mod_dict.get(km, None) for km in requested_kmer_list if km in available_mod_dict}
elif args.target_16S: # estimate salient set of kmers from given 16S sequence
kmer_size = 8
requested_kmer_dict = fa2kmers(args.target_16S, kmer_size) # Collect k-mers per sequence in target fasta marked as recognizable
if args.train_required:
target_kmer_list = get_kmer_candidates_16S(requested_kmer_dict, args.min_nb_models, 0.0001)
kmers_no_models = [km for km in target_kmer_list if km not in available_mod_dict]
if len(kmers_no_models): # train additional models, if required
print(f'No models found for {len(kmers_no_models)} kmers, training on the fly!')
args.kmer_list = kmers_no_models
rpp(args)
# add newly generated models to available model list
for km in kmers_no_models:
if os.path.exists(f'{args.out_dir}nns/{km}/nn.h5'): # Check to filter out failed models
available_mod_dict[km] = f'{args.out_dir}nns/{km}'
else:
warnings.warn(f'model generation failed for {km}, see {args.out_dir}logs. Continuing compilation without it.')
target_kmer_list.remove(km)
target_kmer_dict = train_on_the_fly(target_kmer_list, available_mod_dict, args)
else: # filter out k-mers for which no stored model exists
target_kmer_list = get_kmer_candidates_16S(requested_kmer_dict, args.min_nb_models, 0.0001, filter_list=list(available_mod_dict))
target_kmer_dict = {km: available_mod_dict.get(km, None) for km in target_kmer_list if km in available_mod_dict}
target_kmer_dict = {km: available_mod_dict.get(km, None) for km in target_kmer_list if km in available_mod_dict}
if not len(target_kmer_dict):
raise ValueError('Sequences do not contain any of available models!')
else:
......@@ -157,8 +183,17 @@ def main(args):
with open(args.parameter_file, 'r') as fh:
param_dict = yaml.load(fh, yaml.FullLoader)
mod = compile_model(target_kmer_dict,
param_dict['filter_width'], param_dict['filter_stride'],
param_dict['threshold'], param_dict['k_frac'])
if args.model_type == 'binary':
mod = compile_model(target_kmer_dict,
param_dict['filter_width'], param_dict['filter_stride'],
param_dict['threshold'], param_dict['k_frac'])
elif args.model_type == 'abundance':
mod = compile_model_abundance(target_kmer_dict,
param_dict['filter_width'], param_dict['filter_stride'],
param_dict['threshold'])
else:
raise ValueError(f'--model-type {args.model_type} not implemented')
mod.save(args.out_model)
with h5py.File(args.out_model, 'r+') as fh:
fh.attrs['model_type'] = args.model_type
fh.attrs['kmer_list'] = ','.join(list(target_kmer_dict))
import argparse, subprocess, re
from os.path import basename, splitext
from tempfile import TemporaryDirectory
import pandas as pd
COMPLEMENT_DICT = {'A': 'T', 'T': 'A', 'C': 'G', 'G': 'C'}
def reverse_complement(km):
return ''.join([COMPLEMENT_DICT[k] for k in km][::-1])
def fa2df(fa_fn, kmer_size, column_name, cores):
"""
take multifasta file, return df of counts of k-mers
"""
with TemporaryDirectory() as tdo:
subprocess.run(
f'jellyfish count -t {cores} -m {kmer_size} -s {4 ** kmer_size} -C -o {tdo}/mer_counts.jf {fa_fn}',
shell=True) # option C: no reverse complements
kmer_dump = subprocess.run(f'jellyfish dump -c {tdo}/mer_counts.jf', shell=True, capture_output=True)
kmer_freq_tuples = [km.split(' ') for km in kmer_dump.stdout.decode('utf-8').split('\n')]
kmer_freq_tuples = [km for km in kmer_freq_tuples if len(km) == 2]
kmer_freq_dict = {km[0]: int(km[1]) for km in kmer_freq_tuples}
return pd.DataFrame.from_dict(kmer_freq_dict, orient='index', columns=[column_name])
parser = argparse.ArgumentParser(description='Get set of k-mers for which abundance is maximally different.')
# --- inputs ---
parser.add_argument('--target-fasta', required=True)
parser.add_argument('--background-fastas', nargs='+', required=True)
# --- outputs ---
parser.add_argument('--out-kmer-txt', required=True)
parser.add_argument('--out-freq-table', required=False)
# --- params ---
parser.add_argument('--model-size', type=int, default=25)
parser.add_argument('--cores', type=int, default=4)
args = parser.parse_args()
nb_bg_fastas = len(args.background_fastas)
kmers_per_bg = max(1, args.model_size // (nb_bg_fastas * 2)) # factor 2 because we need kmers from low and high end of ratio spectrum
bg_fn_list = [splitext(basename(fn))[0] for fn in args.background_fastas]
# --- parse fastas into dfs ---
target_df = fa2df(args.target_fasta, 8, 'target', args.cores)
bg_df_list = [fa2df(fn, 8, bn_fn, args.cores) for fn, bn_fn in zip(args.background_fastas, bg_fn_list)]
# --- select kmers based on relative abundances ---
kmer_list = []
count_df = pd.concat([target_df] + bg_df_list, axis=1)
for bg_fn in bg_fn_list:
count_df.loc[:, f'ratio_{bg_fn}'] = count_df.target / count_df.loc[:, bg_fn]
ratio_series = count_df.loc[:, f'ratio_{bg_fn}'].sort_values()
kmer_list.extend(ratio_series.iloc[-kmers_per_bg:].index)
kmer_list.extend(ratio_series.iloc[:kmers_per_bg].index)
# --- write away ---
with open(args.out_kmer_txt, 'w') as fh: fh.write('\n'.join(kmer_list))
if args.out_freq_table:
sub_df = count_df.loc[kmer_list, :].copy()
for cn in ['target'] + bg_fn_list:
sub_df.loc[:, f'rel_{cn}'] = sub_df.loc[:, cn] / sub_df.loc[:, cn].sum()
sub_df.to_csv(args.out_freq_table)
import sys, signal, os
import sys, signal, os, h5py
from datetime import datetime
from helper_functions import parse_output_path
from inference.ReadTable import ReadTable
from inference.InferenceModel import InferenceModel
import tensorflow as tf
import numpy as np
def main(args):
tf.config.threading.set_intra_op_parallelism_threads(1)
tf.config.threading.set_inter_op_parallelism_threads(1)
print('Loading model...')
mod = tf.keras.models.load_model(args.model)
with h5py.File(args.model, 'r') as fh:
model_type = fh.attrs['model_type']
kmer_list = fh.attrs['kmer_list'].split(',')
print(f'Done! Model type is {model_type}')
pos_reads_dir = parse_output_path(args.out_dir + 'pos_reads')
# for abundance estimation mode
abundance_mode = False
if model_type == 'abundance':
abundance_array = np.zeros(len(kmer_list))
abundance_mode = True
# Load read table, start table manager
read_table = ReadTable(args.fast5_in, pos_reads_dir)
read_manager_process = read_table.init_table()
......@@ -37,9 +49,19 @@ def main(args):
read_id, read = read_table.get_read_to_predict()
if read is None: continue
pred = mod(read).numpy()
read_table.update_prediction(read_id, pred)
if abundance_mode:
abundance_array += pred
read_table.update_prediction(read_id, np.zeros(len(read_id), dtype=bool))
else:
read_table.update_prediction(read_id, pred)
else:
run_time = datetime.now() - start_time
read_manager_process.terminate()
read_manager_process.join()
print(f'rutime was {run_time.seconds} s')
if abundance_mode:
abundance_array = abundance_array / max(abundance_array.sum(), 1)
abundance_txt = 'kmer,frequency\n' + '\n'.join([f'{km},{ab}' for km, ab in zip(kmer_list, abundance_array)])
with open(f'{args.out_dir}abundance_estimation.csv', 'w') as fh:
fh.write(abundance_txt)
......@@ -43,4 +43,4 @@ def main(args):
sf_fn = f'{args.out_dir}nn_production_pipeline.sf'
with open(sf_fn, 'w') as fh: fh.write(sm_text)
snakemake(sf_fn, cores=args.cores, verbose=False, keepgoing=True)
snakemake(sf_fn, cores=args.cores, verbose=False, keepgoing=True, resources={'gpu': 1})
......@@ -29,6 +29,8 @@ rule generate_nns:
nn='{{ nn_dir }}{target}/nn.h5'
threads:
1
resources:
gpu=1
shell:
"""
python {__location__} train_nn \
......
import argparse, sys, os, shutil
from snakemake import snakemake
from jinja2 import Template
import multiprocessing as mp
__location__ = os.path.realpath(os.path.join(os.getcwd(), os.path.dirname(__file__)))
baseless_location = os.path.realpath(f'{__location__}/..')
sys.path.append(baseless_location)
from helper_functions import parse_output_path
def copy_dir(tup):
shutil.copytree(tup[0], tup[1])
parser = argparse.ArgumentParser(description='Benchmark baseLess abundance estimation functionality')
# --- input ---
parser.add_argument('--index-csv', type=str, required=True,
help='Index file containing paths to genome and accompanying test reads directory.')
parser.add_argument('--training-read-dir', type=str, required=True,
help='directory containing resquiggled training reads')
#--- output ---
parser.add_argument('--out-dir', type=str, required=True)
#--- params ---
parser.add_argument('--parameter-file', type=str, default=f'{baseless_location}/nns/hyperparams/CnnParameterFile.yaml')
parser.add_argument('--model-size', type=int, default=25,
help='Define of how many k-mers abundance is estimated [default:25]')
parser.add_argument('--cores', type=int, default=4)
parser.add_argument('--dryrun', action='store_true')
args = parser.parse_args()
out_dir = parse_output_path(args.out_dir, clean=True)
test_read_dir = parse_output_path(out_dir + 'test_reads')
genomes_dir = parse_output_path(out_dir + 'genomes')
species_dict = {}
mp_list = []
with open(args.index_csv, 'r') as fh:
for line in fh.readlines():
species, genome_fn, test_dir = line.strip().split(',')
species_dict[species] = genome_fn
mp_list.append((test_dir, test_read_dir + species))
with mp.Pool(min(len(species_dict), args.cores)) as pool:
pool.map(copy_dir, mp_list)
for species in species_dict:
os.symlink(species_dict[species], f'{genomes_dir}{species}.fasta')
cur_bg_dir = parse_output_path(f'{genomes_dir}bg_{species}')
for bgs in species_dict:
if bgs == species: continue
os.symlink(species_dict[bgs], f'{cur_bg_dir}{bgs}.fasta')
species_list = list(species_dict)
# kmer nns for different species need to be separated for simultaneous running. Pregenerate folders for that.
nn_dir = parse_output_path(out_dir + 'kmer_nns')
for species in species_list:
_ = parse_output_path(nn_dir + species)
# --- generate snakemake file ---
with open(f'{__location__}/quick_benchmark_abundance_estimation.sf', 'r') as fh: template_txt = fh.read()
sf_txt = Template(template_txt).render(
parameter_file=args.parameter_file,
training_read_dir=args.training_read_dir,
test_read_dir=test_read_dir,
benchmark_dir=parse_output_path(out_dir + 'benchmarks'),
inference_dir=parse_output_path(out_dir + 'inference'),
kmer_list_dir=parse_output_path(out_dir + 'kmer_lists'),
genomes_dir=genomes_dir,
mod_dir=parse_output_path(out_dir + 'mods'),
logs_dir=parse_output_path(out_dir + 'logs'),
nn_dir=nn_dir,
baseless_location=baseless_location,
species_list=species_list,
model_size=args.model_size
)
sf_fn = f'{out_dir}quick_benchmark_abundance_estimation.sf'
with open(sf_fn, 'w') as fh: fh.write(sf_txt)
snakemake(sf_fn, cores=args.cores, keepgoing=True, dryrun=args.dryrun, resources={'gpu': 1})
# --- input ---
parameter_file = '{{ parameter_file }}'
training_read_dir = '{{ training_read_dir }}'
test_read_dir = '{{ test_read_dir }}'
genomes_dir = '{{ genomes_dir }}'
# --- output ---
benchmark_dir = '{{ benchmark_dir }}'
inference_dir = '{{ inference_dir }}'
kmer_list_dir = '{{ kmer_list_dir }}'
mod_dir = '{{ mod_dir }}'
logs_dir = '{{ logs_dir }}'
nn_dir = '{{ nn_dir }}' # needs to have species-level subfolders created!
# --- params ---
baseless_location = '{{ baseless_location }}'
species_list = {{ species_list }}
model_size = {{ model_size }}
rule target:
input:
abundance_files=expand('{{ inference_dir }}inference_{species}/abundance_estimation.csv',
species=species_list),
freq_tables=expand('{{ kmer_list_dir }}freq_table_{species}.csv',
species=species_list)
# output:
# corr_matrix_svg=''
# run:
#
rule run_inference:
input:
fast5_in='{{ test_read_dir }}{species}/',
model='{{ mod_dir }}compiled_model_{species}.h5'
threads: 3
resources:
gpu=1
benchmark: '{{ benchmark_dir }}benchmark_{species}.tsv'
params:
out_dir='{{ inference_dir }}inference_{species}/'
output:
out_dir='{{ inference_dir }}inference_{species}/abundance_estimation.csv'
shell:
"""
python {baseless_location} run_inference \
--fast5-in {input.fast5_in} \
--model {input.model} \
--out-dir {params.out_dir} \
--inference-mode once > {logs_dir}inference_{wildcards.species}.log
"""
rule compile_model:
input:
kmer_list='{{ kmer_list_dir }}kmer_list_{species}.txt'
# threads: max(workflow.cores // len(species_list), 8)
threads: workflow.cores
resources:
gpu=1
output:
out_mod='{{ mod_dir }}compiled_model_{species}.h5'
shell:
"""
python {baseless_location} compile_model \
--kmer-list {input.kmer_list} \
--train-required \
--training-reads {training_read_dir} \
--out-model {output.out_mod} \
--out-dir {nn_dir}{wildcards.species}/ \
--cores {threads} \
--model-type abundance \
--parameter-file {parameter_file} &> {logs_dir}compile_model_{wildcards.species}.log
"""
rule diff_abundances:
input:
target_fasta='{{ genomes_dir }}{species}.fasta'
threads: workflow.cores
params:
background_fastas='{{ genomes_dir }}bg_{species}/*'
output:
kmer_txt='{{ kmer_list_dir }}kmer_list_{species}.txt',
freq_table='{{ kmer_list_dir }}freq_table_{species}.csv'
shell:
"""
python {baseless_location}/inference/diff_abundance_kmers.py \
--target-fasta {input.target_fasta}\
--background-fastas {params.background_fastas} \
--out-kmer-txt {output.kmer_txt} \
--out-freq-table {output.freq_table} \
--model-size {model_size} \
--cores {threads} > {logs_dir}diff_abundances_{wildcards.species}.log
"""
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment