Skip to content
Snippets Groups Projects
Commit 4c30d016 authored by Carlos de Lannoy's avatar Carlos de Lannoy
Browse files

max test size for quick benchmark, continuous nn support

parent a5f45ffe
No related branches found
No related tags found
No related merge requests found
......@@ -269,6 +269,7 @@ def get_run_inference_parser():
parser = argparse.ArgumentParser(description='Start up inference routine and watch a fast5 directory for reads.')
for arg in (fast5_in, out_dir, model, inference_mode):
parser.add_argument(arg[0], **arg[1])
parser.add_argument('--continuous-nn', action='store_true',help='Used RNN can handle continuous reads.')
return parser
......
......@@ -2,6 +2,7 @@
import tarfile
import numpy as np
import tensorflow as tf
import tensorflow.keras.backend as K
from pathlib import Path
from os.path import splitext
......@@ -32,11 +33,21 @@ class InferenceModel(object):
with tarfile.open(mod_fn) as fh:
fh.extractall(td)
out_dict = {}
def loss_fun(y_true, y_pred): # just dummy to satisfy the stupid thing being there
msk = np.zeros(100, dtype=bool)
msk[50] = True
y_pred_single = tf.boolean_mask(y_pred, msk, axis=1)
return K.binary_crossentropy(K.cast(y_true, K.floatx()), y_pred_single, from_logits=True)
for mn in Path(td).iterdir():
out_dict[splitext(mn.name)[0]] = tf.keras.models.load_model(mn,
custom_objects={'precision': precision,
'recall': recall,
'binary_accuracy': binary_accuracy})
'binary_accuracy': binary_accuracy,
'loss_fun': loss_fun})
self._model_dict = out_dict
self.kmers = list(self._model_dict)
self.input_length = self._model_dict[list(self._model_dict)[0]].layers[0].input_shape[1]
try:
self.input_length = self._model_dict[list(self._model_dict)[0]].layers[0].input_shape[1]
except:
self.input_length = None # todo for undefined timseries length nn, find better solution
......@@ -37,21 +37,23 @@ class ReadTable(object):
manager_process.start()
return manager_process
def get_read_to_predict(self):
def get_read_to_predict(self, unsplit):
tup_list = safe_cursor(self.conn, "SELECT * FROM read_table ORDER BY ROWID ASC LIMIT 10")
if len(tup_list) == 0: return None, None, None
for tup in tup_list:
kmi_list = [ki for ki, kb in enumerate(tup[2:]) if kb == 0]
if not len(kmi_list): continue
kmer_index = kmi_list[0]
return tup[0], self.load_read(tup[0]), self.kmers[kmer_index]
return tup[0], self.load_read(tup[0], unsplit), self.kmers[kmer_index]
return None, None, None
def load_read(self, fn):
def load_read(self, fn, unsplit=False):
if fn in self.read_dict: return self.read_dict[fn]
with h5py.File(fn, 'r') as fh:
self.read_dict[fn] = Read(fh, 'median').get_split_raw_read(self.input_length)
# self.read_dict[fn] = np.expand_dims(pad_sequences(read, padding='post', dtype='float32'), -1)
if unsplit:
self.read_dict[fn] = Read(fh, 'median').raw
else:
self.read_dict[fn] = Read(fh, 'median').get_split_raw_read(self.input_length)
return self.read_dict[fn]
def update_prediction(self, fn, kmer, pred):
......
......@@ -36,7 +36,7 @@ def main(args):
# Start inference loop
while end_condition():
read_id, read, kmer = read_table.get_read_to_predict()
read_id, read, kmer = read_table.get_read_to_predict(unsplit=args.continuous_nn)
if read is None: continue
pred = mod.predict(read, kmer)
# pred = np.any(mod_dict[kmer].predict(read) > 0.0)
......
......@@ -22,6 +22,7 @@ def parse_fast5_list(fast5_list, gt_df):
'fn': fast5_list,
'tp': [gt_df.tp.get(ff, 'unknown') for ff in fast5_basename_list]}
).set_index('read_id')
fast5_df.drop(fast5_df.query('tp == "unknown"').index, axis=0, inplace=True)
fast5_df.tp = fast5_df.tp.astype(bool)
return fast5_df
......@@ -36,6 +37,8 @@ parser.add_argument('--ground-truth', type=str, required=True,
help='Ground truth text file.')
parser.add_argument('--parameter-file', type=str, required=True,
help='yaml defining network architecture.')
parser.add_argument('--max-test-size', type=int, default=100,
help='Test read set size [default: 100]')
parser.add_argument('--out-dir', type=str, required=True,
help='output directory.')
parser.add_argument('--hdf-path', type=str, default='Analyses/RawGenomeCorrected_000')
......@@ -68,11 +71,15 @@ train_num_idx, _ = list(StratifiedShuffleSplit(n_splits=1, test_size=0.1).split(
train_idx = fast5_df.index[train_num_idx]
fast5_df.loc[:, 'fold'] = False
fast5_df.loc[train_idx, f'fold'] = True
fast5_test_df = fast5_df.query('fold == False')
if len(fast5_test_df) > args.max_test_size:
fast5_df = pd.concat((fast5_df.query('fold').copy(), fast5_test_df.sample(args.max_test_size).copy()))
fast5_df.to_csv(read_index_fn, columns=['fn', 'tp', 'fold'])
# --- make test reads folder ---
test_read_dir = parse_output_path(out_dir + 'test_reads')
for fn, tup in fast5_df.query('fold == False').iterrows():
for rc, (fn, tup) in enumerate(fast5_df.query('fold == False').iterrows()):
copy(fast5_dir + fn, test_read_dir)
# --- render snakemake script ---
......@@ -91,6 +98,7 @@ sm_text = Template(template_txt).render(
logs_dir=logs_dir,
parameter_file=args.parameter_file,
filter_width=params['filter_width'],
continuous_nn=params.get('continuous_nn', False),
hdf_path=args.hdf_path
)
......
......@@ -51,7 +51,7 @@ rule run_inference:
python {baseless_location} run_inference \
--fast5-in {input.fast5_in} \
--model {input.model} \
--out-dir {output.out_dir} \
--out-dir {output.out_dir}{% if continuous_nn %} --continuous-nn{% endif %} \
--inference-mode once > {logs_dir}inference_species.log
"""
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment