From b359af7889eb6a0f1e93fbe719750b41bf818417 Mon Sep 17 00:00:00 2001 From: Carlos de Lannoy <carlos.delannoy@wur.nl> Date: Sun, 10 Apr 2022 23:36:00 +0200 Subject: [PATCH] inference memory and batch size are arguments --- argparse_dicts.py | 19 +++++++++++++++++-- inference/ReadTable.py | 2 +- inference/run_inference.py | 26 +++++++++++++++++++------- 3 files changed, 37 insertions(+), 10 deletions(-) diff --git a/argparse_dicts.py b/argparse_dicts.py index 442a9e2..6d0b582 100644 --- a/argparse_dicts.py +++ b/argparse_dicts.py @@ -58,6 +58,20 @@ model = ('--model', { 'help': 'Combined k-mer model to use for inference.' }) +mem = ('--mem', { + 'type': int, + 'default': -1, + 'help': 'Amount of RAM to use in MB [default: no limit]' + } +) + +batch_size = ('--batch-size', { + 'type': int, + 'default': 4, + 'help': 'Max number of reads for which to run prediction simultaneously, ' + 'decreasing may resolve memory issues [default: 4]' +}) + model_weights = ('--model-weights', { 'type': str, 'required': False, @@ -272,11 +286,12 @@ def get_training_parser(): parser.add_argument(arg[0], **arg[1]) return parser + def get_run_inference_parser(): parser = argparse.ArgumentParser(description='Start up inference routine and watch a fast5 directory for reads.') - for arg in (fast5_in, out_dir, model, inference_mode): + for arg in (fast5_in, out_dir, model, inference_mode, mem, batch_size): parser.add_argument(arg[0], **arg[1]) - parser.add_argument('--continuous-nn', action='store_true',help='Used RNN can handle continuous reads.') + # parser.add_argument('--continuous-nn', action='store_true',help='Used RNN can handle continuous reads.') return parser diff --git a/inference/ReadTable.py b/inference/ReadTable.py index ae0a187..a09c0fc 100644 --- a/inference/ReadTable.py +++ b/inference/ReadTable.py @@ -32,7 +32,7 @@ class ReadTable(object): manager_process.start() return manager_process - def get_read_to_predict(self, batch_size=1): + def get_read_to_predict(self, batch_size): """Get read and a kmer for which to scan the read file :return: Tuple of: (path to read fast5 file, diff --git a/inference/run_inference.py b/inference/run_inference.py index fa5d964..74eacf6 100644 --- a/inference/run_inference.py +++ b/inference/run_inference.py @@ -2,6 +2,7 @@ import sys, signal, os, h5py from pathlib import Path from datetime import datetime import tensorflow as tf +from tensorflow.python.saved_model import tag_constants import numpy as np sys.path.append(f'{list(Path(__file__).resolve().parents)[1]}') @@ -19,8 +20,9 @@ def main(args): tf.config.set_soft_device_placement(True) gpus = tf.config.experimental.list_physical_devices('GPU') - tf.config.experimental.set_virtual_device_configuration(gpus[0], [ - tf.config.experimental.VirtualDeviceConfiguration(memory_limit=1024)]) + if args.mem > 0: + tf.config.experimental.set_virtual_device_configuration(gpus[0], [ + tf.config.experimental.VirtualDeviceConfiguration(memory_limit=args.mem)]) # if gpus: # try: # # Currently, memory growth needs to be the same across GPUs @@ -33,10 +35,20 @@ def main(args): # print(e) print('Loading model...') - mod = tf.keras.models.load_model(args.model) - with h5py.File(args.model, 'r') as fh: - model_type = fh.attrs['model_type'] - kmer_list = fh.attrs['kmer_list'].split(',') + + if args.model.endswith('.h5'): + mod = tf.keras.models.load_model(args.model) + with h5py.File(args.model, 'r') as fh: + model_type = fh.attrs['model_type'] + kmer_list = fh.attrs['kmer_list'].split(',') + else: # assume its a trt model + mod_dir = args.model + if not mod_dir.endswith('/'): mod_dir += '/' + _mod = tf.saved_model.load(args.model, tags=[tag_constants.SERVING]) + mod = _mod.signatures['serving_default'] + with open(f'{mod_dir}baseless_params.txt', 'r') as fh: + model_type, km = fh.read().split('\n') + kmer_list = km.split(',') print(f'Done! Model type is {model_type}') pos_reads_dir = parse_output_path(args.out_dir + 'pos_reads') @@ -70,7 +82,7 @@ def main(args): start_time = datetime.now() # Start inference loop while end_condition(): - read_id, read = read_table.get_read_to_predict() + read_id, read = read_table.get_read_to_predict(batch_size=args.batch_size) if read is None: continue pred = mod(read).numpy() if abundance_mode: -- GitLab