Skip to content
Snippets Groups Projects
Commit 338de017 authored by Carlos de Lannoy's avatar Carlos de Lannoy
Browse files

make logos and squiggle plots for tp and fp events

parent 4c30d016
No related branches found
No related tags found
No related merge requests found
import argparse, os, sys
import numpy as np
import pandas as pd
import logomaker
import matplotlib.pyplot as plt
from os.path import splitext, dirname
from pathlib import Path
sys.path.append(f'{dirname(Path(__file__).resolve())}/..')
from helper_functions import parse_output_path, parse_input_path
def condense_preds(pred):
condensed_pos_idx_list = []
cur_event = []
for i, p in enumerate(pred):
if p:
cur_event.append(i)
else:
if len(cur_event):
condensed_pos_idx_list.append(cur_event)
cur_event = []
if len(cur_event): condensed_pos_idx_list.append(cur_event)
return condensed_pos_idx_list
parser = argparse.ArgumentParser(description='Parse TP and FP example read sections out of npzs and graph')
parser.add_argument('--npz', type=str, required=True,
help='npz file of example read, as produced by train_nn')
parser.add_argument('--filter-width', type=int, required=True,
help='number measurements to cut out of squiggle around predicted positives.')
args = parser.parse_args()
with np.load(args.npz) as fh:
raw = np.squeeze(fh['raw_excerpt'])
base_labels = fh['base_labels_excerpt']
target = fh['target']
posterior = fh['posterior']
pred = fh['labels_predicted']
assert len(raw) == len(base_labels) == len(posterior) == len(pred)
hfw = args.filter_width // 2
condensed_idx_list = condense_preds(pred)
fp_kmer_list = []
tp_raw, fp_raw = [], []
for ci in condensed_idx_list:
pos_kmer_list = base_labels[ci]
event_mid = int(np.ceil(np.median(ci)))
cur_raw = raw[event_mid - hfw: event_mid+hfw]
if np.any(np.in1d(target, pos_kmer_list)):
tp_raw.append(cur_raw)
else:
fp_raw.append(cur_raw)
kmers, counts = np.unique(pos_kmer_list, return_counts=True)
fp_kmer_list.append(kmers[np.argmax(counts)])
kmer_size = len(target[0])
nb_fp = len(fp_kmer_list)
# --- plot logo ---
logo_df = pd.DataFrame(0.0, index=np.arange(kmer_size), columns=['A', 'C', 'T', 'G'])
fp_kmer_mat = np.array([list(km) for km in fp_kmer_list])
for i in range(kmer_size):
base, counts = np.unique(fp_kmer_mat[:, i], return_counts=True)
logo_df.loc[i, :] = {b: c/nb_fp for b, c in zip(base, counts)}
plt.figure()
logo = logomaker.Logo(logo_df)
plt.ylabel('frequency')
plt.xlabel('position')
plt.savefig(f'{splitext(args.npz)[0]}_logo.svg', dpi=400)
plt.close()
# --- plot squiggles ---
x_coords = np.arange(args.filter_width)
y_coords_fp = np.vstack(fp_raw).T
y_coords_tp = np.vstack(tp_raw).T
y_coords_fp_mean = np.median(y_coords_fp, axis=1)
plt.figure(figsize=(20, 5))
plt.plot(x_coords, y_coords_fp, alpha=0.2, color='red')
plt.plot(x_coords, y_coords_fp_mean, alpha=1, color='red')
plt.plot(x_coords, y_coords_tp, alpha=1, color='blue')
plt.axvline(hfw, color='black')
plt.xlabel('measurement #')
plt.ylabel('norm. signal')
plt.savefig(f'{splitext(args.npz)[0]}_traces.svg', dpi=400)
plt.close()
......@@ -120,7 +120,7 @@ def train(parameter_file, training_data, test_data, plots_path=None,
start=graph_start,
nb_classes=2)
output_file(sample_predictions_path + "pred_ep%s_ex%s.html" % (epoch_index, tr_fn))
reader.add_to_npz(npz, ts_predict_name, [y_hat, posterior], ['labels_predicted', 'posterior'])
reader.add_to_npz(npz, ts_predict_name, [y_hat, posterior, x, kmers, nn.target], ['labels_predicted', 'posterior', 'raw_excerpt', 'base_labels_excerpt', 'target'])
save(ts_plot)
break
if (i+1) == len(ts_npzs):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment