diff --git a/CNN_PPI.py b/CNN_PPI.py
index 1d28eb5dca647fe91429d8aa15cb632afa9b8a9d..8848888dac964d6f89bc63d6e06d54fce979b2e7 100644
--- a/CNN_PPI.py
+++ b/CNN_PPI.py
@@ -318,15 +318,15 @@ class PPI_CNN(object):
             padding = "SAME"
             # First sequence in the pair
             with tf.variable_scope("FirstSeq"):
-                first_seq_conv = tf.nn.conv1d(sample_seq1, filters=self.filter, stride=1, padding=padding,
+                self.feature_map_1 = tf.nn.conv1d(sample_seq1, filters=self.filter, stride=1, padding=padding,
                                               name="conv",use_cudnn_on_gpu=True)
+                first_pool = tf.reduce_max(self.feature_map_1, axis=1)
 
-                first_pool = tf.reduce_max(first_seq_conv, axis=1)
             # Seqond sequence in the pair
             with tf.variable_scope("SecondSeq"):
-                second_seq_conv = tf.nn.conv1d(sample_seq2, filters=self.filter, stride=1, padding=padding,
+                self.feature_map_2 = tf.nn.conv1d(sample_seq2, filters=self.filter, stride=1, padding=padding,
                                                name="conv", use_cudnn_on_gpu=True)
-                second_pool = tf.reduce_max(second_seq_conv, axis=1)
+                second_pool = tf.reduce_max(self.feature_map_2, axis=1)
 
             combined_pool = tf.concat([first_pool, second_pool], axis=-1, name="combine_pools")
 
@@ -371,7 +371,11 @@ class PPI_CNN(object):
                 print("No valid loss function chosen. Select 'cross-entropy' or 'l2'.")
                 exit()
 
-            self.cost = self.cost + 0.05 * tf.cast(self.sparsity_term,dtype=tf.float64)
+
+            sparsity_coeff = 0.04 #* tf.nn.sigmoid(((20.0*tf.cast(self.iteration,
+            # dtype=tf.float64))/self.max_iterations) - 10)
+
+            self.cost = self.cost + sparsity_coeff * tf.cast(self.sparsity_term,dtype=tf.float64)
 
             tf.summary.scalar("summaries/Cost", self.cost)
             #initial_lr = tf.train.exponential_decay(initial_lr, global_step=self.iteration, decay_steps=1,
@@ -395,7 +399,8 @@ class PPI_CNN(object):
         """
         with tf.variable_scope("Prediction"):
             n = int(0.5*self.n_fil)
-            and_kern_init = tf.exp(tf.ones([n, 1], dtype=tf.float32) * -1.0)
+            #and_kern_init = tf.exp(tf.ones([n, 1], dtype=tf.float32) * -1.0)
+            and_kern_init = tf.sigmoid(tf.ones([n, 1], dtype=tf.float32) * -10.0) * 30.0
 
             and_sum = tf.layers.dense(and_gates, 1, use_bias=False, activation=None, trainable=True,
                                       kernel_initializer=lambda x,  dtype, partition_info: and_kern_init)
@@ -426,13 +431,12 @@ class PPI_CNN(object):
                 threads=self.data_threads, shuffle=not evaluation, n_prediction_nodes=self.n_prediction_nodes)
 
         # The batches for the sequences in a pair
-        seq1 = tf.placeholder_with_default(sample_seq1, [self.batch_size, None, 20], "seq1")
-        seq2 = tf.placeholder_with_default(sample_seq2, [self.batch_size, None, 20], "seq2")
+        seq1 = tf.placeholder_with_default(sample_seq1, [None, None, 20], "seq1")
+        seq2 = tf.placeholder_with_default(sample_seq2, [None, None, 20], "seq2")
 
         # Output of the convolutional layer containing 2*n_fil values per pair
         combined_pool = self._create_convolution_layer(seq1, seq2, dropout_rate=dropout)
 
-
         # The maximum value resulting from convolution with a filter (on a one-hot encoded sequence) is the sum of
         # the maximum values on the columns.
         self.maximum_filter_activation = tf.reduce_sum(tf.reduce_max(self.filter, axis=1), axis=0,
@@ -440,15 +444,28 @@ class PPI_CNN(object):
 
         # For each filter there is a trainable variable expressing the rectification threshold as a fraction of the
         # maximum possible output of that filter.
-        fraction = tf.nn.sigmoid(tf.Variable(tf.ones([1, self.n_fil]), dtype=tf.float32) * -2.0)
+        fraction =  tf.nn.sigmoid(tf.Variable(tf.ones([1, self.n_fil]) * -5.0, dtype=tf.float32))
         # Give the network some leeway
         rect_bias = fraction * self.maximum_filter_activation
-        #rect_bias = tf.Variable(tf.zeros([1, self.n_fil]), dtype=tf.float32)
-        ##rect_bias = tf.Variable(tf.zeros([1, self.n_fil]), dtype=tf.float32)
         cat_bias = tf.concat([rect_bias, rect_bias], axis=-1)
 
-        # Rectify the filter output
+        # This is where we get information on where, in the sequence, the convolution operation finds matches with
+        # the filter that are above the learnt threshold. Note that this is not used for training, but only for
+        # evaluation purposes (i.e. for interpretation purposes after the model is trained).
+        rect_feature_map_1 = self.feature_map_1 - rect_bias
+        rect_feature_map_1 = tf.where(rect_feature_map_1 > 0.0, rect_feature_map_1 + rect_bias, tf.zeros_like(rect_feature_map_1))
+
+        self.nonzero_idx_fm1 = tf.where(rect_feature_map_1 > 0.0)
+        self.nonzero_vals_fm1 = tf.gather_nd(self.feature_map_1, self.nonzero_idx_fm1)
 
+        rect_feature_map_2 = self.feature_map_2 - rect_bias
+        rect_feature_map_2 = tf.where(rect_feature_map_2 > 0.0, rect_feature_map_2 + rect_bias, tf.zeros_like(rect_feature_map_2))
+        self.nonzero_idx_fm2 = tf.where(rect_feature_map_2 > 0.0)
+        self.nonzero_vals_fm2 = tf.gather_nd(self.feature_map_2, self.nonzero_idx_fm2)
+
+
+        # Back to the operations relevant for training
+        # Rectify the filter output
         combined_pool = tf.nn.relu(combined_pool - cat_bias)
         # Pretend rectification did not happen for non-zero outputs, but ignore the gradients for this operation.
         # Otherwise the gradient of the rectification bias with respect to the cost function would always be 0.
@@ -469,8 +486,11 @@ class PPI_CNN(object):
             #TODO: set up experiment to compare sparsity vs attention
             self.sparsity_term = tf.reduce_mean(tf.reduce_sum(corrected_and_gate, axis=-1) - tf.reduce_max(
                 corrected_and_gate, axis=-1))
+            #self.sparsity_term = tf.reduce_mean(tf.reduce_sum(corrected_and_gate, axis=-1))
+
+
 
-            #tf.contrib.layers.l1_regularizer(scale=0.1)
+                #tf.contrib.layers.l1_regularizer(scale=0.1)
 
         self.corrected_and_gate = corrected_and_gate
 
@@ -569,7 +589,7 @@ class PPI_CNN(object):
             tf.reset_default_graph()
             restore_sess.close()
 
-            # Restored model does not need initial 'help' with learning of the filters
+            # Restored model does not need initial help with learning of the filters
             print("Sequential filter start and edge silencing are turned off because a previous model is restored!")
             self.false_start_rate = 0.0
             self.silence_edges = False
@@ -650,9 +670,9 @@ class PPI_CNN(object):
             os.mkdir(save_filters)
 
         # Save weights
-        self.evaluated_weights = self.sess.run(tf.trainable_variables())
+        trainable_weights = self.sess.run(tf.trainable_variables())
 
-        for n, weight in enumerate(self.evaluated_weights):
+        for n, weight in enumerate(trainable_weights):
             np.save("%s/weight_%d" % (save_to, n), weight)
 
         self.evaluated_filters = self.sess.run(self.filter)
@@ -676,7 +696,7 @@ class PPI_CNN(object):
 
         return(roc, accuracy, precision, specificity)
 
-    def evaluate(self, evaluation_pair_list, evaluation_label_list, apply_sigmoid=True, batch_size=50):
+    def evaluate(self, evaluation_pair_list, evaluation_label_list, batch_size=50):
         """
         Performs a forward pass through the trained network with a set of evaluation sequences. It will return the
         average difference between the predicted and actual labels for those samples.
@@ -712,27 +732,48 @@ class PPI_CNN(object):
                 predictions = []
                 activities = []
 
+                # Here we store the information on where, in each sequence, the filters find matches.
+                position_info = []
+
                 for batches in range(int(len(evaluation_pair_list[0])/batch_size) + 1):
                     print("Batch ", batches)
-
-                    print(sample_seq1)
                     seq1, seq2, labs = sample_sess.run([sample_seq1, sample_seq2, batch_labels])
                     prediction = sess.graph.get_tensor_by_name("Prediction/Prediction:0")
                     pair_activity = sess.graph.get_tensor_by_name("AndGates:0")
 
                     # Should a sigmoid activation function be applied to the prediction
-                    if apply_sigmoid == True:
-                        activity, pred = sess.run([pair_activity, tf.nn.sigmoid(prediction)],
-                                        feed_dict={ "seq1:0": seq1,
-                                                    "seq2:0": seq2})
-                        print(pred)
-                    else:
-                        pred = sess.run(prediction, feed_dict={"seq1:0": seq1, "seq2:0": seq2})
+                    activity, pred, fm1_idx, fm1_vals, fm2_idx, fm2_vals = sess.run([pair_activity, tf.nn.sigmoid(
+                                prediction),self.nonzero_idx_fm1, self.nonzero_vals_fm1, self.nonzero_idx_fm2,
+                                                                                     self.nonzero_vals_fm2],
+                                    feed_dict={ "seq1:0": seq1,
+                                                "seq2:0": seq2})
+
+                    # Set the sample number correctly
+                    fm1_idx[:,0] += batches * batch_size
+                    fm2_idx[:, 0] += batches * batch_size
+
+                    # First sequence in the pair
+                    fm1_vals = np.reshape(fm1_vals, [-1, 1])
+                    fm1_position_info = np.hstack([fm1_idx, fm1_vals])
+                    # Add binary value indicating which sequence in the pair (=0)
+                    fm1_position_info = np.hstack([np.zeros([fm1_vals.shape[0], 1]), fm1_position_info])
+
+                    # Second sequence in the pair
+                    fm2_vals = np.reshape(fm2_vals, [-1, 1])
+                    fm2_position_info = np.hstack([fm2_idx, fm2_vals])
+                    # Add binary value indicating which sequence in the pair (=1)
+                    fm2_position_info = np.hstack([np.ones([fm2_vals.shape[0],1]), fm2_position_info])
+
+                    sample_position_info = np.vstack((fm1_position_info, fm2_position_info))
+
+                    position_info.append(sample_position_info)
 
                     activities.append(activity)
                     predictions.append(pred)
 
+        position_info = np.vstack(position_info)
         predictions = np.array(predictions)
+
         # Ouput values of each filter pair for each sample
         activities = np.array(activities)
 
@@ -745,7 +786,7 @@ class PPI_CNN(object):
         roc, accuracy, precision, specificity = self._evaluation_metrics(real, predictions)
         area = auc(roc[0], roc[1])
 
-        return(roc, area, activities, predictions, accuracy, precision, specificity)
+        return(roc, area, activities, predictions, accuracy, precision, specificity, position_info)
 
 
     def filters_to_seqlogo(self, save_dir):
diff --git a/data_processing.py b/data_processing.py
index 0a0085f893a5394346c95b1c6884389ed29ede5e..c42966e12b515951849fccc241c9e421799b5207 100644
--- a/data_processing.py
+++ b/data_processing.py
@@ -49,20 +49,26 @@ def plot_filter_activation(filter_activation_matrix, predictions, save_dir, real
     plt.cla()
     plt.close()
 
+def sigmoid(x):
+    return (1.0/(1.0 + np.exp(-x)))
+
 def plot_evaluation_activity(model_dir, pair_activation_matrix, predictions, real_labels, only_positive=False):
     np.save("%sevaluation_activity.npy" % model_dir, pair_activation_matrix)
     np.save("%sevaluation_predictions.npy" % model_dir, predictions)
-    pair_activation_matrix = np.load("%sevaluation_activity.npy" % model_dir)
+
+    activation_matrix = np.load("%sevaluation_activity.npy" % model_dir)
 
     predictions = np.load("%sevaluation_predictions.npy" % model_dir)
 
     pair_weights = np.load("%sweight_2.npy" % model_dir)
     if only_positive:
-        pair_weights = np.exp(pair_weights)
+        pair_weights = sigmoid(pair_weights) * 30.0
+
+
     prediction_bias = np.exp(np.load("%sweight_3.npy" % model_dir))
 
-    pair_activation_matrix = np.transpose(pair_activation_matrix)
-    net_activity = np.multiply(pair_activation_matrix, pair_weights)
+    activation_matrix = np.transpose(activation_matrix)
+    net_activity = np.multiply(activation_matrix, pair_weights)
 
     relative_activity = net_activity / prediction_bias
 
@@ -97,6 +103,7 @@ def sequence_processing(batch_size, pair_list, labels, shuffle=True, threads=15,
     with tf.variable_scope("DataProcessing"):
         pl1 = tf.constant(pair_list[0])
         pl2 = tf.constant(pair_list[1])
+
         labels = np.array(labels)
         labs = tf.constant(labels, dtype=tf.float64)
 
@@ -113,6 +120,7 @@ def sequence_processing(batch_size, pair_list, labels, shuffle=True, threads=15,
         combined_ds = combined_ds.repeat() # Make the dataset repeat so we don't run out of sequences
 
         if shuffle: # Should be turned off when evaluating since we'll be comparing it to an ordered list of labels
+            print("Shuffling samples. If you see this message during evaluation then something went wrong.")
             combined_ds = combined_ds.shuffle(batch_size)
 
         # Pecode the TFRecord formatted sequences
diff --git a/interpretation.py b/interpretation.py
index b0ef405acdd67febc92e890c4886e57022de0e2d..b3d8b66d06de67f2e7b25b6873953e3681ac4259 100644
--- a/interpretation.py
+++ b/interpretation.py
@@ -3,6 +3,7 @@ import pandas as pd
 
 from bokeh.io import show, output_file
 from bokeh.layouts import gridplot
+from bokeh import events
 from bokeh.models import (
     ColumnDataSource,
     HoverTool,
@@ -12,6 +13,7 @@ from bokeh.models import (
     ColorBar,
     TapTool,
     OpenURL,
+    CustomJS
 
 )
 from bokeh.plotting import figure
@@ -50,31 +52,30 @@ def get_custom_hover_tool():
     #<span style="font-size: 10px; color: #696;">($x, $y)</span>
     return (hover)
 
-def plot_filter_activation(filter_activation_matrix, predictions, save_dir, real_labels, position_activity):
+def _create_prediction_heatmap(filter_activation_matrix, predictions, real_labels):
     predictions = ['+' if np.round(x) == 1 else '-' for x in predictions]
-    #predictions = [predictions[x] + '\n+' if np.round(real_labels[x]) == 1 else predictions[x] + '\n-'for x in
+    # predictions = [predictions[x] + '\n+' if np.round(real_labels[x]) == 1 else predictions[x] + '\n-'for x in
     #               range(len(predictions))]
-
     n_filter_pairs = filter_activation_matrix.shape[0]
-
-    filter_activation_matrix = np.transpose(filter_activation_matrix, [1,0])
+    filter_activation_matrix = np.transpose(filter_activation_matrix, [1, 0])
 
     colors = ["#75968f", "#a5bab7", "#c9d9d3", "#e2e2e2", "#dfccce", "#ddb7b1", "#cc7878", "#933b41", "#550b1d"]
-    mapper = LinearColorMapper(palette=colors, low=np.min(filter_activation_matrix), high=np.max(filter_activation_matrix))
+    mapper = LinearColorMapper(palette=colors, low=np.min(filter_activation_matrix),
+                               high=np.max(filter_activation_matrix))
 
-    y_range = ["Filter %d" %n for n in range(n_filter_pairs)]
-    x_range =  [str(n) for n in range(len(predictions))]
+    y_range = [str(n) for n in range(n_filter_pairs)]
+    x_range = [str(n) for n in range(len(predictions))]
 
-    x = ["%d" %n for n in range(len(predictions)) for i in range(n_filter_pairs)]
+    x = ["%d" % n for n in range(len(predictions)) for i in range(n_filter_pairs)]
     y = y_range * len(predictions)
 
-    values = filter_activation_matrix.flatten()
-    img1 = ['filter_%d.png'%(f*2) for _ in range(len(predictions)) for f in range(n_filter_pairs)]
-    img2 = ['filter_%d.png'%(f*2+1) for _ in range(len(predictions)) for f in range(n_filter_pairs)]
-    source = ColumnDataSource({'x':x, 'y':y, 'value':values, 'img1': img1, 'img2': img2})
+    values = filter_activation_matrix.ravel()
+    img1 = ['filter_%d.png' % (f * 2) for _ in range(len(predictions)) for f in range(n_filter_pairs)]
+    img2 = ['filter_%d.png' % (f * 2 + 1) for _ in range(len(predictions)) for f in range(n_filter_pairs)]
+    source = ColumnDataSource({'x': x, 'y': y, 'value': values, 'img1': img1, 'img2': img2})
 
     hover = get_custom_hover_tool()
-    TOOLS = [hover, "tap"]#, "save,pan,box_zoom,reset,wheel_zoom"]
+    TOOLS = [hover, "tap"]  # , "save,pan,box_zoom,reset,wheel_zoom"]
 
     p = figure(title="Prediction Motif Interpretation",
                x_range=x_range, y_range=list(reversed(y_range)),
@@ -85,12 +86,8 @@ def plot_filter_activation(filter_activation_matrix, predictions, save_dir, real
     p.axis.major_label_standoff = 0
     p.toolbar.logo = None
 
-    p.rect(source=source, x='x', y='y', width=1, height=1, fill_color={'field': 'value', 'transform': mapper})
-
-
-    url = "http://www.colors.commutercreative.com/@color/"
-    taptool = p.select(type=TapTool)
-    taptool.callback = OpenURL(url=url)
+    p.rect(source=source, x='x', y='y', width=1, height=1, fill_color={'field': 'value', 'transform': mapper},
+           name="rectangle")
 
     color_bar = ColorBar(color_mapper=mapper, major_label_text_font_size="5pt",
                          ticker=BasicTicker(desired_num_ticks=len(colors)),
@@ -98,34 +95,153 @@ def plot_filter_activation(filter_activation_matrix, predictions, save_dir, real
                          label_standoff=6, border_line_color=None, location=(0, 0))
     p.add_layout(color_bar, 'right')
 
-    pos_x_range = [str(x) for x in range(position_activity.shape[2])]
-    pos_x = ["%d" %n for n in range(len(predictions)) for i in range(n_filter_pairs)]
-    pos_y = y_range * len(predictions)
+    return(source, p)
 
-    pos_values = filter_activation_matrix.flatten()
-    source = ColumnDataSource({'x':x, 'y':y, 'value':values, 'img1': img1, 'img2': img2})
+def _get_JS_callback():
+    code = """
+                var data = source.data,
+                    selected = source.selected['1d']['indices'];
 
+                var complete_data = complete_source.data;
+                var sample_data_1 = sample_source_1.data;
+                var sample_data_2 = sample_source_2.data;
+                
+                var selected_sample;
+                
+                if(selected.length == 1){
+                    // only consider case where one glyph is selected by user
+                    selected_filter = data['y'][selected[0]];
+                    selected_sample = data['x'][selected[0]];
+                    first_filter = selected_filter * 2;
+                    second_filter = first_filter + 1;
+
+                    var pos_arr = [[],[]];
+                    var filter_arr = [[],[]];
+                    var similarity_arr = [[],[]];
+                    var seq_arr = [[],[]];
+                    var sample_arr = [[],[]];
+                    var index_arr = [[],[]];
+
+                    for (var i = 0; i < complete_data['sample'].length; ++i){
+                        if(complete_data['sample'][i] == selected_sample && ((complete_data['filter'][i] == first_filter) || 
+                        (complete_data['filter'][i] == second_filter))){
+                        
+                            var seq = complete_data['filter'][i] - first_filter;
+                            
+                            pos_arr[seq].push(complete_data['position'][i]);
+                            filter_arr[seq].push(complete_data['filter'][i]);
+                            similarity_arr[seq].push(complete_data['similarity'][i]);
+                            seq_arr[seq].push(complete_data['seq'][i]);
+                            sample_arr[seq].push(complete_data['sample'][i]);
+                            index_arr[seq].push(complete_data['index'][i]);
+                        }
+                    }
+
+                }
+
+                sample_data_1['position'] = pos_arr[0];
+                sample_data_1['index'] = index_arr[0];
+                sample_data_1['filter'] = filter_arr[0];
+                sample_data_1['similarity'] = similarity_arr[0];
+                sample_data_1['seq'] = seq_arr[0];
+                sample_data_1['sample'] = sample_arr[0];
+                
+                sample_data_2['position'] = pos_arr[1];
+                sample_data_2['index'] = index_arr[1];
+                sample_data_2['filter'] = filter_arr[1];
+                sample_data_2['similarity'] = similarity_arr[1];
+                sample_data_2['seq'] = seq_arr[1];
+                sample_data_2['sample'] = sample_arr[1];
+                
+                
+                source.selected['1d']['indices'] = [];
+                source.change.emit();
+                complete_source.change.emit();
+                sample_source_1.change.emit();
+                sample_source_2.change.emit();
+            """
+    return code
+
+def _create_motif_map(complete_df, source):
+    colors = ["#75968f", "#a5bab7", "#c9d9d3", "#e2e2e2", "#dfccce", "#ddb7b1", "#cc7878", "#933b41", "#550b1d"]
+    sample_df = complete_df[complete_df['sample'] == 41.0]
+    sample_df_1 = sample_df[sample_df['filter'] == 14.0]
+    sample_df_2 = sample_df[sample_df['filter'] == 15.0]
+
+    # CDS for all samples
+    complete_source = ColumnDataSource(complete_df)
+
+    # CDS for the specific sample to be shown in the motif map
+    sample_source_1 = ColumnDataSource(sample_df_1)
+    sample_source_2 = ColumnDataSource(sample_df_2)
+
+    # Javascript callback code to deal with rectangle selection in the heatmap
+    code = _get_JS_callback()
+
+    callback = CustomJS(args={'source': source, 'sample_source_1': sample_source_1, 'sample_source_2': sample_source_2,
+                              'complete_source': complete_source},
+                        code=code)
+    source.callback = callback
+
+    # Actual plotting
     pos_fig = figure(title="Sequence position activation",
-               x_range=pos_x_range,
-               x_axis_location="above", plot_width=900, plot_height=400,
-               tools=['pan', 'zoom_in'], toolbar_location=None)
+                     x_axis_location="below", plot_width=900, plot_height=400,
+                     tools=['pan', 'zoom_in', 'zoom_out'], toolbar_location=None)
 
-    pos_fig.toolbar.logo = None
-    current_sample = 0
-    current_filter = 0
-    x_dat = [n for n in range(position_activity.shape[2])]
-    print(position_activity.shape)
-    pos_fig.line(x_dat, position_activity[0,current_sample,:, current_filter])
+    col_mapper = LinearColorMapper(palette=colors, low=0.0, high=1.0)
 
+    # Plot with glyphs where the y-axis indicates which of the two sequences in the sample (sequence pair)
+    # The x-axis indicates the position in the sequence
+    # The orientation of the triangle indicates which of the two filters in a filter pair
+    # The colour of the triangles indicates the similarity score
 
-    grid = gridplot([[p],[pos_fig]], toolbar_options={'logo': None})
+    pos_fig.inverted_triangle(x="position", y="seq", legend='filter', fill_alpha='similarity', source=sample_source_1, size=20,
+                 color={'field': 'similarity', 'transform': col_mapper})
+    pos_fig.triangle(x="position", y="seq", legend='filter', fill_alpha='similarity', source=sample_source_2, size=20,
+                 color={'field': 'similarity', 'transform': col_mapper})
 
 
+    color_bar = ColorBar(color_mapper=col_mapper, major_label_text_font_size="5pt",
+                         ticker=BasicTicker(desired_num_ticks=len(colors)),
+                         formatter=PrintfTickFormatter(format="%.2f"),
+                         label_standoff=6, border_line_color=None, location=(0, 0))
+    pos_fig.add_layout(color_bar, 'right')
+
+    pos_fig.legend.location = 'top_left'
+    pos_fig.legend.click_policy = 'hide'
+
+
+
+
+    # Axis settings
+    pos_fig.yaxis.axis_label = "Sequence in pair"
+    pos_fig.xaxis.axis_label = "Position"
+
+    pos_fig.yaxis[0].ticker.desired_num_ticks = 2
+    pos_fig.y_range.start = -0.5
+    pos_fig.y_range.end = 1.5
+    pos_fig.x_range.start = 0.0
+    pos_fig.x_range.end = 500.0
+
+    return pos_fig
+
+def plot_filter_activation(filter_activation_matrix, predictions, save_dir, real_labels, position_activity):
+
+    source, p = _create_prediction_heatmap(filter_activation_matrix, predictions, real_labels)
+
+    complete_df = pd.DataFrame(position_activity, columns = ["seq", "sample", "position", "filter", "similarity"])
+
+    pos_fig = _create_motif_map(complete_df, source)
+
+    grid = gridplot([[p],[pos_fig]], toolbar_options={'logo': None})
     output_file("%s/interpretation.html"%save_dir)
 
     show(grid)  # show the plot
 
 
+def sigmoid(x):
+    return(1.0/(1.0+np.exp(-x)))
+
 
 def create_interactive_visualization(model_dir, pair_activation_matrix=None, predictions=None, real_labels=None,
                                      only_positive=False, position_activity=None):
@@ -135,7 +251,7 @@ def create_interactive_visualization(model_dir, pair_activation_matrix=None, pre
 
     pair_weights = np.load("%sweight_2.npy" % model_dir)
     if only_positive:
-        pair_weights = np.exp(pair_weights)
+        pair_weights = sigmoid(pair_weights) * 20.0
     prediction_bias = np.exp(np.load("%sweight_3.npy" % model_dir))
 
     pair_activation_matrix = np.transpose(pair_activation_matrix)
@@ -145,11 +261,9 @@ def create_interactive_visualization(model_dir, pair_activation_matrix=None, pre
 
     plot_filter_activation(relative_activity, predictions, model_dir, real_labels, position_activity=position_activity)
 
-
 def main():
-    create_interactive_visualization(model_dir="Results/Synthetic/Synthetic_14_length500_frac0.50"
-                                       "/SparsityTerm_InteractiveTest_OnlyPos_36Filters_uniform_True_normalized/0/",
-                             real_labels="a", only_positive=False)
+    create_interactive_visualization(model_dir="Results/Biological/36Filters_uniform_SparsityTrue_normalized/0/",
+                             real_labels="a", only_positive=True)
 
 if __name__ == "__main__":
     main()
\ No newline at end of file
diff --git a/sequence_position_analyzer.py b/sequence_position_analyzer.py
deleted file mode 100644
index cc7e1257f8801aed3f54ca8cd8245bc2895e12e4..0000000000000000000000000000000000000000
--- a/sequence_position_analyzer.py
+++ /dev/null
@@ -1,79 +0,0 @@
-import tensorflow as tf
-import numpy as np
-from data_processing import sequence_processing
-
-allowed_characters = ['H', 'R', 'K', 'I', 'P', 'F', 'Y', 'Q', 'E', 'D',
-                      'C', 'W', 'T', 'L', 'N', 'G', 'A', 'M', 'V', 'S'] #, 'U']
-
-len_allowed_char = len(allowed_characters)
-
-def one_hot(aa):
-    return([0 if allowed_characters[x] != aa else 1 for x in range(len_allowed_char)])
-
-def convert_to_one_hot(fasta_seq):
-    array = []
-    for aa in fasta_seq:
-        array.append(one_hot(aa))
-
-
-    array = np.asarray(array, dtype=np.float32)
-    array = np.reshape(array, [1, -1, len_allowed_char])
-    return(array)
-
-def get_motif_positions_in_sequence(input_pairs, filters, filter_rect_bias, batch_size=100):
-    n_fils = filters.shape[-1]
-    filter_size = filters.shape[0]
-
-
-    with tf.Session().as_default() as sess:
-        zero_labels = [0 for _ in range(len(input_pairs))] # Labels are not used, but a valid argument has to be passed
-        seq1, seq2, _, iter_initializer = sequence_processing(batch_size=batch_size, pair_list=input_pairs,
-                                                labels=zero_labels, shuffle=False, threads=1)
-
-        feature_map_1 = tf.nn.conv1d(seq1, filters=filters, stride=1, padding="SAME")
-        feature_map_2 = tf.nn.conv1d(seq2, filters=filters, stride=1, padding="SAME")
-
-        value_map_1 = tf.nn.relu(feature_map_1 - filter_rect_bias) + filter_rect_bias
-        value_map_2 = tf.nn.relu(feature_map_2 - filter_rect_bias) + filter_rect_bias
-
-        sess.run(tf.global_variables_initializer())
-        sess.run(iter_initializer)
-
-        value_maps = []
-        for i in range(int(len(input_pairs)/batch_size) + 1):
-            values = sess.run([value_map_1, value_map_2])
-            value_maps.append(values)
-
-        value_maps = np.vstack(value_maps)
-        print(value_maps.shape)
-
-        #bools = value_maps > 0.0
-
-        #bools = np.reshape(bools, [-1, n_fils])
-        #for f_idx in range(bools.shape[-1]):
-        #    print("Contains motif %d at positions: "%f_idx)
-        #    print(np.argmax(bools[:,f_idx]) - filter_size)
-
-    return(value_maps)
-
-
-def main():
-    model_dir ="../../Biological Results/Sun/SuppCD" \
-               "/12Filters_uniform_AttentionFalse_softmax_Sigmoid_0"
-
-    with open("../../Data/Sun/Fasta/NP_002937.1") as inp:
-        fasta = inp.read()
-
-    print(fasta)
-    print(len(fasta))
-
-    input_seq = convert_to_one_hot(fasta)
-
-    filters = np.load("%s/weight_0.npy"%(model_dir))
-    filter_rect_bias = np.reshape(np.load("%s/weight_1.npy"%(model_dir)), [-1])
-
-    vals, bools = get_motif_positions_in_sequence(input_seq, filters, filter_rect_bias)
-
-
-if __name__ == "__main__":
-    main()
\ No newline at end of file
diff --git a/train.py b/train.py
index 7ad824c0ea02d710cfe7c83d8925a2e7747852b1..9e3c8f58e82d6a3eaad83f055f375e2ecc66641d 100644
--- a/train.py
+++ b/train.py
@@ -1,7 +1,35 @@
 from CNN_PPI import PPI_CNN
 from interpretation import create_interactive_visualization
 from data_processing import *
-from sequence_position_analyzer import get_motif_positions_in_sequence
+
+def create_evaluation_files(save_dir, evaluation_pairs, evaluation_labels, evaluation_results,
+                            only_positive_weights=True):
+    roc, area, activation, predictions, accuracy, precision, specificity, position_scores = evaluation_results
+
+    print("Accuracy:\t", accuracy)
+    print("Precision:\t", precision)
+    print("Specificity:\t", specificity)
+    print("AUC-ROC:\t", area)
+    with open("%s/Evaluation_metric.txt" % save_dir, 'w') as outp:
+        outp.write("Accuracy:\t%.2f\n" % accuracy)
+        outp.write("Precision:\t%.2f\n" % precision)
+        outp.write("Specificity:\t%.2f\n" % specificity)
+        outp.write("AUC-ROC:\t%.2f" % area)
+
+    plot_roc(roc, area, "%s/roc.png" % save_dir)
+
+    plot_evaluation_activity(save_dir, activation, predictions, evaluation_labels, only_positive=only_positive_weights)
+    # Note that filters.npy is different from weight_0.npy
+    # weight_0.npy are the weights before normalization
+    filters = np.load("%s/filters.npy" % save_dir)
+    bias = np.load("%s/weight_1.npy" % save_dir)
+    #position_scores = get_motif_positions_in_sequence(evaluation_pairs, filters, bias)
+    np.save("%s/position_scores" % save_dir, position_scores)
+
+    create_interactive_visualization(save_dir, activation, predictions, evaluation_labels,
+                                     position_activity=position_scores, only_positive=only_positive_weights)
+
+
 
 def start_training(data_dir, base_save_dir, pair_file, eval_pair_file, attention, silence_edges, lr, fsr, n_filters, \
                                     batch_size,
@@ -30,55 +58,32 @@ def start_training(data_dir, base_save_dir, pair_file, eval_pair_file, attention
 
         nn.filters_to_seqlogo(save_dir)
 
-        roc, area, activation, predictions, accuracy, precision, specificity = nn.evaluate(evaluation_pairs,
-                                                                                           evaluation_labels,
-                                                                                           apply_sigmoid=True,
-                                                                                           batch_size=batch_size)
-
-        print("Accuracy:\t", accuracy)
-        print("Precision:\t", precision)
-        print("Specificity:\t", specificity)
-        print("AUC-ROC:\t", area)
-        with open("%s/Evaluation_metric.txt" % save_dir, 'w') as outp:
-            outp.write("Accuracy:\t%.2f\n" % accuracy)
-            outp.write("Precision:\t%.2f\n" % precision)
-            outp.write("Specificity:\t%.2f\n" % specificity)
-            outp.write("AUC-ROC:\t%.2f" % area)
-
+        # It is highly recommended to use batch_size=1 here to prevent memory issues
+        # The reason for this is that we are, among others, requesting the entire feature map for all sequences in the
+        # batch and constructing a matrix of of all non-zero indices.
+        evaluation_results = nn.evaluate(evaluation_pairs, evaluation_labels, batch_size=1)
         del (nn)
 
-        plot_evaluation_activity(save_dir, activation, predictions, evaluation_labels)
-        # Note that filters.npy is different from weight_0.npy
-        # weight_0.npy are the weights before normalization
-        filters = np.load("%s/filters.npy"%save_dir)
-        bias = np.load("%s/weight_1.npy"%save_dir)
-        position_scores = get_motif_positions_in_sequence(evaluation_pairs, filters, bias)
-        np.save("%s/position_scores"%save_dir, position_scores)
-
-        create_interactive_visualization(save_dir, activation, predictions, evaluation_labels, position_activity=position_scores)
-
-
-
-        print(position_scores[0].shape)
+        # Create ROC plot, interactive visualization for the filters and .txt files for the evaluation metrics
+        create_evaluation_files(save_dir, evaluation_pairs, evaluation_labels, evaluation_results)
 
-        plot_roc(roc, area, "%s/roc.png" % save_dir)
 
 def main():
     # Settings
     attention = True
     silence_edges = True
-    lr = 3e-3   # Learning rate
-    fsr_iter = 5.0 # The number of iterations before the next filter becomes trainable (in false start)
+    lr = 4e-3   # Learning rate
+    fsr_iter = 25.0 # The number of iterations before the next filter becomes trainable (in false start)
     n_filters = 36
-    batch_size = 50
+    batch_size = 25
     init = 'uniform' # select ['uniform', 'seminormal']
     act = "normalized" # select ['normalized', 'softmax']
     kern_size = 11
-    iterations = 40000
+    iterations = 5
     fsr = fsr_iter * n_filters / float(iterations) # False start rate
 
     # Do we use biological or synthetic data to train?
-    biological = False
+    biological = True
 
     if biological:
         print("Using biological data.")
@@ -86,14 +91,15 @@ def main():
         data_dir = "Biological Data/Records"
         pair_file = "Biological Data/training_pairs.txt"
         eval_pair_file = "Biological Data/test_pairs.txt"
-        save_dir = "Results/Biological/%dFilters_%s_Attention%s_%s/" % (n_filters, init, str(attention),
+        save_dir = "Results/Biological/%dFilters_%s_Sparsity%s_%s/" % (n_filters, init, str(attention),
                                                                                   act)
     else:
         # Synthetic Directories
         if not os.path.exists("Results/Synthetic/"):
             os.mkdir("Results/Synthetic/")
 
-        # Change synth name to the name of the data you've generated, if necessary
+        # Change synth name to the name of the data you've generated, if necessary.
+        # If you ran create_synthetic_data.py without making any changes, you can leave synth_name as is.
         synth_name = "Synthetic_14_length500_frac0.50"
         data_dir = "Synthetic Data/%s/Records" %synth_name
         pair_file = "Synthetic Data/%s/training_pairs.txt" %synth_name
@@ -102,12 +108,12 @@ def main():
         if not os.path.exists("Results/Synthetic/%s/"%synth_name):
             os.mkdir("Results/Synthetic/%s/"%synth_name)
 
-        save_dir = "Results/Synthetic/%s/SparsityTerm_InteractiveTest_OnlyPos_%dFilters_%s_%s_%s" \
+        save_dir = "Results/Synthetic/%s/HighLR_BoundPredWeights_SparsityTerm_%dFilters_%s_%s_%s" \
                    "/" % (
             synth_name, n_filters, init, str(attention), act)
 
     # Do we want to load a previously trained model?
-    restore_model = None # otherwise use "Results/insert_model_dir_here/"
+    restore_model = save_dir + '/0/' # otherwise use "Results/insert_model_dir_here/"
 
     start_training(data_dir, save_dir, pair_file, eval_pair_file, attention, silence_edges, lr, fsr, n_filters, \
                     batch_size, init, act, kern_size, iterations, restore_model)