Commit 969adcc2 authored by Lannoy, Carlos de's avatar Lannoy, Carlos de
Browse files

update

parent 0a47d757
......@@ -164,25 +164,33 @@ def compile_model(kmer_dict, filter_width, threshold, batch_size, parallel_model
layer_index = 1
for il, l in enumerate(mod_first.layers):
if type(l) == tf.keras.layers.Dense:
nl = tf.keras.layers.Dense(l.weights[0].shape[1] * nb_mods, activation=l.activation)
x = tf.keras.layers.Dense(l.weights[0].shape[1] * nb_mods, activation=l.activation)(x)
trained_layers_dict[il] = {'ltype': 'dense', 'layer_index': layer_index, 'weights': []}
elif type(l) == tf.keras.layers.Conv1D:
nl = tf.keras.layers.Conv1D(l.filters * nb_mods, l.kernel_size, activation=l.activation, groups=nb_mods if il>0 else 1)
x = tf.keras.layers.Conv1D(l.filters * nb_mods, l.kernel_size, activation=l.activation,
groups=nb_mods if il > 0 else 1)(x)
trained_layers_dict[il] = {'ltype': 'conv1d', 'layer_index': layer_index, 'weights': []}
elif type(l) == tf.keras.layers.BatchNormalization:
nl = tf.keras.layers.BatchNormalization()
x = tf.keras.layers.BatchNormalization()(x)
trained_layers_dict[il] = {'ltype': 'batchnormalization', 'layer_index': layer_index, 'weights': []}
elif type(l) == tf.keras.layers.Dropout:
nl = tf.keras.layers.Dropout(l.rate)
x = tf.keras.layers.Dropout(l.rate)(x)
elif type(l) == tf.keras.layers.MaxPool1D:
nl = tf.keras.layers.MaxPool1D(l.pool_size)
x = tf.keras.layers.MaxPool1D(l.pool_size)(x)
elif type(l) == tf.keras.layers.Flatten:
nl = lambda x: tf.concat([tf.keras.layers.Flatten()(xs) for xs in tf.split(x, nb_mods, axis=2)], -1)
layer_index += nb_mods + 1
# x = tf.transpose(x, perm=(0,2,1))
nb_filters, t_dim = int(x.shape[2] / nb_mods), x.shape[1]
x = tf.reshape(tf.expand_dims(x, -1), (batch_size, t_dim, nb_mods, nb_filters))
x = tf.transpose(x, perm=[0, 2, 1, 3])
x = tf.keras.layers.Flatten()(x)
layer_index += 3 # additional layer count for transpose
# nb_filters, t_dim = int(x.shape[2] / nb_mods), x.shape[1]
# x = tf.reshape(x, (batch_size, t_dim * nb_filters, nb_mods))
# nl = lambda x: tf.concat([tf.keras.layers.Flatten()(xs) for xs in tf.split(x, nb_mods, axis=2)], -1)
# layer_index += nb_mods + 1
else:
raise ValueError(f'models with layer type {type(l)} cannot be concatenated yet')
layer_index += 1
x = nl(x)
output = K.cast_to_floatx(K.greater(x, threshold))
meta_mod = tf.keras.Model(inputs=input, outputs=output)
......@@ -196,7 +204,8 @@ def compile_model(kmer_dict, filter_width, threshold, batch_size, parallel_model
# fill in weights
for il in trained_layers_dict:
if trained_layers_dict[il]['ltype'] == 'conv1d':
weight_list = concat_conv1d_weights(meta_mod.layers[il+1].weights, trained_layers_dict[il]['weights'])
weight_list = concat_conv1d_weights(meta_mod.layers[trained_layers_dict[il]['layer_index']].weights,
trained_layers_dict[il]['weights'])
elif trained_layers_dict[il]['ltype'] == 'dense':
weight_list = concat_dense_weights(meta_mod.layers[trained_layers_dict[il]['layer_index']].weights,
trained_layers_dict[il]['weights'])
......
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
def prettify_species(sp):
genus, species = sp.split(' ')
return f'${genus[0].upper()}. {species}$'
perf_csv = '/home/carlos/Documents/202108_baseLess/figures/16s_benchmark/16s_benchmark.csv'
tool_order = ['baseLess', 'Guppy fast','SquiggleNet', 'DeepNano-blitz', 'UNCALLED']
perf_df = pd.read_csv(perf_csv)
perf_df.query('species != "porphyromonas gingivalis"', inplace=True)
perf_df.sort_values('species', inplace=True)
perf_df.loc[:, 'species'] = perf_df.species.apply(lambda x: prettify_species(x))
tool_dict = {'deepnano': 'DeepNano-blitz',
'uncalled': 'UNCALLED',
'guppy': 'Guppy fast',
'squigglenet': 'SquiggleNet',
'baseless_8_and_9mers': 'baseLess'}
perf_df.loc[:, 'tool'] = perf_df.tool.apply(lambda x: tool_dict[x])
fig, ax = plt.subplots(1,2, figsize=(210/3 * 0.0393701 * 2,7), sharey=True)
sns.barplot(y='species', x='accuracy', hue='tool', data=perf_df, ax=ax[0],
hue_order=tool_order, errwidth=1.5)
ax[0].set_xlim([0,1.0])
ax[0].invert_xaxis()
ax[0].get_legend().remove()
# ax[0].get_yaxis().set_ticks([])
ax[0].get_yaxis().set_ticklabels([])
ax[0].set_ylabel(None)
sns.barplot(y='species', x='f1', hue='tool', data=perf_df, ax=ax[1], hue_order=tool_order,
errwidth=1.5)
ax[1].set_ylabel(None)
ax[1].set_xlim([0,1.0])
# plt.setp(ax[1].get_xticklabels(), rotation=-45, ha="right", rotation_mode="anchor")
plt.tight_layout()
plt.savefig('/home/carlos/Documents/202108_baseLess/figures/16s_benchmark/16S_performance.svg')
import argparse, re, os, sys, h5py
import pandas as pd
import numpy as np
from pathlib import Path
from glob import glob
import matplotlib.pyplot as plt
import seaborn as sns
sys.path.append(f'{list(Path(__file__).resolve().parents)[1]}')
from low_requirement_helper_functions import parse_input_path
tool_order = ['baseLess', 'Guppy fast','SquiggleNet', 'DeepNano-blitz', 'UNCALLED']
def parse_baseless(pth, hardware):
times = []
for sdir in pth.iterdir():
with open(str(sdir) + '/run_stats.log', 'r') as fh: time_txt = fh.read()
times.append(float(re.search('[0-9]+(?=s)', time_txt).group(0)))
df = pd.DataFrame({'tool': 'baseLess', 'hardware': hardware, 'wall_time': times})
return df
def parse_deepnano(pth, hardware):
times = []
for log_fn in Path(str(pth) + '/logs').iterdir():
with open(str(log_fn), 'r') as fh: time_txt = [line for line in fh.readlines() if line.startswith('real')][0]
times.append(float(re.search('[0-9]+(?=m)', time_txt).group(0)) * 60 + float(
re.search('[0-9,]+(?=s)', time_txt).group(0).replace(',', '.')))
df = pd.DataFrame({'tool': 'DeepNano-blitz', 'hardware': hardware, 'wall_time': times})
return df
def parse_guppy(pth, hardware):
times = []
for sdir in pth.iterdir():
with open(glob(str(sdir)+'/*.log')[0], 'r') as fh: time_txt = fh.read()
times.append(float(re.search('(?<=Caller time:)\s+[0-9]+', time_txt).group(0)) * 0.001)
df = pd.DataFrame({'tool': 'Guppy fast', 'hardware': hardware, 'wall_time': times})
return df
def parse_squigglenet(pth, hardware):
times = []
for log_fn in Path(str(pth) + '/logs').iterdir():
with open(str(log_fn), 'r') as fh: time_txt = [line for line in fh.readlines() if line.startswith('[Step FINAL]')][0]
times.append(float(re.search('[0-9.]+(?= seconds)', time_txt).group(0)))
df = pd.DataFrame({'tool': 'SquiggleNet', 'hardware': hardware, 'wall_time': times})
return df
def parse_uncalled(pth, hardware):
times = []
for log_fn in Path(str(pth) + '/logs').iterdir():
with open(str(log_fn), 'r') as fh: time_txt = [line for line in fh.readlines() if line.startswith('real')][0]
times.append(float(re.search('[0-9]+(?=m)', time_txt).group(0)) * 60 + float(re.search('[0-9,]+(?=s)', time_txt).group(0).replace(',', '.')))
df = pd.DataFrame({'tool': 'UNCALLED', 'hardware': hardware, 'wall_time': times})
return df
parser = argparse.ArgumentParser(description='Parse speed benchmark results')
parser.add_argument('--in-dir-desktop', type=str, required=True)
parser.add_argument('--fast5', type=str, required=True,
help='directory of test reads used in benchmark')
parser.add_argument('--in-dir-jetson', type=str, required=True)
parser.add_argument('--out-svg', type=str,required=True)
args = parser.parse_args()
speed_df_list = []
# parse desktop results
for bm_dir in Path(args.in_dir_desktop).iterdir():
if not bm_dir.is_dir(): continue
if bm_dir.name == 'baseless':
sdf = parse_baseless(bm_dir, 'Desktop')
elif bm_dir.name == 'deepnano':
sdf = parse_deepnano(bm_dir, 'Desktop')
elif bm_dir.name == 'guppy_fast':
sdf = parse_guppy(bm_dir, 'Desktop')
elif bm_dir.name == 'squigglenet':
sdf = parse_squigglenet(bm_dir, 'Desktop')
elif bm_dir.name == 'uncalled':
sdf = parse_uncalled(bm_dir, 'Desktop')
speed_df_list.append(sdf)
# parse jetson results
for bm_dir in Path(args.in_dir_jetson).iterdir():
if not bm_dir.is_dir(): continue
if bm_dir.name == 'baseless':
sdf = parse_baseless(bm_dir, 'Jetson')
elif bm_dir.name == 'guppy_fast':
sdf = parse_guppy(bm_dir, 'Jetson')
elif bm_dir.name == 'squigglenet':
sdf = parse_squigglenet(bm_dir, 'Jetson')
speed_df_list.append(sdf)
speed_df = pd.concat(speed_df_list)
# count bases convert time to speed
nb_bases = 0
for file in parse_input_path(args.fast5):
with h5py.File(file, 'r') as ff:
fq = ff['Analyses/Basecall_1D_000/BaseCalled_template/Fastq'][()].decode('ascii')
nb_bases += len(fq.split('\n')[1])
speed_df.loc[:, 'speed'] = nb_bases / speed_df.wall_time
# plotting
speed_df.loc[speed_df.hardware == 'Jetson', 'speed'] = speed_df.loc[speed_df.hardware == 'Jetson', 'speed'] / 10e3
speed_df.loc[speed_df.hardware == 'Desktop', 'speed'] = speed_df.loc[speed_df.hardware == 'Desktop', 'speed'] / 10e6
fig, axes = plt.subplots(2, 1, figsize=[210/3 * 0.0393701, 7], sharex=True)
sns.barplot(x='tool', y='speed', data=speed_df.query('hardware == "Jetson"'), ax=axes[0],
errwidth=2., order=tool_order)
sns.barplot(x='tool', y='speed', data=speed_df.query('hardware == "Desktop"'), ax=axes[1],
errwidth=2., order=tool_order)
plt.xticks(rotation=-45)
for ax in axes:
ax.set_xlabel('')
ax.ticklabel_format(axis="y", style="sci", scilimits=(0, 0))
axes[0].set_ylabel('10$^4$ bases/s')
axes[1].set_ylabel('10$^6$ bases/s')
plt.tight_layout()
plt.savefig(args.out_svg)
plt.close(fig)
\ No newline at end of file
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment