Skip to content
Snippets Groups Projects
Commit 3b525fc2 authored by Noordijk, Ben's avatar Noordijk, Ben
Browse files

Run production pipeline can now be used for 5-fold cv with index file

parent 25792824
No related branches found
No related tags found
No related merge requests found
......@@ -41,6 +41,14 @@ read_index = ('--read-index', {
'help': 'Supply index file denoting which reads should be training reads and which should be test.'
})
train_reads_index = ('--train-reads-index', {
'type': str,
'required': False,
'default': None,
'help': 'Txt files that specifies file names to be used for training. '
'Used for k-fold cross validation'
})
kmer_list = ('--kmer-list', {
'type': str,
# 'required': True,
......@@ -239,7 +247,7 @@ def get_run_production_pipeline_parser():
parser = argparse.ArgumentParser(description='Generate DBs from read sets and generate RNNs for several k-mers '
'at once')
for arg in (training_reads, test_reads, out_dir, kmer_list, cores,
parameter_file, hdf_path, uncenter_kmer):
parameter_file, hdf_path, uncenter_kmer, train_reads_index):
parser.add_argument(arg[0], **arg[1])
return parser
......@@ -265,6 +273,7 @@ def get_training_parser():
parser.add_argument(arg[0], **arg[1])
return parser
def get_run_inference_parser():
parser = argparse.ArgumentParser(description='Start up inference routine and watch a fast5 directory for reads.')
for arg in (fast5_in, out_dir, model, inference_mode):
......
......@@ -6,6 +6,7 @@ from os.path import isdir, dirname, basename, splitext
from shutil import rmtree
from pathlib import Path
from random import shuffle
import os
__location__ = dirname(Path(__file__).resolve())
sys.path.extend([__location__, f'{__location__}/..'])
......@@ -18,12 +19,10 @@ def main(args):
out_path = parse_output_path(args.db_dir)
if isdir(out_path):
rmtree(out_path)
if args.read_index:
read_index_df = pd.read_csv(args.read_index, index_col=0)
if args.db_type == 'train':
file_list = list(read_index_df.query(f'fold').fn)
else: # test
file_list = list(read_index_df.query(f'fold == False').fn)
read_index_df = pd.read_csv(args.read_index)
file_list = list(read_index_df.squeeze())
else:
file_list = parse_input_path(args.fast5_in, pattern='*.fast5')
if args.randomize: shuffle(file_list)
......@@ -35,18 +34,19 @@ def main(args):
db = ExampleDb(db_name=db_name, target=args.target, width=args.width)
nb_files = len(file_list)
count_pct_lim = 5
nb_example_reads = 0
count_pct_lim = 1
for i, file in enumerate(file_list):
file = os.path.join(args.fast5_in, file)
try:
with h5py.File(file, 'r') as f:
tr = TrainingRead(f, normalization=args.normalization,
hdf_path=args.hdf_path,
kmer_size=kmer_size)
nb_pos = db.add_training_read(training_read=tr,
uncenter_kmer=args.uncenter_kmer)
if nb_example_reads < args.nb_example_reads and nb_pos > 0:
np.savez(npz_path + splitext(basename(file))[0], base_labels=tr.events, raw=tr.raw)
db.add_training_read(training_read=tr,
uncenter_kmer=args.uncenter_kmer)
if args.store_example_reads:
np.savez(npz_path + splitext(basename(file))[0],
base_labels=tr.events, raw=tr.raw)
if not i+1 % 10: # Every 10 reads remove history of transactions ('pack' the database) to reduce size
db.pack_db()
if db.nb_pos > args.max_nb_examples:
......@@ -56,11 +56,10 @@ def main(args):
if not args.silent and percentage_processed >= count_pct_lim:
print(f'{percentage_processed}% of reads processed, {db.nb_pos} positives in DB')
count_pct_lim += 5
except (KeyError, ValueError) as e:
except Exception as e:
with open(error_fn, 'a') as efn:
efn.write('{fn}\t{err}\n'.format(err=e, fn=basename(file)))
continue
db.pack_db()
if db.nb_pos == 0:
raise ValueError(f'No positive examples found for kmer {args.target}')
......@@ -18,22 +18,25 @@ def main(args):
with open(args.kmer_list, 'r') as fh: kmer_list = [k.strip() for k in fh.readlines() if len(k.strip())]
with open(args.parameter_file, 'r') as pf: params = yaml.load(pf, Loader=yaml.FullLoader)
# Construct and run snakemake pipeline
with open(f'{__location__}/run_production_pipeline.sf', 'r') as fh: template_txt = fh.read()
sm_text = Template(template_txt).render(
__location__=__location__,
db_dir=db_dir,
nn_dir=nn_dir,
logs_dir=logs_dir,
parameter_file=args.parameter_file,
train_reads=args.training_reads,
test_reads=args.test_reads,
kmer_list=kmer_list,
filter_width=params['filter_width'],
hdf_path=args.hdf_path,
uncenter_kmer=args.uncenter_kmer
)
snakemake_dict = {'__location__': __location__,
'db_dir': db_dir,
'nn_dir': nn_dir,
'logs_dir': logs_dir,
'parameter_file': args.parameter_file,
'train_reads': args.training_reads,
'test_reads': args.test_reads,
'kmer_list': kmer_list,
'filter_width': params['filter_width'],
'hdf_path': args.hdf_path,
'uncenter_kmer': args.uncenter_kmer}
if args.train_reads_index:
snakemake_dict['read_index'] = args.train_reads_index
with open(f'{__location__}/run_production_pipeline.sf', 'r') as fh:
template_txt = fh.read()
sm_text = Template(template_txt).render(snakemake_dict)
sf_fn = f'{args.out_dir}nn_production_pipeline.sf'
with open(sf_fn, 'w') as fh: fh.write(sm_text)
with open(sf_fn, 'w') as fh:
fh.write(sm_text)
snakemake(sf_fn, cores=args.cores, verbose=False, keepgoing=True)
# jinja args: location, db_dir, nn_dir, width, hdf_path, uncenter_kmer
# jinja args: location, db_dir, nn_dir, width, hdf_path, uncenter_kmer, read_index
__location__ = "{{ __location__ }}"
db_dir = '{{ db_dir }}'
......@@ -53,7 +53,7 @@ rule generate_training_db:
--width {filter_width} \
--hdf-path {hdf_path} \
--uncenter-kmer \
--randomize &> {logs_dir}db_train_{wildcards.target}.log
{% if read_index %} --read-index {{ read_index }} {% endif %} --randomize &> {logs_dir}db_train_{wildcards.target}.log
"""
rule generate_test_db:
......@@ -73,5 +73,5 @@ rule generate_test_db:
--width {filter_width} \
--hdf-path {hdf_path} \
--uncenter-kmer \
--randomize &>{logs_dir}db_test_{wildcards.target}.log
--randomize &> {logs_dir}db_test_{wildcards.target}.log
"""
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment