Commit ddcc808f authored by Noordijk, Ben's avatar Noordijk, Ben
Browse files

Added analysis for guppy accuracy over time

parent f86c2cf2
from compare_benchmark_performance.compare_accuracy_per_read \
import parse_sequencing_summary, get_performance_metrics_from_predictions
import argparse
from pathlib import Path
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
def calculate_guppy_performance_over_time(df: pd.DataFrame,
read_interval: int = 10):
"""Calculate guppy performance for an increasing number of reads.
Needs a dataframe with columns y_true and y_pred.
Where y_pred indicates if the read is assigned a species by guppy, and
y_true represents the ground truth identity of the read"""
# Shuffle dataframe
df = df.sample(frac=1).reset_index(drop=True)
result_dict = {"nr_reads": [],
"f1_score": [],
"acc_score": []}
for read_cutoff in range(read_interval, len(df), read_interval):
# Get subset of all reads
completed_reads = df.iloc[:read_cutoff]
y_true = completed_reads['y_true']
y_pred = completed_reads['y_pred']
# Calculate performance on this subset of reads
acc, _, f1 = get_performance_metrics_from_predictions(y_pred, y_true)
result_df = pd.DataFrame(result_dict)
return result_df
def main(args):
# Workflow for analysis of guppy on bacterial classification
pd.options.display.width = 0
ground_truth = pd.read_csv(args.ground_truth)
guppy_files = args.benchmark_path.glob("guppy/*/fold?/sequencing_summary.txt")
result_df_list = []
if args.input_performance_csv:
# If precalculated csv is provided, use this.
result_df = pd.read_csv(args.input_performance_csv)
# No precalculated performance CSV was found: calculate performance here
for i, guppy_file in enumerate(guppy_files):
target_species =[-3].replace('_', ' ')
if 'gingivalis' in target_species:
# Was not present, so can be ignored
# This is the bacteria-specific part. You should replace this
# with something that works for you.
df = parse_sequencing_summary(guppy_file, ground_truth,
temp_result_df = calculate_guppy_performance_over_time(df)
temp_result_df['species'] = target_species
# Join results of all k-fold CVs for all species
result_df = pd.concat(result_df_list, ignore_index=True)
result_df.to_csv(args.out_dir / 'guppy_performance_over_time.csv')
# Plot one lineplot where different species are different hues
sns.lineplot(data=result_df, x='nr_reads', y='f1_score', hue='species')
plt.savefig(args.out_dir / 'guppy_accuracy_lineplot_hue.svg')
# Print one separate lineplot in a grid for each different species
all_species = result_df.species.unique()
sns.relplot(data=result_df, x='nr_reads', y='f1_score', col='species',
kind='line', col_wrap=5, col_order=all_species)
plt.savefig(args.out_dir / 'guppy_accuracy_lineplot_grid.svg')
if __name__ == '__main__':
parser = argparse.ArgumentParser(description="""Plot accuracy of
guppy for different number of reads it has seen""")
help='Path to folder that should contain a folder '
'called "guppy" which contains sequencing summary',
required=True, type=Path)
help='Path to csv with ground truth labels. '
'It is output by',
required=True, type=Path)
help='Path to csv that contains model performance as '
'output by earlier run of this script. '
'If provided, the script will not recalculate '
'all performance statistics manually but '
'read them from this csv.',
required=False, type=Path)
help='Directory in which to save the figures and'
' csv with model performance',
required=True, type=Path)
args = parser.parse_args()
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment