From 969adcc26c6489b54d8affed6ac253d2ffa2477d Mon Sep 17 00:00:00 2001 From: Carlos de Lannoy <carlos.delannoy@wur.nl> Date: Sun, 19 Jun 2022 10:35:04 +0200 Subject: [PATCH] update --- inference/compile_model.py | 27 +++-- validation/parse_16s_performance.py | 42 +++++++ validation/parse_speed_benchmark_results.py | 124 ++++++++++++++++++++ 3 files changed, 184 insertions(+), 9 deletions(-) create mode 100644 validation/parse_16s_performance.py create mode 100644 validation/parse_speed_benchmark_results.py diff --git a/inference/compile_model.py b/inference/compile_model.py index e4f0012..0efc935 100644 --- a/inference/compile_model.py +++ b/inference/compile_model.py @@ -164,25 +164,33 @@ def compile_model(kmer_dict, filter_width, threshold, batch_size, parallel_model layer_index = 1 for il, l in enumerate(mod_first.layers): if type(l) == tf.keras.layers.Dense: - nl = tf.keras.layers.Dense(l.weights[0].shape[1] * nb_mods, activation=l.activation) + x = tf.keras.layers.Dense(l.weights[0].shape[1] * nb_mods, activation=l.activation)(x) trained_layers_dict[il] = {'ltype': 'dense', 'layer_index': layer_index, 'weights': []} elif type(l) == tf.keras.layers.Conv1D: - nl = tf.keras.layers.Conv1D(l.filters * nb_mods, l.kernel_size, activation=l.activation, groups=nb_mods if il>0 else 1) + x = tf.keras.layers.Conv1D(l.filters * nb_mods, l.kernel_size, activation=l.activation, + groups=nb_mods if il > 0 else 1)(x) trained_layers_dict[il] = {'ltype': 'conv1d', 'layer_index': layer_index, 'weights': []} elif type(l) == tf.keras.layers.BatchNormalization: - nl = tf.keras.layers.BatchNormalization() + x = tf.keras.layers.BatchNormalization()(x) trained_layers_dict[il] = {'ltype': 'batchnormalization', 'layer_index': layer_index, 'weights': []} elif type(l) == tf.keras.layers.Dropout: - nl = tf.keras.layers.Dropout(l.rate) + x = tf.keras.layers.Dropout(l.rate)(x) elif type(l) == tf.keras.layers.MaxPool1D: - nl = tf.keras.layers.MaxPool1D(l.pool_size) + x = tf.keras.layers.MaxPool1D(l.pool_size)(x) elif type(l) == tf.keras.layers.Flatten: - nl = lambda x: tf.concat([tf.keras.layers.Flatten()(xs) for xs in tf.split(x, nb_mods, axis=2)], -1) - layer_index += nb_mods + 1 + # x = tf.transpose(x, perm=(0,2,1)) + nb_filters, t_dim = int(x.shape[2] / nb_mods), x.shape[1] + x = tf.reshape(tf.expand_dims(x, -1), (batch_size, t_dim, nb_mods, nb_filters)) + x = tf.transpose(x, perm=[0, 2, 1, 3]) + x = tf.keras.layers.Flatten()(x) + layer_index += 3 # additional layer count for transpose + # nb_filters, t_dim = int(x.shape[2] / nb_mods), x.shape[1] + # x = tf.reshape(x, (batch_size, t_dim * nb_filters, nb_mods)) + # nl = lambda x: tf.concat([tf.keras.layers.Flatten()(xs) for xs in tf.split(x, nb_mods, axis=2)], -1) + # layer_index += nb_mods + 1 else: raise ValueError(f'models with layer type {type(l)} cannot be concatenated yet') layer_index += 1 - x = nl(x) output = K.cast_to_floatx(K.greater(x, threshold)) meta_mod = tf.keras.Model(inputs=input, outputs=output) @@ -196,7 +204,8 @@ def compile_model(kmer_dict, filter_width, threshold, batch_size, parallel_model # fill in weights for il in trained_layers_dict: if trained_layers_dict[il]['ltype'] == 'conv1d': - weight_list = concat_conv1d_weights(meta_mod.layers[il+1].weights, trained_layers_dict[il]['weights']) + weight_list = concat_conv1d_weights(meta_mod.layers[trained_layers_dict[il]['layer_index']].weights, + trained_layers_dict[il]['weights']) elif trained_layers_dict[il]['ltype'] == 'dense': weight_list = concat_dense_weights(meta_mod.layers[trained_layers_dict[il]['layer_index']].weights, trained_layers_dict[il]['weights']) diff --git a/validation/parse_16s_performance.py b/validation/parse_16s_performance.py new file mode 100644 index 0000000..73dbfd7 --- /dev/null +++ b/validation/parse_16s_performance.py @@ -0,0 +1,42 @@ +import pandas as pd + + +import matplotlib.pyplot as plt +import seaborn as sns + + +def prettify_species(sp): + genus, species = sp.split(' ') + return f'${genus[0].upper()}. {species}$' + +perf_csv = '/home/carlos/Documents/202108_baseLess/figures/16s_benchmark/16s_benchmark.csv' +tool_order = ['baseLess', 'Guppy fast','SquiggleNet', 'DeepNano-blitz', 'UNCALLED'] + +perf_df = pd.read_csv(perf_csv) +perf_df.query('species != "porphyromonas gingivalis"', inplace=True) +perf_df.sort_values('species', inplace=True) +perf_df.loc[:, 'species'] = perf_df.species.apply(lambda x: prettify_species(x)) + +tool_dict = {'deepnano': 'DeepNano-blitz', + 'uncalled': 'UNCALLED', + 'guppy': 'Guppy fast', + 'squigglenet': 'SquiggleNet', + 'baseless_8_and_9mers': 'baseLess'} +perf_df.loc[:, 'tool'] = perf_df.tool.apply(lambda x: tool_dict[x]) + +fig, ax = plt.subplots(1,2, figsize=(210/3 * 0.0393701 * 2,7), sharey=True) +sns.barplot(y='species', x='accuracy', hue='tool', data=perf_df, ax=ax[0], + hue_order=tool_order, errwidth=1.5) +ax[0].set_xlim([0,1.0]) +ax[0].invert_xaxis() +ax[0].get_legend().remove() +# ax[0].get_yaxis().set_ticks([]) +ax[0].get_yaxis().set_ticklabels([]) +ax[0].set_ylabel(None) +sns.barplot(y='species', x='f1', hue='tool', data=perf_df, ax=ax[1], hue_order=tool_order, + errwidth=1.5) +ax[1].set_ylabel(None) +ax[1].set_xlim([0,1.0]) +# plt.setp(ax[1].get_xticklabels(), rotation=-45, ha="right", rotation_mode="anchor") +plt.tight_layout() +plt.savefig('/home/carlos/Documents/202108_baseLess/figures/16s_benchmark/16S_performance.svg') diff --git a/validation/parse_speed_benchmark_results.py b/validation/parse_speed_benchmark_results.py new file mode 100644 index 0000000..1e582b0 --- /dev/null +++ b/validation/parse_speed_benchmark_results.py @@ -0,0 +1,124 @@ +import argparse, re, os, sys, h5py +import pandas as pd +import numpy as np +from pathlib import Path +from glob import glob + +import matplotlib.pyplot as plt +import seaborn as sns + +sys.path.append(f'{list(Path(__file__).resolve().parents)[1]}') +from low_requirement_helper_functions import parse_input_path + +tool_order = ['baseLess', 'Guppy fast','SquiggleNet', 'DeepNano-blitz', 'UNCALLED'] + +def parse_baseless(pth, hardware): + times = [] + for sdir in pth.iterdir(): + with open(str(sdir) + '/run_stats.log', 'r') as fh: time_txt = fh.read() + times.append(float(re.search('[0-9]+(?=s)', time_txt).group(0))) + df = pd.DataFrame({'tool': 'baseLess', 'hardware': hardware, 'wall_time': times}) + return df + + +def parse_deepnano(pth, hardware): + times = [] + for log_fn in Path(str(pth) + '/logs').iterdir(): + with open(str(log_fn), 'r') as fh: time_txt = [line for line in fh.readlines() if line.startswith('real')][0] + times.append(float(re.search('[0-9]+(?=m)', time_txt).group(0)) * 60 + float( + re.search('[0-9,]+(?=s)', time_txt).group(0).replace(',', '.'))) + df = pd.DataFrame({'tool': 'DeepNano-blitz', 'hardware': hardware, 'wall_time': times}) + return df + + +def parse_guppy(pth, hardware): + times = [] + for sdir in pth.iterdir(): + with open(glob(str(sdir)+'/*.log')[0], 'r') as fh: time_txt = fh.read() + times.append(float(re.search('(?<=Caller time:)\s+[0-9]+', time_txt).group(0)) * 0.001) + df = pd.DataFrame({'tool': 'Guppy fast', 'hardware': hardware, 'wall_time': times}) + return df + + +def parse_squigglenet(pth, hardware): + times = [] + for log_fn in Path(str(pth) + '/logs').iterdir(): + with open(str(log_fn), 'r') as fh: time_txt = [line for line in fh.readlines() if line.startswith('[Step FINAL]')][0] + times.append(float(re.search('[0-9.]+(?= seconds)', time_txt).group(0))) + df = pd.DataFrame({'tool': 'SquiggleNet', 'hardware': hardware, 'wall_time': times}) + return df + + +def parse_uncalled(pth, hardware): + times = [] + for log_fn in Path(str(pth) + '/logs').iterdir(): + with open(str(log_fn), 'r') as fh: time_txt = [line for line in fh.readlines() if line.startswith('real')][0] + times.append(float(re.search('[0-9]+(?=m)', time_txt).group(0)) * 60 + float(re.search('[0-9,]+(?=s)', time_txt).group(0).replace(',', '.'))) + df = pd.DataFrame({'tool': 'UNCALLED', 'hardware': hardware, 'wall_time': times}) + return df + + +parser = argparse.ArgumentParser(description='Parse speed benchmark results') +parser.add_argument('--in-dir-desktop', type=str, required=True) +parser.add_argument('--fast5', type=str, required=True, + help='directory of test reads used in benchmark') +parser.add_argument('--in-dir-jetson', type=str, required=True) +parser.add_argument('--out-svg', type=str,required=True) +args = parser.parse_args() + +speed_df_list = [] + +# parse desktop results +for bm_dir in Path(args.in_dir_desktop).iterdir(): + if not bm_dir.is_dir(): continue + if bm_dir.name == 'baseless': + sdf = parse_baseless(bm_dir, 'Desktop') + elif bm_dir.name == 'deepnano': + sdf = parse_deepnano(bm_dir, 'Desktop') + elif bm_dir.name == 'guppy_fast': + sdf = parse_guppy(bm_dir, 'Desktop') + elif bm_dir.name == 'squigglenet': + sdf = parse_squigglenet(bm_dir, 'Desktop') + elif bm_dir.name == 'uncalled': + sdf = parse_uncalled(bm_dir, 'Desktop') + speed_df_list.append(sdf) + +# parse jetson results +for bm_dir in Path(args.in_dir_jetson).iterdir(): + if not bm_dir.is_dir(): continue + if bm_dir.name == 'baseless': + sdf = parse_baseless(bm_dir, 'Jetson') + elif bm_dir.name == 'guppy_fast': + sdf = parse_guppy(bm_dir, 'Jetson') + elif bm_dir.name == 'squigglenet': + sdf = parse_squigglenet(bm_dir, 'Jetson') + speed_df_list.append(sdf) +speed_df = pd.concat(speed_df_list) + +# count bases convert time to speed +nb_bases = 0 +for file in parse_input_path(args.fast5): + with h5py.File(file, 'r') as ff: + fq = ff['Analyses/Basecall_1D_000/BaseCalled_template/Fastq'][()].decode('ascii') + nb_bases += len(fq.split('\n')[1]) +speed_df.loc[:, 'speed'] = nb_bases / speed_df.wall_time + +# plotting +speed_df.loc[speed_df.hardware == 'Jetson', 'speed'] = speed_df.loc[speed_df.hardware == 'Jetson', 'speed'] / 10e3 +speed_df.loc[speed_df.hardware == 'Desktop', 'speed'] = speed_df.loc[speed_df.hardware == 'Desktop', 'speed'] / 10e6 + +fig, axes = plt.subplots(2, 1, figsize=[210/3 * 0.0393701, 7], sharex=True) +sns.barplot(x='tool', y='speed', data=speed_df.query('hardware == "Jetson"'), ax=axes[0], + errwidth=2., order=tool_order) +sns.barplot(x='tool', y='speed', data=speed_df.query('hardware == "Desktop"'), ax=axes[1], + errwidth=2., order=tool_order) + +plt.xticks(rotation=-45) +for ax in axes: + ax.set_xlabel('') + ax.ticklabel_format(axis="y", style="sci", scilimits=(0, 0)) +axes[0].set_ylabel('10$^4$ bases/s') +axes[1].set_ylabel('10$^6$ bases/s') +plt.tight_layout() +plt.savefig(args.out_svg) +plt.close(fig) \ No newline at end of file -- GitLab