Skip to content
Snippets Groups Projects
Commit b8ada615 authored by Noordijk, Ben's avatar Noordijk, Ben
Browse files

Fixed bug where CNN couldn't handle the 'quiet' kwarg

parent b4da7db6
No related branches found
No related tags found
1 merge request!1All MSc thesis work so far, mainly on CNNs, K-mer design and data preprocessing
......@@ -6,9 +6,10 @@
# RNN ARCHITECTURE
nn_class: Cnn_test
batch_size: 32
eps_per_kmer_switch: 20
eps_per_kmer_switch: 21
max_sequence_length: 500 # Only for example reads
kernel_size: 10
kernel_size: 3
filters: 4
threshold: 0.95
num_batches: 320
learning_rate: 0.01
......
......@@ -33,10 +33,11 @@ class NeuralNetwork(object):
self.batch_size = kwargs['batch_size']
self.threshold = kwargs['threshold']
self.eps_per_kmer_switch = kwargs['eps_per_kmer_switch']
self.filters = kwargs['filters']
self.initialize(kwargs['weights'])
self.history = {'loss':[], 'binary_accuracy': [], 'precision': [],
'recall': [], 'val_loss': [], 'val_binary_accuracy':[],
'recall': [], 'val_loss': [], 'val_binary_accuracy': [],
'val_precision': [], 'val_recall': []}
def initialize(self, weights):
......@@ -45,7 +46,9 @@ class NeuralNetwork(object):
"""
self.model = models.Sequential()
self.model.add(layers.Conv1D(5, kernel_size=self.kernel_size, activation='relu',
self.model.add(layers.Conv1D(self.filters,
kernel_size=self.kernel_size,
activation='relu',
input_shape=(self.max_sequence_length, 1)))
# self.model.add(layers.MaxPool1D(2)) # Might use this later
self.model.add(layers.Flatten())
......@@ -53,8 +56,9 @@ class NeuralNetwork(object):
self.model.compile(optimizer='adam',
loss=tf.keras.losses.BinaryCrossentropy(),
metrics=['BinaryAccuracy', 'Precision', 'Recall'])
print(self.model.summary())
def train(self, x, y, x_val, y_val, eps_per_kmer_switch=100):
def train(self, x, y, x_val, y_val, quiet=False, eps_per_kmer_switch=100):
"""Train the network. x_val/y_val may be used for validation/early
stopping mechanisms.
......@@ -62,19 +66,23 @@ class NeuralNetwork(object):
:param y: Ground truth labels of the read
:param x_val: Input reads to use for validation
:param y_val: Ground truth reads to use for validation
:param quiet: If set to true, does not print to console
"""
# Pad input sequences
x_pad = np.expand_dims(pad_sequences(x, maxlen=self.max_sequence_length,
padding='post', truncating='post'), -1)
x_val_pad = np.expand_dims(pad_sequences(x_val, maxlen=self.max_sequence_length,
x_val_pad = np.expand_dims(pad_sequences(x_val,
maxlen=self.max_sequence_length,
padding='post', truncating='post'), -1)
# Create tensorflow dataset
tfd = tf.data.Dataset.from_tensor_slices((x_pad, y)).batch(
self.batch_size).shuffle(x_pad.shape[0],
reshuffle_each_iteration=True)
# Train the model
self.model.fit(tfd, epochs=self.eps_per_kmer_switch,
validation_data=(x_val_pad, y_val))
validation_data=(x_val_pad, y_val),
verbose=[2, 0][quiet])
for hv in self.model.history.history:
self.history[hv].extend(self.model.history.history[hv])
......
......@@ -14,7 +14,8 @@ import reader
from helper_functions import load_db, parse_output_path, plot_timeseries
def train(parameter_file, training_data, test_data, plots_path=None, save_model=None, model_weights=None, quiet=False):
def train(parameter_file, training_data, test_data, plots_path=None,
save_model=None, model_weights=None, quiet=False):
timestamp = datetime.now().strftime('%y-%m-%d_%H:%M:%S')
# Load parameter file
if type(parameter_file) == str:
......@@ -45,9 +46,10 @@ def train(parameter_file, training_data, test_data, plots_path=None, save_model=
save_weights_only=True,
save_freq=params['batch_size'])
# create rnn
# create nn
nn_class = importlib.import_module(f'nns.{params["nn_class"]}').NeuralNetwork
rnn = nn_class(**params, target=train_db.target, weights=model_weights, cp_callback=cp_callback)
nn = nn_class(**params, target=train_db.target, weights=model_weights,
cp_callback=cp_callback)
# Start training
worst_kmers = []
......@@ -55,7 +57,8 @@ def train(parameter_file, training_data, test_data, plots_path=None, save_model=
for epoch_index in range(1, params['num_kmer_switches'] + 1):
x_train, y_train = train_db.get_training_set(nb_examples, [])
# x_train, y_train = train_db.get_training_set(nb_examples, worst_kmers) # todo worst_kmers mechanism errors out, fix
rnn.train(x_train, y_train, x_val, y_val, params['eps_per_kmer_switch'], quiet=quiet)
nn.train(x_train, y_train, x_val, y_val,
eps_per_kmer_switch=params['eps_per_kmer_switch'], quiet=quiet)
# Run on whole training reads for selection of top 5 wrongly classified k-mers
# random.shuffle(train_npzs)
......@@ -63,8 +66,8 @@ def train(parameter_file, training_data, test_data, plots_path=None, save_model=
# squiggle_count = 0
# for npz in train_npzs:
# x, sequence_length, kmers = reader.npz_to_tf_and_kmers(npz, params['max_sequence_length'], target_kmer=train_db.target) # todo here too, heckin buggy
# if x.shape[0] > rnn.filter_width and train_db.target in kmers:
# y_hat = rnn.predict(x, clean_signal=True)
# if x.shape[0] > nn.filter_width and train_db.target in kmers:
# y_hat = nn.predict(x, clean_signal=True)
# predicted_pos_list.append(kmers[y_hat])
# squiggle_count += 1
# if squiggle_count == params['batch_size']:
......@@ -87,14 +90,14 @@ def train(parameter_file, training_data, test_data, plots_path=None, save_model=
for i, npz in enumerate(ts_npzs):
x, sequence_length, kmers = reader.npz_to_tf_and_kmers(npz, target_kmer=train_db.target)
# x, sequence_length, kmers = reader.npz_to_tf_and_kmers(npz, params['max_sequence_length'], )
if x.shape[0] > rnn.filter_width and (train_db.target in kmers or i+1 == len(ts_npzs)):
if x.shape[0] > nn.filter_width and (train_db.target in kmers or i+1 == len(ts_npzs)):
tr_fn = splitext(basename(npz))[0]
start_idx = np.argwhere(train_db.target == kmers).min()
oh = params['max_sequence_length'] // 2
sidx = np.arange(max(0,start_idx-oh),min(start_idx+oh, len(x)))
x = x[sidx, :]
kmers = kmers[sidx]
y_hat, posterior = rnn.predict(x, clean_signal=True, return_probs=True)
y_hat, posterior = nn.predict(x, clean_signal=True, return_probs=True)
ts_predict_name = ("{path}pred_ep{ep}_ex{ex}.npz".format(path=ts_predict_path,
ep=epoch_index,
ex=tr_fn))
......@@ -114,7 +117,7 @@ def train(parameter_file, training_data, test_data, plots_path=None, save_model=
break
if (i+1) == len(ts_npzs):
raise UserWarning('None of test npzs suitable for testing, no squiggle plot generated.')
return rnn
return nn
def main(args):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment