From 533b6e49cadf8d311791808e7de01a2b194a62c1 Mon Sep 17 00:00:00 2001
From: Carlos de Lannoy <cvdelannoy@gmail.com>
Date: Tue, 14 Jun 2022 02:29:18 +0200
Subject: [PATCH] bug fixes

---
 inference/compile_model.py | 20 ++++++++++++--------
 1 file changed, 12 insertions(+), 8 deletions(-)

diff --git a/inference/compile_model.py b/inference/compile_model.py
index 31b81e0..e4f0012 100644
--- a/inference/compile_model.py
+++ b/inference/compile_model.py
@@ -66,7 +66,7 @@ def get_kmer_candidates_16S(kmer_candidates_dict, min_nb_kmers=5, threshold=0.00
             prefiltered_kmer_list = [km for km in kmer_candidates_dict[kc] if km in filter_list]
         kmer_freqs_dict_cur = {km: kmer_freqs_dict.get(km, 0) for km in prefiltered_kmer_list}
         kmer_list = sorted(prefiltered_kmer_list, key=kmer_freqs_dict_cur.get, reverse=False)
-        # kmer_list = kmer_list[:20]  # todo arbitrarily set to 25, change to parameter
+        kmer_list = kmer_list[5:30]  # todo subset does not seem to have a positive effect
         kmer_candidates_selection_dict[kc] = kmer_list
         # kmer_candidates_selection_dict[kc] = []
         # sub_df = kmer_freqs_df.copy()
@@ -161,24 +161,27 @@ def compile_model(kmer_dict, filter_width, threshold, batch_size, parallel_model
     trained_layers_dict = {}
     mod_first = tf.keras.models.load_model(f'{list(kmer_dict.values())[0]}/nn.h5', compile=False)
     x = input
+    layer_index = 1
     for il, l in enumerate(mod_first.layers):
         if type(l) == tf.keras.layers.Dense:
             nl = tf.keras.layers.Dense(l.weights[0].shape[1] * nb_mods, activation=l.activation)
-            trained_layers_dict[il] = {'ltype': 'dense', 'weights': []}
+            trained_layers_dict[il] = {'ltype': 'dense', 'layer_index': layer_index, 'weights': []}
         elif type(l) == tf.keras.layers.Conv1D:
-            nl = tf.keras.layers.Conv1D(l.filters * nb_mods, l.kernel_size, activation=l.activation)
-            trained_layers_dict[il] = {'ltype': 'conv1d', 'weights': []}
+            nl = tf.keras.layers.Conv1D(l.filters * nb_mods, l.kernel_size, activation=l.activation, groups=nb_mods if il>0 else 1)
+            trained_layers_dict[il] = {'ltype': 'conv1d', 'layer_index': layer_index, 'weights': []}
         elif type(l) == tf.keras.layers.BatchNormalization:
             nl = tf.keras.layers.BatchNormalization()
-            trained_layers_dict[il] = {'ltype': 'batchnormalization', 'weights': []}
+            trained_layers_dict[il] = {'ltype': 'batchnormalization', 'layer_index': layer_index, 'weights': []}
         elif type(l) == tf.keras.layers.Dropout:
             nl = tf.keras.layers.Dropout(l.rate)
         elif type(l) == tf.keras.layers.MaxPool1D:
             nl = tf.keras.layers.MaxPool1D(l.pool_size)
         elif type(l) == tf.keras.layers.Flatten:
-            nl = tf.keras.layers.Flatten()
+            nl = lambda x: tf.concat([tf.keras.layers.Flatten()(xs) for xs in tf.split(x, nb_mods, axis=2)], -1)
+            layer_index += nb_mods + 1
         else:
             raise ValueError(f'models with layer type {type(l)} cannot be concatenated yet')
+        layer_index += 1
         x = nl(x)
     output = K.cast_to_floatx(K.greater(x, threshold))
     meta_mod = tf.keras.Model(inputs=input, outputs=output)
@@ -195,12 +198,13 @@ def compile_model(kmer_dict, filter_width, threshold, batch_size, parallel_model
         if trained_layers_dict[il]['ltype'] == 'conv1d':
             weight_list = concat_conv1d_weights(meta_mod.layers[il+1].weights, trained_layers_dict[il]['weights'])
         elif trained_layers_dict[il]['ltype'] == 'dense':
-            weight_list = concat_dense_weights(meta_mod.layers[il+1].weights, trained_layers_dict[il]['weights'])
+            weight_list = concat_dense_weights(meta_mod.layers[trained_layers_dict[il]['layer_index']].weights,
+                                               trained_layers_dict[il]['weights'])
         elif trained_layers_dict[il]['ltype'] == 'batchnormalization':
             weight_list = concat_weights(trained_layers_dict[il]['weights'])
         else:
             raise ValueError(f'Weight filling not implemented for layer type {trained_layers_dict[il]["ltype"]}')
-        meta_mod.layers[il + 1].set_weights(weight_list)
+        meta_mod.layers[trained_layers_dict[il]['layer_index']].set_weights(weight_list)
     meta_mod.compile()
     return meta_mod
 
-- 
GitLab