Skip to content
Snippets Groups Projects
Commit f4a9f6e2 authored by Noordijk, Ben's avatar Noordijk, Ben
Browse files

Updated docstrings and removed unused function

parent 5f8e3fcc
No related branches found
No related tags found
No related merge requests found
......@@ -10,6 +10,7 @@ batch_norm: 0
batch_size: 24
dropout_keep_prob: 0.9
eps_per_kmer_switch: 25
max_sequence_length: 5000
filter_width: 1000
filters: 5
kernel_size: 19
......
......@@ -106,7 +106,7 @@ out_dir = ('--out-dir', {
out_model = ('--out-model', {
'type': str,
'required': True,
'help': 'Produced model'
'help': 'Path to tar file in which to save produced model'
})
ckpt_model = ('--ckpt-model', {
......
......@@ -128,17 +128,6 @@ def load_db(db_dir, read_only=False):
return db, squiggles
def load_model(mod_fn):
with TemporaryDirectory() as td:
with tarfile.open(mod_fn) as fh:
fh.extractall(td)
out_dict = {}
for mn in Path(td).iterdir():
out_dict[splitext(mn.name)[0]] = tf.keras.models.load_model(mn,
custom_objects={'precision': precision,
'recall': recall,
'binary_accuracy': binary_accuracy})
return out_dict
# def set_logfolder(brnn_object, param_base_name, parent_dir, epoch_index):
# """
# Create a folder to store tensorflow metrics for tensorboard and set it up for a specific session.
......
import tarfile
import numpy as np
import tensorflow as tf
......@@ -11,8 +10,15 @@ from nns.keras_metrics_from_logits import precision, recall, binary_accuracy
class InferenceModel(object):
"""Composite model of multiple k-mer recognising NNs."""
def __init__(self, mod_fn, batch_size=32):
"""Construct model
:param mod_fn: Path to tar file that contains tarred model output by compile_model.py
:type mod_fn: str
:param batch_size: Size of batch
"""
self.load_model(mod_fn)
self.batch_size = batch_size
......@@ -26,8 +32,12 @@ class InferenceModel(object):
# return True
# return np.any(self._model_dict[kmer].predict(read[-1])[:last_idx] > 0.0)
def load_model(self, mod_fn):
"""Load compiled model from path to saved precompiled model
:param mod_fn: Path to tarred file of compiled model
:return:
"""
with TemporaryDirectory() as td:
with tarfile.open(mod_fn) as fh:
fh.extractall(td)
......@@ -37,6 +47,9 @@ class InferenceModel(object):
custom_objects={'precision': precision,
'recall': recall,
'binary_accuracy': binary_accuracy})
# Dictionary with kmer string as key and keras.Sequential as value
self._model_dict = out_dict
# List of kmers that the InferenceModel contains models for
self.kmers = list(self._model_dict)
# Length of input that should be given to each individual model
self.input_length = self._model_dict[list(self._model_dict)[0]].layers[0].input_shape[1]
......@@ -15,6 +15,16 @@ from helper_functions import numeric_timestamp, safe_cursor
class ReadTable(object):
def __init__(self, reads_dir, table_fn, pos_reads_dir, kmers, input_length, batch_size=32):
"""Table that keeps track of all reads in directory and what
kmers they contain
:param reads_dir: Directory that contains fast5 reads on which to run inference
:param table_fn: File name of this table database
:param pos_reads_dir: # TODO don't quite get what this is
:param kmers: List of kmers for which to search
:param input_length: Length of signal that should be passed to model as input
:param batch_size: Batch size for the model
"""
Path(table_fn).unlink(missing_ok=True)
self.input_length = input_length
self.batch_size = batch_size
......@@ -38,6 +48,12 @@ class ReadTable(object):
return manager_process
def get_read_to_predict(self):
"""Get read and a kmer for which to scan the read file
:return: Tuple of: (path to read fast5 file,
Read object for inference (split to desired input length),
k-mer as string for which to scna)
"""
tup_list = safe_cursor(self.conn, "SELECT * FROM read_table ORDER BY ROWID ASC LIMIT 10")
if len(tup_list) == 0: return None, None, None
for tup in tup_list:
......
......@@ -19,6 +19,7 @@ def main(args):
nn_directory = f'{__location__}/../16S_db/'
else:
raise ValueError('Either provide --nn-directory or --target-16S')
# Find models for which k-mer is available
available_mod_kmers = {pth.name: str(pth) for pth in Path(nn_directory).iterdir() if pth.is_dir()}
......@@ -58,4 +59,5 @@ def main(args):
# tar the model
with tarfile.open(args.out_model, 'w') as fh:
for cl in Path(tdo).iterdir(): fh.add(cl, arcname=os.path.basename(cl))
for cl in Path(tdo).iterdir():
fh.add(cl, arcname=os.path.basename(cl))
......@@ -3,7 +3,6 @@ import numpy as np
from helper_functions import parse_output_path, parse_input_path
from inference.ReadTable import ReadTable
from helper_functions import load_model
from inference.InferenceModel import InferenceModel
def main(args):
......
......@@ -17,6 +17,10 @@ class NeuralNetwork(object):
:type target: str
:param kernel_size: Kernel size of CNN
:type kernel_size: int
:param max_sequence_length: Maximum length of read that can be used to
infer from. If read is longer than this length, don't use it.
Currently this is unimplemented
:type max_sequence_length: int
:param weights: Path to h5 file that contains model weights, optional
:param batch_size: Batch size to use during training
:param threshold: Assign label to TRUE if probability above this threshold
......@@ -57,7 +61,8 @@ class NeuralNetwork(object):
"""
if weights:
self.model = tf.keras.models.load_model(weights, custom_objects={
'precision': precision, 'recall': recall, 'binary_accuracy': binary_accuracy})
'precision': precision, 'recall': recall,
'binary_accuracy': binary_accuracy})
print('Successfully loaded weights')
return
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment