Skip to content
Snippets Groups Projects
Commit fdfba74c authored by Carlos de Lannoy's avatar Carlos de Lannoy
Browse files

better timeseries plot

parent 338de017
No related branches found
No related tags found
No related merge requests found
......@@ -12,7 +12,7 @@ import tarfile
from pathlib import Path
from os.path import basename, splitext
from glob import glob
from math import nan
from math import nan, log
from pathlib import Path
from statistics import median
from tempfile import TemporaryDirectory
......@@ -172,21 +172,28 @@ def safe_cursor(conn, comm, read_only=True, retries=1000):
raise TimeoutError('writing to sql table failed')
def plot_timeseries(raw, base_labels, posterior, y_hat, start=0, nb_classes=2):
def plot_timeseries(raw, base_labels, posterior, y_hat, start=0, nb_classes=2, reflines=(0.90, 0.95, 0.99)):
nb_points = y_hat.size
posterior = posterior - posterior.min() + raw.min()
posterior = posterior / posterior.max() * raw.max()
# Main data source
source = ColumnDataSource(dict(
scaling = np.abs(raw.max() - raw.min()) / np.abs(posterior.max() - posterior.min())
posterior = posterior * scaling
translation = - posterior.min() + raw.min()
posterior = posterior + translation
reflines = np.array([log(rl / (1-rl)) for rl in reflines])
reflines = reflines * scaling + translation
source_dict = dict(
raw=raw[:nb_points],
posterior=posterior,
event=list(range(len(y_hat))),
cat=y_hat,
cat_height=np.repeat(np.mean(raw), len(y_hat))
))
)
for ri, r in enumerate(reflines):
source_dict[f'r{ri}'] = np.repeat(r, len(posterior))
# Main data source
source = ColumnDataSource(source_dict)
# Base labels stuff
base_labels_condensed = [base_labels[0]]
......@@ -211,7 +218,7 @@ def plot_timeseries(raw, base_labels, posterior, y_hat, start=0, nb_classes=2):
y= np.repeat(raw.max(), len(bl_xcoords))
))
base_labels_labelset = LabelSet(x='x', y='y',text='base_labels',
base_labels_labelset = LabelSet(x='x', y='y', text='base_labels',
source=bl_source,
text_baseline='middle',
angle=0.25*pi)
......@@ -241,6 +248,8 @@ def plot_timeseries(raw, base_labels, posterior, y_hat, start=0, nb_classes=2):
ts_plot.add_layout(base_labels_labelset)
ts_plot.line(x='event', y='raw', source=source)
ts_plot.line(x='event', y='posterior', color='red', source=source)
for i in range(len(reflines)):
ts_plot.line(x='event', y=f'r{i}', color='grey', source=source)
ts_plot.plot_width = 1500
ts_plot.plot_height = 500
ts_plot.x_range = Range1d(start, start+100)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment