Skip to content
Snippets Groups Projects
Commit 7f87d75b authored by Noordijk, Ben's avatar Noordijk, Ben
Browse files

Added file used for development

parent 2ecb99bb
No related branches found
No related tags found
1 merge request!5compare_accuracy.py now saves confusion matrices and can be called on the...
......@@ -129,7 +129,10 @@ class NeuralNetwork(object):
def predict(self, x, clean_signal=True, return_probs=False):
"""Given sequences input as x, predict if they contain target k-mer.
Assumes the sequence x is a read that has been normalised,
but not cut into smaller chunks
but not cut into smaller chunks.
Function is mainly written to be called from train_nn.py.
Not for final inference.
:param x: Squiggle as numeric representation
:type x: np.ndarray
......
import h5py
from db_building.TrainingRead import Read, TrainingRead
from pathlib import Path
from inference.InferenceModel import InferenceModel
from nns.Cnn_test import NeuralNetwork
import tensorflow.keras.backend as K
import numpy as np
import tensorflow as tf
import time
from nns.keras_metrics_from_logits import precision, recall, binary_accuracy
def main():
# Hunt for positive read
input_length = 250
stride = input_length // 2
target_kmer = 'AGGAGAGT'
# for file in Path('/home/noord087/lustre_link/HoiCarlos/16Sreads_mockcommunity/demultiplexed_reads/files_for_initial_training/test').iterdir():
# print(f'Scanning {file}')
# with h5py.File(file, 'r') as h5_file:
# try:
# train_read = TrainingRead(h5_file, 'median',
# 'Analyses/RawGenomeCorrected_000', 8)
# if [i for i in train_read.condensed_events
# if i[0] == target_kmer]:
# print(f"found in {file}")
# break
# except KeyError as e:
# print('Got keyerror, continuing')
#
# return
pos_read_path = Path(
'/home/noord087/lustre_link/HoiCarlos/16Sreads_mockcommunity/demultiplexed_reads/files_for_initial_training/test/L0144169_20181212_FAK22428_MN19628_sequencing_run_16Srhizhome_2_99947_read_129882_ch_413_strand.fast5'
)
neg_read_path = Path(
'/home/noord087/lustre_link/HoiCarlos/16Sreads_mockcommunity/demultiplexed_reads/files_for_initial_training/test/L0144169_20181212_FAK22428_MN19628_sequencing_run_16Srhizosphere_1_66037_read_11482_ch_104_strand.fast5')
with h5py.File(pos_read_path, 'r') as f:
pos_read = Read(f, 'median')
with h5py.File(neg_read_path, 'r') as f:
neg_read = Read(f, 'median')
split_pos_read = pos_read.get_split_raw_read(input_length, stride=input_length)
split_neg_read = neg_read.get_split_raw_read(input_length, stride=input_length)
def loss_fun(y_true,
y_pred): # just dummy to satisfy the stupid thing being there
msk = np.zeros(100, dtype=bool)
msk[50] = True
y_pred_single = tf.boolean_mask(y_pred, msk, axis=1)
return K.binary_crossentropy(K.cast(y_true, K.floatx()), y_pred_single,
from_logits=True)
cnn = tf.keras.models.load_model(
'/lustre/BIF/nobackup/noord087/mscthesis/baseless/baseless_250_width_uncentered_kmer/nns/AGGAGAGT/nn.h5',
custom_objects={'precision': precision,
'recall': recall,
'binary_accuracy': binary_accuracy,
'loss_fun': loss_fun})
start = time.time()
pos_posteriors = cnn.predict(split_pos_read)
neg_posteriors = cnn.predict(split_neg_read)
print(f'.predict method: {time.time() - start} seconds')
start = time.time()
pos_posteriors = cnn(split_pos_read)
neg_posteriors = cnn(split_neg_read)
print(f'Direct call: {time.time() - start} seconds')
print(f'{np.sum(pos_posteriors > 0)=}')
print(f'{np.sum(neg_posteriors > 0)=}')
if __name__ == '__main__':
main()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment