Skip to content
Snippets Groups Projects
Commit 4dd89189 authored by Noordijk, Ben's avatar Noordijk, Ben
Browse files

Some debug scripts modified / added

parent e5cbd555
No related branches found
No related tags found
No related merge requests found
from nns.keras_metrics_from_logits import precision, recall, binary_accuracy
import pandas as pd
import random
from pathlib import Path
import h5py
from inference.InferenceModel import InferenceModel
......@@ -21,22 +22,33 @@ def loss_fun(y_true, y_pred): # just dummy to satisfy the stupid thing being th
return K.binary_crossentropy(K.cast(y_true, K.floatx()), y_pred_single,
from_logits=True)
def precision_recall_at_threshold(threshold, ax):
def precision_recall_at_threshold(threshold, target_model, target_name, ax):
"""
:param threshold: Logit threshold to use
:param target_model: Path to target model
:param target_name: Name of target species in ground_truth csv
:param ax: ax to plot on
:return:
"""
stride = 125
num_of_reads_to_check = 100
num_of_reads_to_check = 500
ground_truths_list = []
frac_of_kmers_found_list = []
model_path = '/lustre/BIF/nobackup/noord087/mscthesis/baseless/baseless_250_hyperparam/e_coli_from16s_compiled.tar'
mod = InferenceModel(model_path, threshold=threshold)
mod = InferenceModel(target_model, threshold=threshold)
print(f'{len(mod.kmers)=}')
ground_truth_df = pd.read_csv(
'/lustre/BIF/nobackup/noord087/HoiCarlos/16Sreads_mockcommunity/ground_truth_with_read_id_and_perc_id.csv')
start = time.time()
for counter, file in enumerate(
Path('/home/noord087/lustre_link/HoiCarlos/16Sreads_mockcommunity/demultiplexed_reads/files_for_initial_training/test').iterdir()):
file_info = ground_truth_df[ground_truth_df['file name'] == file.name]
read_is_target_species = 'Escherichia coli' in str(
read_is_target_species = target_name in str(
file_info['species'])
print(f"Target in read: {read_is_target_species}")
# if read_is_target_species and random.random() > 0.9:
# # Reduce to 1/10 of what is normal
# continue
with h5py.File(file, 'r') as f:
raw_read_split_reads = Read(f, 'median').get_split_raw_read(250,
stride)
......@@ -57,9 +69,14 @@ def precision_recall_at_threshold(threshold, ax):
def precision_recall_wrapper():
# # Now doing it for rhodobacter
# model_path = '/lustre/BIF/nobackup/noord087/mscthesis/baseless/baseless_250_hyperparam/r_sphaeroides_from16s_compiled.tar'
# target_name = 'Rhodobacter sphaeroides'
model_path = '/lustre/BIF/nobackup/noord087/mscthesis/baseless/baseless_250_hyperparam/e_coli_from16s_compiled.tar'
target_name = 'Escherichia coli'
fig, ax = plt.subplots()
for cutoff in np.arange(3, 7):
precision_recall_at_threshold(cutoff, ax=ax)
for cutoff in np.arange(4, 6):
precision_recall_at_threshold(cutoff, model_path, target_name, ax=ax)
plt.show()
......
import pandas as pd
from pathlib import Path
from inference.compile_model import reverse_complement
kmer_freqs_df = pd.read_parquet(Path(f'{__file__}/../data/ncbi_16S_bacteria_archaea_kmer_counts.parquet').resolve())
kmer_path = '/lustre/BIF/nobackup/noord087/mscthesis/baseless/baseless_250_hyperparam/kmers_in_dbs.txt'
ground_truth_df = pd.read_csv(
'/lustre/BIF/nobackup/noord087/HoiCarlos/16Sreads_mockcommunity/ground_truth_with_read_id_and_perc_id.csv')
with open(kmer_path, 'r') as f:
kmers = f.read().splitlines()
kmers.extend([reverse_complement(kmer) for kmer in kmers])
valid_kmers = [kmer for kmer in kmers if kmer in kmer_freqs_df.columns]
df_kmers_selected = kmer_freqs_df[valid_kmers]
# TODO convert between species in ground truth and species in kmers_freq_df
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment