Skip to content
Snippets Groups Projects
Commit 055b870b authored by Lannoy, Carlos de's avatar Lannoy, Carlos de
Browse files

add script for conversion h5 to trt

parent b359af78
Branches coa_analysis
No related tags found
No related merge requests found
import argparse, re, h5py
from tempfile import TemporaryDirectory
import tensorflow as tf
from tensorflow import keras
from tensorflow.python.compiler.tensorrt import trt_convert as trt
from tensorflow.python.saved_model import tag_constants
parser = argparse.ArgumentParser(description='Convert baseLess h5 model to trt model for efficiency')
parser.add_argument('--model', type=str, required=True, help='h5-format input model')
parser.add_argument('--fp-precision', type=int, choices=[16,32], default=32, # todo: 8 bit
help='Floating point precision for output model: 16 or 32 bit [default: 32]')
parser.add_argument('--mem', type=str, default='512MB',
help='Max RAM to use during inference, define MB or GB [default: 512MB]')
parser.add_argument('--out-model', type=str, required=True)
args = parser.parse_args()
# --- parameter parsing ---
out_model = args.out_model
if out_model.endswith('/'): out_model = out_model[:-1]
mem_int = int(re.search('^[0-9]+', args.mem).group(0))
if 'MB' in args.mem:
mem_int *= int(1e6)
elif 'GB' in args.mem:
mem_int *= int(1e9)
else:
raise ValueError('Define mem in GB or MB')
if args.fp_precision == 32:
pm = trt.TrtPrecisionMode.FP32
elif args.fp_precision == 16:
pm = trt.TrtPrecisionMode.FP16
else:
raise ValueError('Precision must be 16 or 32')
conversion_params = trt.DEFAULT_TRT_CONVERSION_PARAMS._replace(precision_mode=pm,
max_workspace_size_bytes=mem_int)
# --- load and convert model ---
mod = tf.keras.models.load_model(args.model)
with TemporaryDirectory() as td:
mod.save(td)
converter = trt.TrtGraphConverterV2(input_saved_model_dir=td,
conversion_params=conversion_params)
converter.convert()
converter.save(output_saved_model_dir=args.out_model)
# --- add custom parameters as a txt file ---
with h5py.File(args.model, 'r') as fh:
model_type = fh.attrs['model_type']
kmer_list = fh.attrs['kmer_list']
with open(f'{out_model}/baseless_params.txt', 'w') as fh:
fh.write(f'{model_type}\n{kmer_list}')
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment