From 7a30f97898bea43d169ee1aa6286ddc78f09110c Mon Sep 17 00:00:00 2001 From: "Fuchs, Pim" <pim.fuchs@wur.nl> Date: Fri, 30 Mar 2018 12:31:36 +0200 Subject: [PATCH] Added basic interpretation interface --- interpretation.py | 155 ++++++++++++++++++++++++++++++++++ sequence_position_analyzer.py | 79 +++++++++++++++++ train.py | 17 +++- 3 files changed, 250 insertions(+), 1 deletion(-) create mode 100644 interpretation.py create mode 100644 sequence_position_analyzer.py diff --git a/interpretation.py b/interpretation.py new file mode 100644 index 0000000..b0ef405 --- /dev/null +++ b/interpretation.py @@ -0,0 +1,155 @@ +import numpy as np +import pandas as pd + +from bokeh.io import show, output_file +from bokeh.layouts import gridplot +from bokeh.models import ( + ColumnDataSource, + HoverTool, + LinearColorMapper, + BasicTicker, + PrintfTickFormatter, + ColorBar, + TapTool, + OpenURL, + +) +from bokeh.plotting import figure + +def get_custom_hover_tool(): + hover = HoverTool(tooltips=""" + <div> + <div> + <span style="font-size: 15px;">Sample: @x</span> + <img + src="@img1" height="42" alt="@imgs" width="100" + style="float: left; margin: 0px 15px 15px 0px;" + border="2" + ></img> + </div> + <div> + <img + src="@img2" height="42" alt="@imgs" width="100" + style="float: left; margin: 0px 15px 15px 0px;" + border="2" + ></img> + </div> + <div> + <span style="font-size: 15px;">Activation:</span> + <span style="font-size: 17px; font-weight: bold;">@value</span> + </div> + + + <div> + + + </div> + </div> + """ + ) + #<span style="font-size: 10px; color: #696;">($x, $y)</span> + return (hover) + +def plot_filter_activation(filter_activation_matrix, predictions, save_dir, real_labels, position_activity): + predictions = ['+' if np.round(x) == 1 else '-' for x in predictions] + #predictions = [predictions[x] + '\n+' if np.round(real_labels[x]) == 1 else predictions[x] + '\n-'for x in + # range(len(predictions))] + + n_filter_pairs = filter_activation_matrix.shape[0] + + filter_activation_matrix = np.transpose(filter_activation_matrix, [1,0]) + + colors = ["#75968f", "#a5bab7", "#c9d9d3", "#e2e2e2", "#dfccce", "#ddb7b1", "#cc7878", "#933b41", "#550b1d"] + mapper = LinearColorMapper(palette=colors, low=np.min(filter_activation_matrix), high=np.max(filter_activation_matrix)) + + y_range = ["Filter %d" %n for n in range(n_filter_pairs)] + x_range = [str(n) for n in range(len(predictions))] + + x = ["%d" %n for n in range(len(predictions)) for i in range(n_filter_pairs)] + y = y_range * len(predictions) + + values = filter_activation_matrix.flatten() + img1 = ['filter_%d.png'%(f*2) for _ in range(len(predictions)) for f in range(n_filter_pairs)] + img2 = ['filter_%d.png'%(f*2+1) for _ in range(len(predictions)) for f in range(n_filter_pairs)] + source = ColumnDataSource({'x':x, 'y':y, 'value':values, 'img1': img1, 'img2': img2}) + + hover = get_custom_hover_tool() + TOOLS = [hover, "tap"]#, "save,pan,box_zoom,reset,wheel_zoom"] + + p = figure(title="Prediction Motif Interpretation", + x_range=x_range, y_range=list(reversed(y_range)), + x_axis_location="above", plot_width=900, plot_height=400, + tools=TOOLS, toolbar_location=None) + + p.xaxis.major_label_text_font_size = "4pt" + p.axis.major_label_standoff = 0 + p.toolbar.logo = None + + p.rect(source=source, x='x', y='y', width=1, height=1, fill_color={'field': 'value', 'transform': mapper}) + + + url = "http://www.colors.commutercreative.com/@color/" + taptool = p.select(type=TapTool) + taptool.callback = OpenURL(url=url) + + color_bar = ColorBar(color_mapper=mapper, major_label_text_font_size="5pt", + ticker=BasicTicker(desired_num_ticks=len(colors)), + formatter=PrintfTickFormatter(format="%.2f"), + label_standoff=6, border_line_color=None, location=(0, 0)) + p.add_layout(color_bar, 'right') + + pos_x_range = [str(x) for x in range(position_activity.shape[2])] + pos_x = ["%d" %n for n in range(len(predictions)) for i in range(n_filter_pairs)] + pos_y = y_range * len(predictions) + + pos_values = filter_activation_matrix.flatten() + source = ColumnDataSource({'x':x, 'y':y, 'value':values, 'img1': img1, 'img2': img2}) + + pos_fig = figure(title="Sequence position activation", + x_range=pos_x_range, + x_axis_location="above", plot_width=900, plot_height=400, + tools=['pan', 'zoom_in'], toolbar_location=None) + + pos_fig.toolbar.logo = None + current_sample = 0 + current_filter = 0 + x_dat = [n for n in range(position_activity.shape[2])] + print(position_activity.shape) + pos_fig.line(x_dat, position_activity[0,current_sample,:, current_filter]) + + + grid = gridplot([[p],[pos_fig]], toolbar_options={'logo': None}) + + + output_file("%s/interpretation.html"%save_dir) + + show(grid) # show the plot + + + +def create_interactive_visualization(model_dir, pair_activation_matrix=None, predictions=None, real_labels=None, + only_positive=False, position_activity=None): + pair_activation_matrix = np.load("%sevaluation_activity.npy" % model_dir) + predictions = np.load("%sevaluation_predictions.npy" % model_dir) + position_activity = np.load("%s/position_scores.npy" % model_dir) + + pair_weights = np.load("%sweight_2.npy" % model_dir) + if only_positive: + pair_weights = np.exp(pair_weights) + prediction_bias = np.exp(np.load("%sweight_3.npy" % model_dir)) + + pair_activation_matrix = np.transpose(pair_activation_matrix) + net_activity = np.multiply(pair_activation_matrix, pair_weights) + + relative_activity = net_activity / prediction_bias + + plot_filter_activation(relative_activity, predictions, model_dir, real_labels, position_activity=position_activity) + + +def main(): + create_interactive_visualization(model_dir="Results/Synthetic/Synthetic_14_length500_frac0.50" + "/SparsityTerm_InteractiveTest_OnlyPos_36Filters_uniform_True_normalized/0/", + real_labels="a", only_positive=False) + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/sequence_position_analyzer.py b/sequence_position_analyzer.py new file mode 100644 index 0000000..cc7e125 --- /dev/null +++ b/sequence_position_analyzer.py @@ -0,0 +1,79 @@ +import tensorflow as tf +import numpy as np +from data_processing import sequence_processing + +allowed_characters = ['H', 'R', 'K', 'I', 'P', 'F', 'Y', 'Q', 'E', 'D', + 'C', 'W', 'T', 'L', 'N', 'G', 'A', 'M', 'V', 'S'] #, 'U'] + +len_allowed_char = len(allowed_characters) + +def one_hot(aa): + return([0 if allowed_characters[x] != aa else 1 for x in range(len_allowed_char)]) + +def convert_to_one_hot(fasta_seq): + array = [] + for aa in fasta_seq: + array.append(one_hot(aa)) + + + array = np.asarray(array, dtype=np.float32) + array = np.reshape(array, [1, -1, len_allowed_char]) + return(array) + +def get_motif_positions_in_sequence(input_pairs, filters, filter_rect_bias, batch_size=100): + n_fils = filters.shape[-1] + filter_size = filters.shape[0] + + + with tf.Session().as_default() as sess: + zero_labels = [0 for _ in range(len(input_pairs))] # Labels are not used, but a valid argument has to be passed + seq1, seq2, _, iter_initializer = sequence_processing(batch_size=batch_size, pair_list=input_pairs, + labels=zero_labels, shuffle=False, threads=1) + + feature_map_1 = tf.nn.conv1d(seq1, filters=filters, stride=1, padding="SAME") + feature_map_2 = tf.nn.conv1d(seq2, filters=filters, stride=1, padding="SAME") + + value_map_1 = tf.nn.relu(feature_map_1 - filter_rect_bias) + filter_rect_bias + value_map_2 = tf.nn.relu(feature_map_2 - filter_rect_bias) + filter_rect_bias + + sess.run(tf.global_variables_initializer()) + sess.run(iter_initializer) + + value_maps = [] + for i in range(int(len(input_pairs)/batch_size) + 1): + values = sess.run([value_map_1, value_map_2]) + value_maps.append(values) + + value_maps = np.vstack(value_maps) + print(value_maps.shape) + + #bools = value_maps > 0.0 + + #bools = np.reshape(bools, [-1, n_fils]) + #for f_idx in range(bools.shape[-1]): + # print("Contains motif %d at positions: "%f_idx) + # print(np.argmax(bools[:,f_idx]) - filter_size) + + return(value_maps) + + +def main(): + model_dir ="../../Biological Results/Sun/SuppCD" \ + "/12Filters_uniform_AttentionFalse_softmax_Sigmoid_0" + + with open("../../Data/Sun/Fasta/NP_002937.1") as inp: + fasta = inp.read() + + print(fasta) + print(len(fasta)) + + input_seq = convert_to_one_hot(fasta) + + filters = np.load("%s/weight_0.npy"%(model_dir)) + filter_rect_bias = np.reshape(np.load("%s/weight_1.npy"%(model_dir)), [-1]) + + vals, bools = get_motif_positions_in_sequence(input_seq, filters, filter_rect_bias) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/train.py b/train.py index eb8de2d..7ad824c 100644 --- a/train.py +++ b/train.py @@ -1,5 +1,7 @@ from CNN_PPI import PPI_CNN +from interpretation import create_interactive_visualization from data_processing import * +from sequence_position_analyzer import get_motif_positions_in_sequence def start_training(data_dir, base_save_dir, pair_file, eval_pair_file, attention, silence_edges, lr, fsr, n_filters, \ batch_size, @@ -46,6 +48,19 @@ def start_training(data_dir, base_save_dir, pair_file, eval_pair_file, attention del (nn) plot_evaluation_activity(save_dir, activation, predictions, evaluation_labels) + # Note that filters.npy is different from weight_0.npy + # weight_0.npy are the weights before normalization + filters = np.load("%s/filters.npy"%save_dir) + bias = np.load("%s/weight_1.npy"%save_dir) + position_scores = get_motif_positions_in_sequence(evaluation_pairs, filters, bias) + np.save("%s/position_scores"%save_dir, position_scores) + + create_interactive_visualization(save_dir, activation, predictions, evaluation_labels, position_activity=position_scores) + + + + print(position_scores[0].shape) + plot_roc(roc, area, "%s/roc.png" % save_dir) def main(): @@ -87,7 +102,7 @@ def main(): if not os.path.exists("Results/Synthetic/%s/"%synth_name): os.mkdir("Results/Synthetic/%s/"%synth_name) - save_dir = "Results/Synthetic/%s/SparsityTerm_HighFSR_OnlyPos_%dFilters_%s_%s_%s" \ + save_dir = "Results/Synthetic/%s/SparsityTerm_InteractiveTest_OnlyPos_%dFilters_%s_%s_%s" \ "/" % ( synth_name, n_filters, init, str(attention), act) -- GitLab