Skip to content
Snippets Groups Projects
Commit 9f2fe975 authored by Carlos de Lannoy's avatar Carlos de Lannoy
Browse files

fix accidental float-to-int casting, fix predict function

parent 533d0c9c
No related branches found
No related tags found
No related merge requests found
......@@ -31,9 +31,9 @@ class NeuralNetwork(object):
def __init__(self, **kwargs):
self.target = kwargs['target']
self.filter_width = kwargs['filter_width']
self.hfw = (self.filter_width - 1) // 2 # half filter width
self.kernel_size = kwargs['kernel_size']
self.max_sequence_length = self.filter_width # TODO this is an experiment
# self.max_sequence_length = kwargs['max_sequence_length']
self.max_sequence_length = kwargs['max_sequence_length']
self.batch_size = kwargs['batch_size']
self.threshold = kwargs['threshold']
self.eps_per_kmer_switch = kwargs['eps_per_kmer_switch']
......@@ -66,7 +66,7 @@ class NeuralNetwork(object):
self.model.add(layers.Conv1D(self.filters,
kernel_size=self.kernel_size,
activation='relu',
input_shape=(self.max_sequence_length, 1)))
input_shape=(self.filter_width, 1)))
for _ in range(self.num_layers):
if self.batch_norm:
self.model.add(layers.BatchNormalization())
......@@ -101,11 +101,11 @@ class NeuralNetwork(object):
:param quiet: If set to true, does not print to console
"""
# Pad input sequences
x_pad = np.expand_dims(pad_sequences(x, maxlen=self.max_sequence_length,
padding='post', truncating='post'), -1)
x_pad = np.expand_dims(pad_sequences(x, maxlen=self.filter_width,
padding='post', truncating='post', dtype='float32'), -1)
x_val_pad = np.expand_dims(pad_sequences(x_val,
maxlen=self.max_sequence_length,
padding='post', truncating='post'), -1)
maxlen=self.filter_width,
padding='post', truncating='post', dtype='float32'), -1)
# Create tensorflow dataset
tfd = tf.data.Dataset.from_tensor_slices((x_pad, y)).batch(
self.batch_size).shuffle(x_pad.shape[0],
......@@ -121,17 +121,33 @@ class NeuralNetwork(object):
def predict(self, x, clean_signal=True, return_probs=False):
"""Given sequences input as x, predict if they contain target k-mer
"""
# Pad the sequences to max sequence length
x_pad = np.expand_dims(pad_sequences(x, maxlen=self.max_sequence_length,
padding='post', truncating='post'), -1)
offset = 5
ho = offset // 2
lb, rb = self.hfw - ho, self.hfw + ho + 1
idx = np.arange(self.filter_width, len(x) + offset, offset)
x_batched = [x[si:ei] for si, ei in zip(idx - self.filter_width, idx)]
x_pad = pad_sequences(x_batched, padding='post', dtype='float32')
# Predicted posterior probabilities per read that it contains k-mer
posteriors = self.model.predict(x_pad)
# Put predicted class = 1 where posterior is larger than threshold
true_ids = np.where(posteriors > self.threshold)
y_out = np.zeros(len(x_pad), dtype=int)
np.put(y_out, true_ids, 1)
if return_probs:
return y_out, np.array(posteriors)
y_hat = posteriors > self.threshold
y_out = np.zeros(len(x), dtype=int)
for i, yh in enumerate(y_hat):
y_out[lb + i * offset:rb + i * offset] = yh
# todo include clean signal
if return_probs:
posteriors_out = np.zeros(len(x), dtype=float)
for i, p in enumerate(posteriors):
posteriors_out[lb + i * offset:rb + i * offset] = p
return y_out, posteriors_out
return y_out
#
#
# true_ids = np.where(posteriors > self.threshold)
# y_out = np.zeros(len(x_pad), dtype=int)
# np.put(y_out, true_ids, 1)
# if return_probs:
# return y_out, np.array(posteriors)
#
# return y_out
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment