Skip to content
Snippets Groups Projects
Commit 806d8816 authored by Carlos de Lannoy's avatar Carlos de Lannoy
Browse files

add inference model class, more diverse neg examples, misc

parent d5aa0682
No related branches found
No related tags found
No related merge requests found
......@@ -38,14 +38,15 @@ class ExampleDb(object):
# --- add positive examples (if any) ---
pos_examples = training_read.get_pos(self.target, self.width, uncenter_kmer)
for i, ex in enumerate(pos_examples): conn.root.pos[len(conn.root.pos)] = ex
nb_new_positives = len(pos_examples)
# --- update record nb positive examples ---
if self._db_empty:
if len(pos_examples):
if nb_new_positives > 0:
self._db_empty = False
if not self._db_empty:
self.nb_pos = conn.root.pos.maxKey()
# --- add negative examples ---
neg_examples, neg_kmers = training_read.get_neg(self.target, self.width, len(pos_examples) * 5) # arbitrarily adding 5x as much neg examples
neg_examples, neg_kmers = training_read.get_neg(self.target, self.width, len(pos_examples) * 10) # arbitrarily adding 10x as much neg examples
for i, ex in enumerate(neg_examples):
if neg_kmers[i] in self.neg_kmers:
self.neg_kmers[neg_kmers[i]].append(self.nb_neg + i)
......@@ -54,6 +55,7 @@ class ExampleDb(object):
conn.root.neg[self.nb_neg+i] = ex
# self.nb_neg += len(neg_examples)
conn.root.neg_kmers = self.neg_kmers
return nb_new_positives
def get_training_set(self, size=None, includes=[]):
"""
......
......@@ -226,10 +226,14 @@ class TrainingRead(Persistent):
# Make sure the negative examples are far enough away from
# the target k mer:
if np.all(distances_to_kmer > width):
for ii in range(len(cur_condensed_event[2])):
mid_idx = sum(self.event_length_list[:cur_idx]) + ii + 1
raw_hits_out.append(self.raw[mid_idx - width_l:
mid_idx + width_r])
raw_kmers_out.append(cur_condensed_event[0])
mid_idx = sum(self.event_length_list[:cur_idx]) + random.randint(0, len(cur_condensed_event[2]))
raw_hits_out.append(self.raw[mid_idx - width_l:mid_idx + width_r])
raw_kmers_out.append(cur_condensed_event[0])
# for ii in range(len(cur_condensed_event[2])):
# mid_idx = sum(self.event_length_list[:cur_idx]) + ii + 1
# raw_hits_out.append(self.raw[mid_idx - width_l:
# mid_idx + width_r])
# raw_kmers_out.append(cur_condensed_event[0])
idx_list.remove(cur_idx)
return raw_hits_out, raw_kmers_out
import tarfile
import numpy as np
import tensorflow as tf
from pathlib import Path
from os.path import splitext
from tempfile import TemporaryDirectory
from nns.keras_metrics_from_logits import precision, recall, binary_accuracy
class InferenceModel(object):
def __init__(self, mod_fn, batch_size=32):
self.load_model(mod_fn)
self.batch_size = batch_size
def predict(self, read, kmer):
return np.any(self._model_dict[kmer].predict(read) > 0.0) # todo make threshold feature?
# # For fixed batch sizes
# read, last_idx = read
# for batch in read[:-1]:
# if np.any(self._model_dict[kmer].predict(batch) > 0.0):
# return True
# return np.any(self._model_dict[kmer].predict(read[-1])[:last_idx] > 0.0)
def load_model(self, mod_fn):
with TemporaryDirectory() as td:
with tarfile.open(mod_fn) as fh:
fh.extractall(td)
out_dict = {}
for mn in Path(td).iterdir():
out_dict[splitext(mn.name)[0]] = tf.keras.models.load_model(mn,
custom_objects={'precision': precision,
'recall': recall,
'binary_accuracy': binary_accuracy})
self._model_dict = out_dict
self.kmers = list(self._model_dict)
self.input_length = self._model_dict[list(self._model_dict)[0]].layers[0].input_shape[1]
......@@ -5,6 +5,7 @@ import sqlite3
from pathlib import Path
from time import sleep
from os.path import basename
from inference.GracefulKiller import GracefulKiller
from contextlib import closing
from helper_functions import safe_cursor
......@@ -41,7 +42,7 @@ class ReadManager(object):
new_preds_df = pd.DataFrame(new_predictions, columns=self.table_columns).set_index('read_id')
for fn, tup in new_preds_df.iterrows():
if tup.loc[self.kmers].sum() == self.pos_check_total: # If all positive: move to confirmed positives and remove from table
move(fn, self.pos_reads_dir)
move(fn, self.pos_reads_dir + basename(fn))
safe_cursor(self.conn, f"DELETE from read_table WHERE read_id = '{fn}'", read_only=False)
elif np.any(tup.loc[self.kmers] < 0): # if any negative: delete read and remove from table
Path(fn).unlink(missing_ok=True)
......
......@@ -10,13 +10,14 @@ from db_building.TrainingRead import Read
from inference.ReadManager import ReadManager
from helper_functions import numeric_timestamp, safe_cursor
from contextlib import closing
class ReadTable(object):
def __init__(self, reads_dir, table_fn, pos_reads_dir, kmers, preallocate_size=100):
def __init__(self, reads_dir, table_fn, pos_reads_dir, kmers, input_length, batch_size=32):
Path(table_fn).unlink(missing_ok=True)
self.input_length = input_length
self.batch_size = batch_size
self.table_fn = table_fn
self.pos_reads_dir = pos_reads_dir
self.reads_dir = reads_dir
......@@ -36,21 +37,21 @@ class ReadTable(object):
manager_process.start()
return manager_process
def get_read_to_predict(self, length):
def get_read_to_predict(self):
tup_list = safe_cursor(self.conn, "SELECT * FROM read_table ORDER BY ROWID ASC LIMIT 10")
if len(tup_list) == 0: return None, None, None
for tup in tup_list:
kmi_list = [ki for ki, kb in enumerate(tup[2:]) if kb == 0]
if not len(kmi_list): continue
kmer_index = kmi_list[0]
return tup[0], self.load_read(tup[0], length), self.kmers[kmer_index]
return tup[0], self.load_read(tup[0]), self.kmers[kmer_index]
return None, None, None
def load_read(self, fn, length):
def load_read(self, fn):
if fn in self.read_dict: return self.read_dict[fn]
with h5py.File(fn, 'r') as fh:
read = Read(fh, 'median').get_split_raw_read(length)
self.read_dict[fn] = np.expand_dims(pad_sequences(read, padding='post', dtype='float32'), -1)
self.read_dict[fn] = Read(fh, 'median').get_split_raw_read(self.input_length)
# self.read_dict[fn] = np.expand_dims(pad_sequences(read, padding='post', dtype='float32'), -1)
return self.read_dict[fn]
def update_prediction(self, fn, kmer, pred):
......
......@@ -4,15 +4,16 @@ import numpy as np
from helper_functions import parse_output_path, parse_input_path
from inference.ReadTable import ReadTable
from helper_functions import load_model
from inference.InferenceModel import InferenceModel
def main(args):
mod_dict = load_model(args.model) # Load model
input_length = mod_dict[list(mod_dict)[0]].layers[0].input_shape[1]
mod = InferenceModel(args.model) # Load model
# mod_dict = load_model(args.model)
# input_length = mod_dict[list(mod_dict)[0]].layers[0].input_shape[1]
pos_reads_dir = parse_output_path(args.out_dir + 'pos_reads')
# Load read table, start table manager
read_table = ReadTable(args.fast5_in, args.out_dir + 'index_table.db', pos_reads_dir, list(mod_dict))
read_table = ReadTable(args.fast5_in, args.out_dir + 'index_table.db', pos_reads_dir, mod.kmers, mod.input_length)
read_manager_process = read_table.init_table()
# ensure processes are ended after exit
......@@ -35,9 +36,10 @@ def main(args):
# Start inference loop
while end_condition():
read_id, read, kmer = read_table.get_read_to_predict(input_length)
read_id, read, kmer = read_table.get_read_to_predict()
if read is None: continue
pred = np.any(mod_dict[kmer].predict(read) > 0.0)
pred = mod.predict(read, kmer)
# pred = np.any(mod_dict[kmer].predict(read) > 0.0)
read_table.update_prediction(read_id, kmer, pred)
else:
read_manager_process.terminate()
......
......@@ -63,7 +63,7 @@ class NeuralNetwork(object):
ho = offset // 2
lb, rb = self.hfw - ho, self.hfw + ho + 1
idx = np.arange(self.filter_width, len(x) + offset, offset)
x_batched = [x[si:ei] for si, ei in zip(idx-100, idx)]
x_batched = [x[si:ei] for si, ei in zip(idx-self.filter_width, idx)]
x_pad = pad_sequences(x_batched, padding='post', dtype='float32')
posteriors = self.model.predict(x_pad)
y_hat = posteriors > self.threshold
......
import argparse, re, os, sys
import numpy as np
import pandas as pd
import multiprocessing as mp
from datetime import datetime
from os.path import basename
from shutil import copy
from sklearn.model_selection import StratifiedShuffleSplit
__location__ = os.path.realpath(os.path.join(os.getcwd(), os.path.dirname(__file__)))
baseless_location = os.path.realpath(f'{__location__}/..')
sys.path.append(baseless_location)
from helper_functions import parse_output_path, parse_input_path
def parse_fast5_list(fast5_list, gt_df, out_queue):
fast5_basename_list = [basename(f) for f in fast5_list]
fast5_df = pd.DataFrame({'read_id': fast5_basename_list,
'fn': fast5_list,
'species': [gt_df.species_short.get(ff, 'unknown') for ff in fast5_basename_list]}
).set_index('read_id')
fast5_df.drop(fast5_df.query('species == "unknown"').index, axis=0, inplace=True)
out_queue.put(fast5_df)
def write_reads(fast5_sub_df, read_dir):
for _, fn in fast5_sub_df.fn.iteritems():
copy(fn, f'{read_dir}/')
parser = argparse.ArgumentParser(description='Take stratified test/train samples of defined size.')
parser.add_argument('--ground-truth', type=str, required=True,
help='csv denoting which species reads belong to')
parser.add_argument('--fast5-dir', type=str, required=True,
help='fast5 directory from which to take reads')
parser.add_argument('--out-dir', type=str, required=True)
parser.add_argument('--sample-size', type=int, default=1000,
help='Number of reads to take (test + train) [default: 1000]')
parser.add_argument('--split', type=float, default=0.8,
help='Train/test split [default: 0.8]')
parser.add_argument('--cores', type=int, default=4,
help='Number of cores to use, minimum 2 [default: 4]')
args = parser.parse_args()
# --- prep dirs ---
out_dir = parse_output_path(args.out_dir, clean=True)
train_dir = parse_output_path(out_dir + 'train')
test_dir = parse_output_path(out_dir + 'test')
# --- load ground truth ---
gt_df = pd.read_csv(args.ground_truth, header=0)
gt_df.columns = [cn.replace(' ', '_') for cn in gt_df.columns]
gt_df.set_index('file_name', inplace=True)
# --- index fast5 files ---
print(f'{datetime.now()}: parsing fast5 file list...')
fast5_list = np.array(parse_input_path(args.fast5_dir, pattern='*.fast5'))
nb_fast5 = len(fast5_list)
ff_idx_list = np.array_split(np.arange(len(fast5_list)), args.cores)
out_queue = mp.Queue()
ff_workers = [mp.Process(target=parse_fast5_list, args=(fast5_list[ff_idx], gt_df, out_queue)) for ff_idx in ff_idx_list]
for worker in ff_workers: tst = worker.start()
df_list = []
while any(p.is_alive() for p in ff_workers):
while not out_queue.empty():
df_list.append(out_queue.get())
fast5_df = pd.concat(df_list)
print(f'{datetime.now()}: done')
# --- generate split ---
# species_list = list(fast5_df.species.unique())
train_size = round(args.sample_size * args.split)
test_size = args.sample_size - train_size
train_num_idx, test_num_idx = tuple(StratifiedShuffleSplit(n_splits=1, train_size=train_size, test_size=test_size).split(fast5_df.index, fast5_df.species))[0]
train_idx, test_idx = fast5_df.index[train_num_idx], fast5_df.index[test_num_idx]
workers = [mp.Process(target=write_reads, args=(fast5_df.loc[idx, :], cdir))
for idx, cdir in ((train_idx, train_dir), (test_idx, test_dir))]
for worker in workers:
worker.start()
for worker in workers:
worker.join()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment