Skip to content
Snippets Groups Projects
Commit fdfba74c authored by Carlos de Lannoy's avatar Carlos de Lannoy
Browse files

better timeseries plot

parent 338de017
Branches
Tags
No related merge requests found
...@@ -12,7 +12,7 @@ import tarfile ...@@ -12,7 +12,7 @@ import tarfile
from pathlib import Path from pathlib import Path
from os.path import basename, splitext from os.path import basename, splitext
from glob import glob from glob import glob
from math import nan from math import nan, log
from pathlib import Path from pathlib import Path
from statistics import median from statistics import median
from tempfile import TemporaryDirectory from tempfile import TemporaryDirectory
...@@ -172,21 +172,28 @@ def safe_cursor(conn, comm, read_only=True, retries=1000): ...@@ -172,21 +172,28 @@ def safe_cursor(conn, comm, read_only=True, retries=1000):
raise TimeoutError('writing to sql table failed') raise TimeoutError('writing to sql table failed')
def plot_timeseries(raw, base_labels, posterior, y_hat, start=0, nb_classes=2): def plot_timeseries(raw, base_labels, posterior, y_hat, start=0, nb_classes=2, reflines=(0.90, 0.95, 0.99)):
nb_points = y_hat.size nb_points = y_hat.size
scaling = np.abs(raw.max() - raw.min()) / np.abs(posterior.max() - posterior.min())
posterior = posterior - posterior.min() + raw.min() posterior = posterior * scaling
posterior = posterior / posterior.max() * raw.max() translation = - posterior.min() + raw.min()
posterior = posterior + translation
# Main data source
source = ColumnDataSource(dict( reflines = np.array([log(rl / (1-rl)) for rl in reflines])
reflines = reflines * scaling + translation
source_dict = dict(
raw=raw[:nb_points], raw=raw[:nb_points],
posterior=posterior, posterior=posterior,
event=list(range(len(y_hat))), event=list(range(len(y_hat))),
cat=y_hat, cat=y_hat,
cat_height=np.repeat(np.mean(raw), len(y_hat)) cat_height=np.repeat(np.mean(raw), len(y_hat))
)) )
for ri, r in enumerate(reflines):
source_dict[f'r{ri}'] = np.repeat(r, len(posterior))
# Main data source
source = ColumnDataSource(source_dict)
# Base labels stuff # Base labels stuff
base_labels_condensed = [base_labels[0]] base_labels_condensed = [base_labels[0]]
...@@ -211,7 +218,7 @@ def plot_timeseries(raw, base_labels, posterior, y_hat, start=0, nb_classes=2): ...@@ -211,7 +218,7 @@ def plot_timeseries(raw, base_labels, posterior, y_hat, start=0, nb_classes=2):
y= np.repeat(raw.max(), len(bl_xcoords)) y= np.repeat(raw.max(), len(bl_xcoords))
)) ))
base_labels_labelset = LabelSet(x='x', y='y',text='base_labels', base_labels_labelset = LabelSet(x='x', y='y', text='base_labels',
source=bl_source, source=bl_source,
text_baseline='middle', text_baseline='middle',
angle=0.25*pi) angle=0.25*pi)
...@@ -241,6 +248,8 @@ def plot_timeseries(raw, base_labels, posterior, y_hat, start=0, nb_classes=2): ...@@ -241,6 +248,8 @@ def plot_timeseries(raw, base_labels, posterior, y_hat, start=0, nb_classes=2):
ts_plot.add_layout(base_labels_labelset) ts_plot.add_layout(base_labels_labelset)
ts_plot.line(x='event', y='raw', source=source) ts_plot.line(x='event', y='raw', source=source)
ts_plot.line(x='event', y='posterior', color='red', source=source) ts_plot.line(x='event', y='posterior', color='red', source=source)
for i in range(len(reflines)):
ts_plot.line(x='event', y=f'r{i}', color='grey', source=source)
ts_plot.plot_width = 1500 ts_plot.plot_width = 1500
ts_plot.plot_height = 500 ts_plot.plot_height = 500
ts_plot.x_range = Range1d(start, start+100) ts_plot.x_range = Range1d(start, start+100)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment