Skip to content
Snippets Groups Projects
Commit fec53958 authored by Noordijk, Ben's avatar Noordijk, Ben
Browse files

Changed predict function so it predicts properly when called from run_inference.py.

However, CNN call is hardcoded now in an ugly way.
parent f4a9f6e2
No related branches found
No related tags found
No related merge requests found
......@@ -5,6 +5,7 @@ import tensorflow as tf
from pathlib import Path
from os.path import splitext
from tempfile import TemporaryDirectory
from nns.Cnn_test import NeuralNetwork
from nns.keras_metrics_from_logits import precision, recall, binary_accuracy
......@@ -38,18 +39,32 @@ class InferenceModel(object):
:param mod_fn: Path to tarred file of compiled model
:return:
"""
hardcoded_cnn_settings = {
'target': 'CTGCTCCC',
'filter_width': 1000,
'kernel_size': 19,
'max_sequence_length': 5000,
'batch_size': 32,
'threshold': 0.0,
'eps_per_kmer_switch': 25,
'filters': 5,
'learning_rate': 0.002,
'pool_size': 8,
'dropout_keep_prob': 0.9,
'num_layers': 2,
'batch_norm': 0
}
with TemporaryDirectory() as td:
with tarfile.open(mod_fn) as fh:
fh.extractall(td)
out_dict = {}
for mn in Path(td).iterdir():
out_dict[splitext(mn.name)[0]] = tf.keras.models.load_model(mn,
custom_objects={'precision': precision,
'recall': recall,
'binary_accuracy': binary_accuracy})
k_mer = splitext(mn.name)[0]
location = f'/home/noord087/lustre_link/mscthesis/baseless/baseless_2_on_16s/nns/{k_mer}/nn.h5'
out_dict[k_mer] = NeuralNetwork(weights=location, **hardcoded_cnn_settings)
# Dictionary with kmer string as key and keras.Sequential as value
self._model_dict = out_dict
# List of kmers that the InferenceModel contains models for
self.kmers = list(self._model_dict)
# Length of input that should be given to each individual model
self.input_length = self._model_dict[list(self._model_dict)[0]].layers[0].input_shape[1]
self.input_length = self._model_dict[k_mer].filter_width
......@@ -7,8 +7,6 @@ from inference.InferenceModel import InferenceModel
def main(args):
mod = InferenceModel(args.model) # Load model
# mod_dict = load_model(args.model)
# input_length = mod_dict[list(mod_dict)[0]].layers[0].input_shape[1]
pos_reads_dir = parse_output_path(args.out_dir + 'pos_reads')
# Load read table, start table manager
......@@ -36,7 +34,8 @@ def main(args):
# Start inference loop
while end_condition():
read_id, read, kmer = read_table.get_read_to_predict()
if read is None: continue
if read is None:
continue
pred = mod.predict(read, kmer)
# pred = np.any(mod_dict[kmer].predict(read) > 0.0)
read_table.update_prediction(read_id, kmer, pred)
......
......@@ -130,33 +130,20 @@ class NeuralNetwork(object):
def predict(self, x, clean_signal=True, return_probs=False):
"""Given sequences input as x, predict if they contain target k-mer
"""
offset = 5
ho = offset // 2
lb, rb = self.hfw - ho, self.hfw + ho + 1
idx = np.arange(self.filter_width, len(x) + offset, offset)
x_batched = [x[si:ei] for si, ei in zip(idx - self.filter_width, idx)]
x_pad = pad_sequences(x_batched, padding='post', dtype='float32')
x_pad = np.expand_dims(pad_sequences(x, maxlen=self.filter_width,
padding='post', truncating='post',
dtype='float32'), -1)
posteriors = self.model.predict(x_pad)
# Put predicted class = 1 where posterior is larger than threshold
posteriors = self.model.predict(x_pad)
y_hat = posteriors > self.threshold
offset = 5
ho = offset // 2
lb, rb = self.hfw - ho, self.hfw + ho + 1
y_out = np.zeros(len(x), dtype=int)
for i, yh in enumerate(y_hat):
y_out[lb + i * offset:rb + i * offset] = yh
# todo include clean signal
if return_probs:
posteriors_out = np.zeros(len(x), dtype=float)
for i, p in enumerate(posteriors):
posteriors_out[lb + i * offset:rb + i * offset] = p
return y_out, posteriors_out
y_out[lb + i * offset: rb + i * offset] = yh
return y_out
#
#
# true_ids = np.where(posteriors > self.threshold)
# y_out = np.zeros(len(x_pad), dtype=int)
# np.put(y_out, true_ids, 1)
# if return_probs:
# return y_out, np.array(posteriors)
#
# return y_out
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment