Skip to content
Snippets Groups Projects
Commit 841ebff8 authored by Noordijk, Ben's avatar Noordijk, Ben
Browse files

Save confusion matrix plots

parent cb35d96e
No related branches found
No related tags found
No related merge requests found
......@@ -2,9 +2,10 @@ import argparse
from pathlib import Path
from itertools import chain
from multiprocessing import Pool
import pickle
import pandas as pd
from sklearn.metrics import confusion_matrix, f1_score, accuracy_score
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay, f1_score, accuracy_score
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
......@@ -154,7 +155,7 @@ def parse_squigglenet_output(in_dir, ground_truth):
y_flipped = 1 - y_pred
f1 = f1_score(y_true, y_flipped)
accuracy = accuracy_score(y_true, y_flipped)
cm = confusion_matrix(y_true, y_pred)
cm = confusion_matrix(y_true, y_flipped)
return tool, target_species, fold, accuracy, f1, cm
......@@ -203,11 +204,15 @@ def main(args):
# Create confusion matrices
cm_series = df.groupby(['tool', 'species'])['confusion matrix'].aggregate(
np.sum)
# TODO save to file here
for tool, dataframe in cm_series.groupby(level=0):
print(tool)
print(dataframe)
with open(args.out_dir / f'confusion_matrices.pickle', 'wb') as f:
pickle.dump(cm_series, f)
# Confusion matrix out directory
cm_out_dir = args.out_dir / 'confusion_matrices'
cm_out_dir.mkdir(exist_ok=True)
for (tool, species), cm in cm_series.iteritems():
ConfusionMatrixDisplay(cm).plot()
plt.savefig(cm_out_dir / f'{tool}_{species}.png')
filename = 'all_algorithms_performance'
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment