Skip to content
Snippets Groups Projects
Commit 8bae9293 authored by Noordijk, Ben's avatar Noordijk, Ben
Browse files

Merge remote-tracking branch 'origin/master' into ben_dev

parents c19ccdf6 dbe6a5d7
No related branches found
No related tags found
1 merge request!3Added data preparation, hyperparameter optimisation, benchmarking code and k-mer library visualisation
......@@ -87,7 +87,7 @@ fast5_in = ('--fast5-in', {
target_16S = ('--target-16S', {
'type': str,
'help': 'fasta containing to-be recognized 16S sequence'
'help': 'fasta containing to-be recognized 16S sequence(s)'
})
inference_mode = ('--inference-mode', {
......@@ -277,7 +277,6 @@ def get_validate_parser():
parameter_file, hdf_path):
parser.add_argument(arg[0], **arg[1])
parser.add_argument('--ground-truth-16s', required=True, type=str, help='csv denoting which species reads belong to')
parser.add_argument('--target-species', required=True, type=str, help='Species to detect, as noted in ground truth csv')
parser.add_argument('--primed-nn-dir', required=True, type=str, help='Directory containing NNs trained on held-out set')
return parser
......
......@@ -14,4 +14,6 @@ dependencies:
- hyperopt=0.2.5
- seaborn=0.11.1
- numpy=1.19.5
- pytables=3.6.1
\ No newline at end of file
- pytables=3.6.1
- kmer-jellyfish=2.3.0
- scikit-learn=0.24.2
import os
import os, sys
import pandas as pd
from os.path import basename
from shutil import copy
from pathlib import Path
from distutils.dir_util import copy_tree
from sklearn.model_selection import StratifiedKFold
from jinja2 import Template
__location__ = os.path.realpath(os.path.join(os.getcwd(), os.path.dirname(__file__)))
baseless_location = os.path.realpath(f'{__location__}/..')
sys.path.append(baseless_location)
from argparse_dicts import get_validate_parser
from helper_functions import parse_input_path, parse_output_path
from snakemake import snakemake
import yaml
from tempfile import TemporaryDirectory
parser = get_validate_parser()
args = parser.parse_args()
nn_target_list = [pth.name for pth in Path(args.primed_nn_dir).iterdir() if pth.is_dir()]
__location__ = os.path.realpath(os.path.join(os.getcwd(), os.path.dirname(__file__)))
baseless_location = os.path.realpath(f'{__location__}/..')
with open(args.parameter_file, 'r') as pf: params = yaml.load(pf, Loader=yaml.FullLoader)
primed_nn_dir = args.primed_nn_dir
if primed_nn_dir[-1] != '/': primed_nn_dir += '/'
......@@ -32,34 +38,44 @@ db_dir = args.out_dir + 'dbs/'
# _ = parse_output_path(f'{db_dir}fold_{nf}')
# --- Generate read index files for folds ---
read_index_dir = parse_output_path(f'{args.out_dir}read_index_files')
fast5_list = parse_input_path(args.fast5_in, pattern='*.fast5')
fast5_basename_list = [basename(f) for f in fast5_list]
gt_df = pd.read_csv(args.ground_truth_16s, header=0)
gt_df.columns = [cn.replace(' ', '_') for cn in gt_df.columns]
gt_df.set_index('file_name', inplace=True)
target_reads_list = list(gt_df.query(f'species_id == "{args.target_species}"').index)
read_index_dir = parse_output_path(f'{args.out_dir}read_index_files')
fast5_list = parse_input_path(args.fast5_in, pattern='*.fast5')
fast5_basename_list = [basename(f) for f in fast5_list]
fast5_df = pd.DataFrame({'read_id': fast5_basename_list,
'fn': fast5_list,
'species_id': [gt_df.loc[ff, 'species_id'] for ff in fast5_basename_list],
'is_target': [True if f in target_reads_list else False for f in fast5_basename_list]}).set_index('read_id')
'species': [gt_df.species_short.get(ff, 'unknown') for ff in fast5_basename_list]}
).set_index('read_id')
fast5_df.drop(fast5_df.query('species == "unknown"').index, axis=0, inplace=True)
# --- make folders with test reads (inference consumes them, so separately for each species) ---
species_list = list(fast5_df.species.unique())
test_read_dir = parse_output_path(f'{args.out_dir}test_reads/')
for fi, (train_num_idx, _) in enumerate(StratifiedKFold(n_splits=args.nb_folds, shuffle=True).split(fast5_df.index,
fast5_df.species_id)):
train_idx = fast5_df.index[train_num_idx]
test_idx = fast5_df.index.difference(train_idx)
cur_test_read_dir = parse_output_path(f'{test_read_dir}fold_{fi}/')
for _, fn in fast5_df.loc[test_idx, 'fn'].iteritems(): copy(fn, cur_test_read_dir)
fast5_df.loc[:, 'fold'] = False
fast5_df.loc[train_idx, 'fold'] = True
fast5_df.loc[:, ['fn', 'is_target', 'fold']].to_csv(f'{read_index_dir}index_fold{fi}.csv')
with TemporaryDirectory() as td:
for fi, (train_num_idx, _) in enumerate(StratifiedKFold(n_splits=args.nb_folds, shuffle=True).split(fast5_df.index,
fast5_df.species)):
train_idx = fast5_df.index[train_num_idx]
test_idx = fast5_df.index.difference(train_idx)
cur_test_read_dir = parse_output_path(f'{td}/fold_{fi}/')
for _, fn in fast5_df.loc[test_idx, 'fn'].iteritems(): copy(fn, cur_test_read_dir)
fast5_df.loc[:, 'fold'] = False
fast5_df.loc[train_idx, 'fold'] = True
fast5_df.loc[:, ['fn', 'species', 'fold']].to_csv(f'{read_index_dir}index_fold{fi}.csv')
for sp in species_list:
copy_tree(td, f'{test_read_dir}{sp}')
# --- compile snakefile and run ---
with open(f'{__location__}/validate_16S.sf', 'r') as fh: template_txt = fh.read()
sm_text = Template(template_txt).render(
baseless_location=baseless_location,
species_list=species_list,
hdf_path=args.hdf_path,
filter_width=params['filter_width'],
parameter_file=args.parameter_file,
......@@ -74,6 +90,7 @@ sm_text = Template(template_txt).render(
inference_out_dir=parse_output_path(args.out_dir + 'inference_out/'),
inference_summary_dir=parse_output_path(args.out_dir + 'inference_summaries/'),
benchmark_dir=parse_output_path(args.out_dir + 'inference_benchmark/'),
target_16S_dir=parse_output_path(args.out_dir + 'target_fastas/'),
reads_dir=args.fast5_in,
test_read_dir=test_read_dir,
read_index_dir=read_index_dir
......@@ -81,4 +98,4 @@ sm_text = Template(template_txt).render(
sf_fn = f'{args.out_dir}validate_16S_pipeline.sf'
with open(sf_fn, 'w') as fh: fh.write(sm_text)
snakemake(sf_fn, cores=args.cores, verbose=False, keepgoing=True)
snakemake(sf_fn, cores=args.cores, verbose=False, keepgoing=True, dryrun=False)
import os
import re
import pandas as pd
__location__ = "{{ __location__ }}"
......@@ -23,14 +24,16 @@ logs_dir= '{{ logs_dir }}'
rule target:
input:
expand('{{ inference_summary_dir }}fold_{fold}.csv', fold=range({{ nb_folds }}))
expand('{{ inference_summary_dir }}species_{species}_fold_{fold}.csv', fold=range({{ nb_folds }}), species={{ species_list }})
rule parse_inference_results:
input:
inference_out_dir='{{ inference_out_dir }}fold_{fold}',
inference_out_dir='{{ inference_out_dir }}species_{species}_fold_{fold}',
index_fold_csv='{{ read_index_dir }}index_fold{fold}.csv'
params:
species='{species}'
output:
summary_file='{{ inference_summary_dir }}fold_{fold}.csv'
summary_file='{{ inference_summary_dir }}species_{species}_fold_{fold}.csv'
run:
read_index_df = pd.read_csv(input.index_fold_csv, index_col=0)
out_df = pd.DataFrame({'is_target': False}, index=read_index_df.query('fold == False').index)
......@@ -40,13 +43,13 @@ rule parse_inference_results:
rule run_inference:
input:
fast5_in='{{ test_read_dir }}fold_{fold}',
model='{{ compiled_mod_dir }}compiled_{fold}.tar'
fast5_in='{{ test_read_dir }}{species}/fold_{fold}',
model='{{ compiled_mod_dir }}compiled_species_{species}_fold_{fold}.tar'
# threads: workflow.cores
threads: 2
benchmark: '{{ benchmark_dir }}inference_benchmark_fold{fold}.tsv'
benchmark: '{{ benchmark_dir }}inference_benchmark_species_{species}_fold_{fold}.tsv'
output:
out_dir=directory('{{ inference_out_dir }}fold_{fold}')
out_dir=directory('{{ inference_out_dir }}species_{species}_fold_{fold}')
shell:
"""
python {baseless_location} run_inference \
......@@ -62,20 +65,34 @@ rule run_inference:
rule compile_model:
input:
target_16S='{{ target_16S_fasta }}',
target_16S='{{ target_16S_dir }}{species}.fasta',
nn_directories=expand("{{nn_dir}}fold_{{ '{{fold}}' }}/{nn_target}/nn.h5", nn_target={{ nn_target_list }}),
params:
nn_directory='{{ nn_dir }}/fold_{fold}'
output:
out_mod='{{ compiled_mod_dir }}compiled_{fold}.tar'
out_mod='{{ compiled_mod_dir }}compiled_species_{species}_fold_{fold}.tar'
shell:
"""
python {baseless_location} compile_model \
--nn-directory {params.nn_directory} \
--target-16S {input.target_16S} \
--out-model {output.out_mod} &> {logs_dir}compile_fold{wildcards.fold}.log
--out-model {output.out_mod} &> {logs_dir}compile_species_{wildcards.species}_fold{wildcards.fold}.log
"""
rule extract_target_sequences:
input:
target_16S_fasta='{{ target_16S_fasta }}'
params:
species='{species}'
output:
target_fasta='{{ target_16S_dir }}{species}.fasta'
run:
with open(input.target_16S_fasta, 'r') as fh: fasta_txt = fh.read()
fa_list = re.findall(f'>{params.species}[^>]+', fasta_txt)
with open(output.target_fasta, 'w') as fh:
fh.write(''.join(fa_list))
rule generate_nns:
input:
target_db_train='{{ db_dir }}fold_{fold}/train/{target}/db.fs',
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment