From 7a30f97898bea43d169ee1aa6286ddc78f09110c Mon Sep 17 00:00:00 2001
From: "Fuchs, Pim" <pim.fuchs@wur.nl>
Date: Fri, 30 Mar 2018 12:31:36 +0200
Subject: [PATCH] Added basic interpretation interface

---
 interpretation.py             | 155 ++++++++++++++++++++++++++++++++++
 sequence_position_analyzer.py |  79 +++++++++++++++++
 train.py                      |  17 +++-
 3 files changed, 250 insertions(+), 1 deletion(-)
 create mode 100644 interpretation.py
 create mode 100644 sequence_position_analyzer.py

diff --git a/interpretation.py b/interpretation.py
new file mode 100644
index 0000000..b0ef405
--- /dev/null
+++ b/interpretation.py
@@ -0,0 +1,155 @@
+import numpy as np
+import pandas as pd
+
+from bokeh.io import show, output_file
+from bokeh.layouts import gridplot
+from bokeh.models import (
+    ColumnDataSource,
+    HoverTool,
+    LinearColorMapper,
+    BasicTicker,
+    PrintfTickFormatter,
+    ColorBar,
+    TapTool,
+    OpenURL,
+
+)
+from bokeh.plotting import figure
+
+def get_custom_hover_tool():
+    hover = HoverTool(tooltips="""
+        <div>
+            <div>
+            <span style="font-size: 15px;">Sample: @x</span>
+                <img
+                    src="@img1" height="42" alt="@imgs" width="100"
+                    style="float: left; margin: 0px 15px 15px 0px;"
+                    border="2"
+                ></img>
+            </div>
+            <div>
+                <img
+                    src="@img2" height="42" alt="@imgs" width="100"
+                    style="float: left; margin: 0px 15px 15px 0px;"
+                    border="2"
+                ></img>
+            </div>
+            <div>
+                <span style="font-size: 15px;">Activation:</span>
+                <span style="font-size: 17px; font-weight: bold;">@value</span>
+            </div>
+
+            
+            <div>
+                
+                
+            </div>
+        </div>
+        """
+                      )
+    #<span style="font-size: 10px; color: #696;">($x, $y)</span>
+    return (hover)
+
+def plot_filter_activation(filter_activation_matrix, predictions, save_dir, real_labels, position_activity):
+    predictions = ['+' if np.round(x) == 1 else '-' for x in predictions]
+    #predictions = [predictions[x] + '\n+' if np.round(real_labels[x]) == 1 else predictions[x] + '\n-'for x in
+    #               range(len(predictions))]
+
+    n_filter_pairs = filter_activation_matrix.shape[0]
+
+    filter_activation_matrix = np.transpose(filter_activation_matrix, [1,0])
+
+    colors = ["#75968f", "#a5bab7", "#c9d9d3", "#e2e2e2", "#dfccce", "#ddb7b1", "#cc7878", "#933b41", "#550b1d"]
+    mapper = LinearColorMapper(palette=colors, low=np.min(filter_activation_matrix), high=np.max(filter_activation_matrix))
+
+    y_range = ["Filter %d" %n for n in range(n_filter_pairs)]
+    x_range =  [str(n) for n in range(len(predictions))]
+
+    x = ["%d" %n for n in range(len(predictions)) for i in range(n_filter_pairs)]
+    y = y_range * len(predictions)
+
+    values = filter_activation_matrix.flatten()
+    img1 = ['filter_%d.png'%(f*2) for _ in range(len(predictions)) for f in range(n_filter_pairs)]
+    img2 = ['filter_%d.png'%(f*2+1) for _ in range(len(predictions)) for f in range(n_filter_pairs)]
+    source = ColumnDataSource({'x':x, 'y':y, 'value':values, 'img1': img1, 'img2': img2})
+
+    hover = get_custom_hover_tool()
+    TOOLS = [hover, "tap"]#, "save,pan,box_zoom,reset,wheel_zoom"]
+
+    p = figure(title="Prediction Motif Interpretation",
+               x_range=x_range, y_range=list(reversed(y_range)),
+               x_axis_location="above", plot_width=900, plot_height=400,
+               tools=TOOLS, toolbar_location=None)
+
+    p.xaxis.major_label_text_font_size = "4pt"
+    p.axis.major_label_standoff = 0
+    p.toolbar.logo = None
+
+    p.rect(source=source, x='x', y='y', width=1, height=1, fill_color={'field': 'value', 'transform': mapper})
+
+
+    url = "http://www.colors.commutercreative.com/@color/"
+    taptool = p.select(type=TapTool)
+    taptool.callback = OpenURL(url=url)
+
+    color_bar = ColorBar(color_mapper=mapper, major_label_text_font_size="5pt",
+                         ticker=BasicTicker(desired_num_ticks=len(colors)),
+                         formatter=PrintfTickFormatter(format="%.2f"),
+                         label_standoff=6, border_line_color=None, location=(0, 0))
+    p.add_layout(color_bar, 'right')
+
+    pos_x_range = [str(x) for x in range(position_activity.shape[2])]
+    pos_x = ["%d" %n for n in range(len(predictions)) for i in range(n_filter_pairs)]
+    pos_y = y_range * len(predictions)
+
+    pos_values = filter_activation_matrix.flatten()
+    source = ColumnDataSource({'x':x, 'y':y, 'value':values, 'img1': img1, 'img2': img2})
+
+    pos_fig = figure(title="Sequence position activation",
+               x_range=pos_x_range,
+               x_axis_location="above", plot_width=900, plot_height=400,
+               tools=['pan', 'zoom_in'], toolbar_location=None)
+
+    pos_fig.toolbar.logo = None
+    current_sample = 0
+    current_filter = 0
+    x_dat = [n for n in range(position_activity.shape[2])]
+    print(position_activity.shape)
+    pos_fig.line(x_dat, position_activity[0,current_sample,:, current_filter])
+
+
+    grid = gridplot([[p],[pos_fig]], toolbar_options={'logo': None})
+
+
+    output_file("%s/interpretation.html"%save_dir)
+
+    show(grid)  # show the plot
+
+
+
+def create_interactive_visualization(model_dir, pair_activation_matrix=None, predictions=None, real_labels=None,
+                                     only_positive=False, position_activity=None):
+    pair_activation_matrix = np.load("%sevaluation_activity.npy" % model_dir)
+    predictions = np.load("%sevaluation_predictions.npy" % model_dir)
+    position_activity = np.load("%s/position_scores.npy" % model_dir)
+
+    pair_weights = np.load("%sweight_2.npy" % model_dir)
+    if only_positive:
+        pair_weights = np.exp(pair_weights)
+    prediction_bias = np.exp(np.load("%sweight_3.npy" % model_dir))
+
+    pair_activation_matrix = np.transpose(pair_activation_matrix)
+    net_activity = np.multiply(pair_activation_matrix, pair_weights)
+
+    relative_activity = net_activity / prediction_bias
+
+    plot_filter_activation(relative_activity, predictions, model_dir, real_labels, position_activity=position_activity)
+
+
+def main():
+    create_interactive_visualization(model_dir="Results/Synthetic/Synthetic_14_length500_frac0.50"
+                                       "/SparsityTerm_InteractiveTest_OnlyPos_36Filters_uniform_True_normalized/0/",
+                             real_labels="a", only_positive=False)
+
+if __name__ == "__main__":
+    main()
\ No newline at end of file
diff --git a/sequence_position_analyzer.py b/sequence_position_analyzer.py
new file mode 100644
index 0000000..cc7e125
--- /dev/null
+++ b/sequence_position_analyzer.py
@@ -0,0 +1,79 @@
+import tensorflow as tf
+import numpy as np
+from data_processing import sequence_processing
+
+allowed_characters = ['H', 'R', 'K', 'I', 'P', 'F', 'Y', 'Q', 'E', 'D',
+                      'C', 'W', 'T', 'L', 'N', 'G', 'A', 'M', 'V', 'S'] #, 'U']
+
+len_allowed_char = len(allowed_characters)
+
+def one_hot(aa):
+    return([0 if allowed_characters[x] != aa else 1 for x in range(len_allowed_char)])
+
+def convert_to_one_hot(fasta_seq):
+    array = []
+    for aa in fasta_seq:
+        array.append(one_hot(aa))
+
+
+    array = np.asarray(array, dtype=np.float32)
+    array = np.reshape(array, [1, -1, len_allowed_char])
+    return(array)
+
+def get_motif_positions_in_sequence(input_pairs, filters, filter_rect_bias, batch_size=100):
+    n_fils = filters.shape[-1]
+    filter_size = filters.shape[0]
+
+
+    with tf.Session().as_default() as sess:
+        zero_labels = [0 for _ in range(len(input_pairs))] # Labels are not used, but a valid argument has to be passed
+        seq1, seq2, _, iter_initializer = sequence_processing(batch_size=batch_size, pair_list=input_pairs,
+                                                labels=zero_labels, shuffle=False, threads=1)
+
+        feature_map_1 = tf.nn.conv1d(seq1, filters=filters, stride=1, padding="SAME")
+        feature_map_2 = tf.nn.conv1d(seq2, filters=filters, stride=1, padding="SAME")
+
+        value_map_1 = tf.nn.relu(feature_map_1 - filter_rect_bias) + filter_rect_bias
+        value_map_2 = tf.nn.relu(feature_map_2 - filter_rect_bias) + filter_rect_bias
+
+        sess.run(tf.global_variables_initializer())
+        sess.run(iter_initializer)
+
+        value_maps = []
+        for i in range(int(len(input_pairs)/batch_size) + 1):
+            values = sess.run([value_map_1, value_map_2])
+            value_maps.append(values)
+
+        value_maps = np.vstack(value_maps)
+        print(value_maps.shape)
+
+        #bools = value_maps > 0.0
+
+        #bools = np.reshape(bools, [-1, n_fils])
+        #for f_idx in range(bools.shape[-1]):
+        #    print("Contains motif %d at positions: "%f_idx)
+        #    print(np.argmax(bools[:,f_idx]) - filter_size)
+
+    return(value_maps)
+
+
+def main():
+    model_dir ="../../Biological Results/Sun/SuppCD" \
+               "/12Filters_uniform_AttentionFalse_softmax_Sigmoid_0"
+
+    with open("../../Data/Sun/Fasta/NP_002937.1") as inp:
+        fasta = inp.read()
+
+    print(fasta)
+    print(len(fasta))
+
+    input_seq = convert_to_one_hot(fasta)
+
+    filters = np.load("%s/weight_0.npy"%(model_dir))
+    filter_rect_bias = np.reshape(np.load("%s/weight_1.npy"%(model_dir)), [-1])
+
+    vals, bools = get_motif_positions_in_sequence(input_seq, filters, filter_rect_bias)
+
+
+if __name__ == "__main__":
+    main()
\ No newline at end of file
diff --git a/train.py b/train.py
index eb8de2d..7ad824c 100644
--- a/train.py
+++ b/train.py
@@ -1,5 +1,7 @@
 from CNN_PPI import PPI_CNN
+from interpretation import create_interactive_visualization
 from data_processing import *
+from sequence_position_analyzer import get_motif_positions_in_sequence
 
 def start_training(data_dir, base_save_dir, pair_file, eval_pair_file, attention, silence_edges, lr, fsr, n_filters, \
                                     batch_size,
@@ -46,6 +48,19 @@ def start_training(data_dir, base_save_dir, pair_file, eval_pair_file, attention
         del (nn)
 
         plot_evaluation_activity(save_dir, activation, predictions, evaluation_labels)
+        # Note that filters.npy is different from weight_0.npy
+        # weight_0.npy are the weights before normalization
+        filters = np.load("%s/filters.npy"%save_dir)
+        bias = np.load("%s/weight_1.npy"%save_dir)
+        position_scores = get_motif_positions_in_sequence(evaluation_pairs, filters, bias)
+        np.save("%s/position_scores"%save_dir, position_scores)
+
+        create_interactive_visualization(save_dir, activation, predictions, evaluation_labels, position_activity=position_scores)
+
+
+
+        print(position_scores[0].shape)
+
         plot_roc(roc, area, "%s/roc.png" % save_dir)
 
 def main():
@@ -87,7 +102,7 @@ def main():
         if not os.path.exists("Results/Synthetic/%s/"%synth_name):
             os.mkdir("Results/Synthetic/%s/"%synth_name)
 
-        save_dir = "Results/Synthetic/%s/SparsityTerm_HighFSR_OnlyPos_%dFilters_%s_%s_%s" \
+        save_dir = "Results/Synthetic/%s/SparsityTerm_InteractiveTest_OnlyPos_%dFilters_%s_%s_%s" \
                    "/" % (
             synth_name, n_filters, init, str(attention), act)
 
-- 
GitLab