Skip to content
Snippets Groups Projects
Commit bc10ef08 authored by Noordijk, Ben's avatar Noordijk, Ben
Browse files

Changed model loading and added docstrings

parent a133f0c7
No related branches found
No related tags found
No related merge requests found
......@@ -39,31 +39,13 @@ class InferenceModel(object):
:param mod_fn: Path to tarred file of compiled model
:return:
"""
# TODO this is ugly and temporary and should be fixed soon
hardcoded_cnn_settings = {
'target': 'CTGCTCCC',
'filter_width': 1000,
'kernel_size': 19,
'max_sequence_length': 5000,
'batch_size': 32,
'threshold': 0.0,
'eps_per_kmer_switch': 25,
'filters': 5,
'learning_rate': 0.002,
'pool_size': 8,
'dropout_keep_prob': 0.9,
'num_layers': 2,
'batch_norm': 0
}
with TemporaryDirectory() as td:
with tarfile.open(mod_fn) as fh:
fh.extractall(td)
out_dict = {}
for mn in Path(td).iterdir():
k_mer = splitext(mn.name)[0]
# TODO this too
location = f'/home/noord087/lustre_link/mscthesis/baseless/baseless_2_on_16s/nns/{k_mer}/nn.h5'
out_dict[k_mer] = NeuralNetwork(weights=location, **hardcoded_cnn_settings)
for saved_model in Path(td).iterdir():
k_mer = saved_model.stem
out_dict[k_mer] = NeuralNetwork(target=k_mer, weights=saved_model)
# Dictionary with kmer string as key and keras.Sequential as value
self._model_dict = out_dict
# List of kmers that the InferenceModel contains models for
......
......@@ -64,6 +64,13 @@ class ReadTable(object):
return None, None, None
def load_read(self, fn):
"""Load read from fast5 file name. Already splits the read into
correct input length
:param fn: file name of fast5 file
:return: Split read of normalized read
:rtype: Read
"""
if fn in self.read_dict: return self.read_dict[fn]
with h5py.File(fn, 'r') as fh:
self.read_dict[fn] = Read(fh, 'median').get_split_raw_read(self.input_length)
......
......@@ -33,7 +33,15 @@ class NeuralNetwork(object):
"""
def __init__(self, **kwargs):
self.history = {'loss': [], 'binary_accuracy': [], 'precision': [],
'recall': [], 'val_loss': [], 'val_binary_accuracy': [],
'val_precision': [], 'val_recall': []}
# Must always provide target keyword
self.target = kwargs['target']
if kwargs['weights']:
self.initialize(kwargs['weights'])
self.filter_width = self.model.layers[0].input_shape[1]
return
self.filter_width = kwargs['filter_width']
self.hfw = (self.filter_width - 1) // 2 # half filter width
self.kernel_size = kwargs['kernel_size']
......@@ -48,11 +56,6 @@ class NeuralNetwork(object):
self.num_layers = kwargs['num_layers']
self.batch_norm = kwargs['batch_norm']
self.initialize(kwargs['weights'])
self.history = {'loss': [], 'binary_accuracy': [], 'precision': [],
'recall': [], 'val_loss': [], 'val_binary_accuracy': [],
'val_precision': [], 'val_recall': []}
def initialize(self, weights=None):
"""Initialize the network.
......@@ -63,7 +66,8 @@ class NeuralNetwork(object):
self.model = tf.keras.models.load_model(weights, custom_objects={
'precision': precision, 'recall': recall,
'binary_accuracy': binary_accuracy})
print('Successfully loaded weights')
print(f'Successfully loaded weights for {self.target}')
return
# First layer
......@@ -128,7 +132,15 @@ class NeuralNetwork(object):
self.history[metric].extend(self.model.history.history[metric])
def predict(self, x, clean_signal=True, return_probs=False):
"""Given sequences input as x, predict if they contain target k-mer
"""Given sequences input as x, predict if they contain target k-mer.
Assumes that read is already split up into smaller batches.
:param x: Squiggle as numeric representation
:type x: np.ndarray
:param clean_signal:
:param return_probs:
:return: unnormalized predicted values
:rtype: np.array of posteriors
"""
x_pad = np.expand_dims(pad_sequences(x, maxlen=self.filter_width,
padding='post', truncating='post',
......@@ -136,14 +148,18 @@ class NeuralNetwork(object):
# Put predicted class = 1 where posterior is larger than threshold
posteriors = self.model.predict(x_pad)
y_hat = posteriors > self.threshold
offset = 5
ho = offset // 2
lb, rb = self.hfw - ho, self.hfw + ho + 1
y_out = np.zeros(len(x), dtype=int)
for i, yh in enumerate(y_hat):
y_out[lb + i * offset: rb + i * offset] = yh
return y_out
return posteriors
# y_hat = posteriors > self.threshold
# offset = 5
# ho = offset // 2
# lb, rb = self.hfw - ho, self.hfw + ho + 1
# y_out = np.zeros(len(x), dtype=int)
# for i, yh in enumerate(y_hat):
# y_out[lb + i * offset: rb + i * offset] = yh
# if return_probs:
# posteriors_out = np.zeros(len(x), dtype=float)
# for i, p in enumerate(posteriors):
# posteriors_out[lb + i * offset: rb + i * offset] = p
# return y_out, posteriors_out
# return y_out
......@@ -118,10 +118,10 @@ def train(parameter_file, training_data, test_data, plots_path=None,
if (i+1) == len(ts_npzs):
raise UserWarning('None of test npzs suitable for testing, no squiggle plot generated.')
# # Uncomment to print confusion matrix
# # Rows are true labels, columns are predicted labels
# prediction = nn.predict(x_val)
# print(tf.math.confusion_matrix(y_val, prediction))
# Uncomment to print confusion matrix
# Rows are true labels, columns are predicted labels
prediction = nn.predict(x_val)
print(tf.math.confusion_matrix(y_val, prediction))
return nn
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment