Skip to content
Snippets Groups Projects
Commit 9386da11 authored by Carlos de Lannoy's avatar Carlos de Lannoy
Browse files

less error-prone model loading

parent c3e2abd2
No related merge requests found
......@@ -43,7 +43,6 @@ class NeuralNetwork(object):
self.dropout_remove_prob = 1 - kwargs['dropout_keep_prob']
self.num_layers = kwargs['num_layers']
self.batch_norm = kwargs['batch_norm']
self.model = models.Sequential()
self.initialize(kwargs['weights'])
self.history = {'loss': [], 'binary_accuracy': [], 'precision': [],
......@@ -56,7 +55,14 @@ class NeuralNetwork(object):
:param weights: Path to .h5 model summary with weights, optional.
If provided, use this to set the model weights
"""
if weights:
self.model = tf.keras.models.load_model(weights, custom_objects={
'precision': precision, 'recall': recall, 'binary_accuracy': binary_accuracy})
print('Successfully loaded weights')
return
# First layer
self.model = models.Sequential()
self.model.add(layers.Conv1D(self.filters,
kernel_size=self.kernel_size,
activation='relu',
......@@ -77,9 +83,9 @@ class NeuralNetwork(object):
self.model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=self.learning_rate),
loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),
metrics=[binary_accuracy, precision, recall])
if weights:
self.model.load_weights(weights)
print('Successfully loaded weights')
# if weights:
# self.model.load_weights(weights)
# print('Successfully loaded weights')
# # Uncomment to print model summary
# self.model.summary()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment