Skip to content
Snippets Groups Projects
Commit 055b870b authored by Lannoy, Carlos de's avatar Lannoy, Carlos de
Browse files

add script for conversion h5 to trt

parent b359af78
No related branches found
No related tags found
No related merge requests found
import argparse, re, h5py
from tempfile import TemporaryDirectory
import tensorflow as tf
from tensorflow import keras
from tensorflow.python.compiler.tensorrt import trt_convert as trt
from tensorflow.python.saved_model import tag_constants
parser = argparse.ArgumentParser(description='Convert baseLess h5 model to trt model for efficiency')
parser.add_argument('--model', type=str, required=True, help='h5-format input model')
parser.add_argument('--fp-precision', type=int, choices=[16,32], default=32, # todo: 8 bit
help='Floating point precision for output model: 16 or 32 bit [default: 32]')
parser.add_argument('--mem', type=str, default='512MB',
help='Max RAM to use during inference, define MB or GB [default: 512MB]')
parser.add_argument('--out-model', type=str, required=True)
args = parser.parse_args()
# --- parameter parsing ---
out_model = args.out_model
if out_model.endswith('/'): out_model = out_model[:-1]
mem_int = int(re.search('^[0-9]+', args.mem).group(0))
if 'MB' in args.mem:
mem_int *= int(1e6)
elif 'GB' in args.mem:
mem_int *= int(1e9)
else:
raise ValueError('Define mem in GB or MB')
if args.fp_precision == 32:
pm = trt.TrtPrecisionMode.FP32
elif args.fp_precision == 16:
pm = trt.TrtPrecisionMode.FP16
else:
raise ValueError('Precision must be 16 or 32')
conversion_params = trt.DEFAULT_TRT_CONVERSION_PARAMS._replace(precision_mode=pm,
max_workspace_size_bytes=mem_int)
# --- load and convert model ---
mod = tf.keras.models.load_model(args.model)
with TemporaryDirectory() as td:
mod.save(td)
converter = trt.TrtGraphConverterV2(input_saved_model_dir=td,
conversion_params=conversion_params)
converter.convert()
converter.save(output_saved_model_dir=args.out_model)
# --- add custom parameters as a txt file ---
with h5py.File(args.model, 'r') as fh:
model_type = fh.attrs['model_type']
kmer_list = fh.attrs['kmer_list']
with open(f'{out_model}/baseless_params.txt', 'w') as fh:
fh.write(f'{model_type}\n{kmer_list}')
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment