Skip to content
Snippets Groups Projects
Commit 7e059cdf authored by Noordijk, Ben's avatar Noordijk, Ben
Browse files

CNN can now load weights from nn.h5 file

parent c02c5630
Branches
No related tags found
1 merge request!3Added data preparation, hyperparameter optimisation, benchmarking code and k-mer library visualisation
import tensorflow as tf
from tensorflow.keras.preprocessing.sequence import pad_sequences
from tensorflow.keras import models, layers
from tensorflow.keras.models import load_model
# from tensorflow.keras.metrics import Accuracy, Precision, Recall
import numpy as np
from helper_functions import clean_classifications
......@@ -15,9 +17,7 @@ class NeuralNetwork(object):
:type target: str
:param kernel_size: Kernel size of CNN
:type kernel_size: int
:param weights: Initial weights to use for the neural network
:param max_sequence_length: Length to which pad or truncate the sequences
:type max_sequence_length: int
:param weights: Path to h5 file that contains model weights, optional
:param batch_size: Batch size to use during training
:param threshold: Assign label to TRUE if probability above this threshold
when doing prediction
......@@ -45,17 +45,16 @@ class NeuralNetwork(object):
self.model = models.Sequential()
self.initialize(kwargs['weights'])
self.history = {'loss':[], 'binary_accuracy': [], 'precision': [],
self.history = {'loss': [], 'binary_accuracy': [], 'precision': [],
'recall': [], 'val_loss': [], 'val_binary_accuracy': [],
'val_precision': [], 'val_recall': []}
def initialize(self, weights):
"""
Initialize the network. If weights are provided, these should be loaded.
def initialize(self, weights=None):
"""Initialize the network.
:param weights: Path to .h5 model summary with weights, optional.
If provided, use this to set the model weights
"""
if weights:
raise NotImplementedError('Cannot provide weights now,'
' still needs to be implemented')
# First layer
self.model.add(layers.Conv1D(self.filters,
kernel_size=self.kernel_size,
......@@ -77,8 +76,12 @@ class NeuralNetwork(object):
self.model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=self.learning_rate),
loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),
metrics=[binary_accuracy, precision, recall])
# Uncomment to print model summary
print(self.model.summary())
if weights:
self.model.load_weights(weights)
print('Successfully loaded weights')
# # Uncomment to print model summary
# self.model.summary()
def train(self, x, y, x_val, y_val, quiet=False, eps_per_kmer_switch=100):
"""Train the network. x_val/y_val may be used for validation/early
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment