From 055b870bd770330104ccaabc377737ea173bf2ab Mon Sep 17 00:00:00 2001 From: Carlos de Lannoy <carlos.delannoy@wur.nl> Date: Mon, 11 Apr 2022 00:04:03 +0200 Subject: [PATCH] add script for conversion h5 to trt --- inference/tf2trt.py | 59 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 59 insertions(+) create mode 100644 inference/tf2trt.py diff --git a/inference/tf2trt.py b/inference/tf2trt.py new file mode 100644 index 0000000..59d4ed8 --- /dev/null +++ b/inference/tf2trt.py @@ -0,0 +1,59 @@ +import argparse, re, h5py + +from tempfile import TemporaryDirectory + +import tensorflow as tf +from tensorflow import keras +from tensorflow.python.compiler.tensorrt import trt_convert as trt +from tensorflow.python.saved_model import tag_constants + + +parser = argparse.ArgumentParser(description='Convert baseLess h5 model to trt model for efficiency') +parser.add_argument('--model', type=str, required=True, help='h5-format input model') +parser.add_argument('--fp-precision', type=int, choices=[16,32], default=32, # todo: 8 bit + help='Floating point precision for output model: 16 or 32 bit [default: 32]') +parser.add_argument('--mem', type=str, default='512MB', + help='Max RAM to use during inference, define MB or GB [default: 512MB]') +parser.add_argument('--out-model', type=str, required=True) + +args = parser.parse_args() + +# --- parameter parsing --- + +out_model = args.out_model +if out_model.endswith('/'): out_model = out_model[:-1] + +mem_int = int(re.search('^[0-9]+', args.mem).group(0)) +if 'MB' in args.mem: + mem_int *= int(1e6) +elif 'GB' in args.mem: + mem_int *= int(1e9) +else: + raise ValueError('Define mem in GB or MB') + +if args.fp_precision == 32: + pm = trt.TrtPrecisionMode.FP32 +elif args.fp_precision == 16: + pm = trt.TrtPrecisionMode.FP16 +else: + raise ValueError('Precision must be 16 or 32') + +conversion_params = trt.DEFAULT_TRT_CONVERSION_PARAMS._replace(precision_mode=pm, + max_workspace_size_bytes=mem_int) + +# --- load and convert model --- +mod = tf.keras.models.load_model(args.model) +with TemporaryDirectory() as td: + mod.save(td) + converter = trt.TrtGraphConverterV2(input_saved_model_dir=td, + conversion_params=conversion_params) + converter.convert() + converter.save(output_saved_model_dir=args.out_model) + +# --- add custom parameters as a txt file --- +with h5py.File(args.model, 'r') as fh: + model_type = fh.attrs['model_type'] + kmer_list = fh.attrs['kmer_list'] +with open(f'{out_model}/baseless_params.txt', 'w') as fh: + fh.write(f'{model_type}\n{kmer_list}') + -- GitLab