diff --git a/argparse_dicts.py b/argparse_dicts.py
index 442a9e21fef6fe864e09a396d57001e1d8778861..6d0b5820513351cec5e486b999979970dc0a4bfc 100644
--- a/argparse_dicts.py
+++ b/argparse_dicts.py
@@ -58,6 +58,20 @@ model = ('--model', {
     'help': 'Combined k-mer model to use for inference.'
 })
 
+mem = ('--mem', {
+    'type': int,
+    'default': -1,
+    'help': 'Amount of RAM to use in MB [default: no limit]'
+    }
+)
+
+batch_size = ('--batch-size', {
+    'type': int,
+    'default': 4,
+    'help': 'Max number of reads for which to run prediction simultaneously, '
+            'decreasing may resolve memory issues [default: 4]'
+})
+
 model_weights = ('--model-weights', {
     'type': str,
     'required': False,
@@ -272,11 +286,12 @@ def get_training_parser():
         parser.add_argument(arg[0], **arg[1])
     return parser
 
+
 def get_run_inference_parser():
     parser = argparse.ArgumentParser(description='Start up inference routine and watch a fast5 directory for reads.')
-    for arg in (fast5_in, out_dir, model, inference_mode):
+    for arg in (fast5_in, out_dir, model, inference_mode, mem, batch_size):
         parser.add_argument(arg[0], **arg[1])
-    parser.add_argument('--continuous-nn', action='store_true',help='Used RNN can handle continuous reads.')
+    # parser.add_argument('--continuous-nn', action='store_true',help='Used RNN can handle continuous reads.')
     return parser
 
 
diff --git a/inference/ReadTable.py b/inference/ReadTable.py
index ae0a187b45a4de3af736c0947aa66299cd432b3a..a09c0fce249fe3b841d1853db93dca206668883a 100644
--- a/inference/ReadTable.py
+++ b/inference/ReadTable.py
@@ -32,7 +32,7 @@ class ReadTable(object):
         manager_process.start()
         return manager_process
 
-    def get_read_to_predict(self, batch_size=1):
+    def get_read_to_predict(self, batch_size):
         """Get read and a kmer for which to scan the read file
 
         :return: Tuple of: (path to read fast5 file,
diff --git a/inference/run_inference.py b/inference/run_inference.py
index fa5d9641165c9c46aca148fca137cc72a6191ed2..74eacf6fce2aa13af7552b176fec7ee2f1c3a42d 100644
--- a/inference/run_inference.py
+++ b/inference/run_inference.py
@@ -2,6 +2,7 @@ import sys, signal, os, h5py
 from pathlib import Path
 from datetime import datetime
 import tensorflow as tf
+from tensorflow.python.saved_model import tag_constants
 import numpy as np
 
 sys.path.append(f'{list(Path(__file__).resolve().parents)[1]}')
@@ -19,8 +20,9 @@ def main(args):
     tf.config.set_soft_device_placement(True)
 
     gpus = tf.config.experimental.list_physical_devices('GPU')
-    tf.config.experimental.set_virtual_device_configuration(gpus[0], [
-        tf.config.experimental.VirtualDeviceConfiguration(memory_limit=1024)])
+    if args.mem > 0:
+        tf.config.experimental.set_virtual_device_configuration(gpus[0], [
+            tf.config.experimental.VirtualDeviceConfiguration(memory_limit=args.mem)])
     # if gpus:
     #     try:
     #         # Currently, memory growth needs to be the same across GPUs
@@ -33,10 +35,20 @@ def main(args):
     #         print(e)
 
     print('Loading model...')
-    mod = tf.keras.models.load_model(args.model)
-    with h5py.File(args.model, 'r') as fh:
-        model_type = fh.attrs['model_type']
-        kmer_list = fh.attrs['kmer_list'].split(',')
+
+    if args.model.endswith('.h5'):
+        mod = tf.keras.models.load_model(args.model)
+        with h5py.File(args.model, 'r') as fh:
+            model_type = fh.attrs['model_type']
+            kmer_list = fh.attrs['kmer_list'].split(',')
+    else:  # assume its a trt model
+        mod_dir = args.model
+        if not mod_dir.endswith('/'): mod_dir += '/'
+        _mod = tf.saved_model.load(args.model, tags=[tag_constants.SERVING])
+        mod = _mod.signatures['serving_default']
+        with open(f'{mod_dir}baseless_params.txt', 'r') as fh:
+            model_type, km = fh.read().split('\n')
+            kmer_list = km.split(',')
     print(f'Done! Model type is {model_type}')
     pos_reads_dir = parse_output_path(args.out_dir + 'pos_reads')
 
@@ -70,7 +82,7 @@ def main(args):
     start_time = datetime.now()
     # Start inference loop
     while end_condition():
-        read_id, read = read_table.get_read_to_predict()
+        read_id, read = read_table.get_read_to_predict(batch_size=args.batch_size)
         if read is None: continue
         pred = mod(read).numpy()
         if abundance_mode: