diff --git a/bin/caretta-app b/bin/caretta-app
index 3496041d0f3787d89851bd73e6fcb6a74f858428..56801f86c32654e40005f44361443aa5c2d0302e 100755
--- a/bin/caretta-app
+++ b/bin/caretta-app
@@ -1,508 +1,73 @@
 #!/usr/bin/env python3
 
-import base64
-import os
-import pickle
 from pathlib import Path
-from zipfile import ZipFile
 
 import dash
-import dash_bio as dashbio
-import dash_core_components as dcc
-import dash_html_components as html
-import fire
-import flask
-import numpy as np
+import typer
 from cryptography.fernet import Fernet
 
-from caretta import helper, msa_numba
-from caretta.pfam import PdbEntry
+from caretta import helper, multiple_alignment
+from caretta.app import app_helper, app_layout, app_callbacks
 
-key = Fernet.generate_key()
-suite = Fernet(key)
+# for compressing and decompressing files
+KEY = Fernet.generate_key()
+SUITE = Fernet(KEY)
 
-if not Path("static").exists():
-    Path("static").mkdir()
+STATIC = Path("static")
+if not STATIC.exists():
+    STATIC.mkdir()
 
-external_stylesheets = ["https://cdnjs.cloudflare.com/ajax/libs/skeleton/2.0.4/skeleton.css"]
+ADDITION = ""
 
-app = dash.Dash(__name__, external_stylesheets=external_stylesheets)
-
-
-def feature_heatmap(data, zeros=False):
-    if zeros:
-        length = 2
-        z = np.zeros((length, length))
-    else:
-        keys = list(data.keys())
-        length = len(data[keys[0]])
-        z = np.zeros((len(data), length))
-        for i in range(len(data)):
-            for j in range(length):
-                z[i, j] = data[keys[i]][j]
-    return dict(data=[dict(z=z, type="heatmap", showscale=False)], layout=dict(margin=dict(l=25, r=25, t=25, b=25)))
-
-
-def feature_line(features, alignment):
-    length = len(features[list(features.keys())[0]])
-    z = np.zeros((len(features), length))
-    keys = list(features.keys())
-    for i in range(len(features)):
-        for j in range(length):
-            if alignment[keys[i]][j] is not "-":
-                z[i, j] = features[keys[i]][j]
-            else:
-                z[i, j] = np.NaN
-    y = np.array([np.nanmean(z[:, x]) for x in range(z.shape[1])])
-    y_se = np.array([np.nanstd(z[:, x]) / np.sqrt(z.shape[1]) for x in range(z.shape[1])])
-
-    data = [dict(y=list(y + y_se) + list(y - y_se)[::-1], x=list(range(length)) + list(range(length))[::-1],
-                 fillcolor="lightblue", fill="toself", type="scatter", mode="lines", name="Standard error",
-                 line=dict(color='lightblue')),
-            dict(y=y, x=np.arange(length), type="scatter", mode="lines", name="Mean",
-                 line=dict(color='blue'))]
-    return dict(data=data, layout=dict(legend=dict(x=0.5, y=1.2), margin=dict(l=25, r=25, t=25, b=25)))
-
-
-def structure_plot(coord_dict):
-    data = []
-    for k, v in coord_dict.items():
-        x, y, z = v[:, 0], v[:, 1], v[:, 2]
-        data.append(dict(
-            x=x,
-            y=y,
-            z=z,
-            mode='lines',
-            type='scatter3d',
-            text=None,
-            name=str(k),
-            line=dict(
-                width=3,
-                opacity=0.8)))
-    layout = dict(margin=dict(l=20, r=20, t=20, b=20), clickmode='event+select',
-                  scene=dict(xaxis=dict(visible=False, showgrid=False, showline=False),
-                             yaxis=dict(visible=False, showgrid=False, showline=False),
-                             zaxis=dict(visible=False, showgrid=False, showline=False)))
-    return dict(data=data, layout=layout)
-
-
-def check_gap(sequences, i):
-    for seq in sequences:
-        if seq[i] == "-":
-            return True
-    return False
-
-
-def get_feature_z(features, alignments):
-    core_indices = []
-    sequences = list(alignments.values())
-    for i in range(len(sequences[0])):
-        if not check_gap(sequences, i):
-            core_indices.append(i)
-        else:
-            continue
-    return {x: features[x][np.arange(len(sequences[0]))] for x in features}
-
-
-def write_as_csv(feature_dict, file_name):
-    with open(file_name, "w") as f:
-        for protein_name, features in feature_dict.items():
-            f.write(";".join([protein_name] + [str(x) for x in list(features)]) + "\n")
-
-
-def write_as_csv_all_features(feature_dict, file_name):
-    with open(file_name, "w") as f:
-        for protein_name, feature_dict in feature_dict.items():
-            for feature_name, feature_values in feature_dict.items():
-                f.write(";".join([protein_name, feature_name] + [str(x) for x in list(feature_values)]) + "\n")
-
-
-box_style = {"box-shadow": "1px 3px 20px -4px rgba(0,0,0,0.75)",
-             "border-radius": "5px", "background-color": "#f9f7f7"}
-
-box_style_lg = {"top-margin": 25,
-                "border-style": "solid",
-                "border-color": "rgb(187, 187, 187)",
-                "border-width": "1px",
-                "border-radius": "5px",
-                "background-color": "#edfdff"}
-
-box_style_lr = {"top-margin": 25,
-                "border-style": "solid",
-                "border-color": "rgb(187, 187, 187)",
-                "border-width": "1px",
-                "border-radius": "5px",
-                "background-color": "#ffbaba"}
-
-
-def compress_object(raw_object):
-    return base64.b64encode(suite.encrypt(pickle.dumps(raw_object, protocol=4))).decode("utf-8")
-
-
-def decompress_object(compressed_object):
-    return pickle.loads(suite.decrypt(base64.b64decode(compressed_object)))
-
-
-def protein_to_aln_index(protein_index, aln_seq):
-    n = 0
-    for i in range(len(aln_seq)):
-        if protein_index == n:
-            return i
-        elif aln_seq[i] == "-":
-            pass
-        else:
-            n += 1
-
-
-def aln_index_to_protein(alignment_index, alignment):
-    res = dict()
-    for k, v in alignment.items():
-        if v[alignment_index] == "-":
-            res[k] = None
-        else:
-            res[k] = alignment_index - v[:alignment_index].count("-")
-    return res
+external_stylesheets = [
+    "https://cdnjs.cloudflare.com/ajax/libs/skeleton/2.0.4/skeleton.css"
+]
 
+# server = Flask(__name__)
+app = dash.Dash(
+    __name__,
+    external_stylesheets=external_stylesheets,
+    url_base_pathname="/",
+)
 
 introduction_text = """Caretta generates multiple structure alignments for a set of input proteins and displays the alignment, the superposed proteins,
 and aligned structural features. All the generated data can further be exported for downstream use. 
 If you have to align more than 100 proteins your browser may lag, please use the command-line tool instead 
 (See https://git.wageningenur.nl/durai001/caretta for instructions)."""
 
-pdb_selection_text = dcc.Markdown("""Possible input options are: 
+input_text = "Enter a folder with PDB files and click on Load Structures"
+placeholder_text = "PDB folder"
+selection_text = """Possible input options are: 
 * Path to a folder containing files
 * List of files (one on each line)
 * List of PDB IDs 
-""")
-structure_alignment_text = """Click on a residue to see its position on the feature alignment in the next section."""
-feature_alignment_text = """Click on a position in the feature alignment to see the corresponding residues in the previous section."""
-
-app.layout = html.Div(children=[html.Div(html.Div([html.H1("Caretta",
-                                                           style={"text-align": "center"}),
-                                                   html.H3("a multiple protein structure alignment and feature extraction suite",
-                                                           style={"text-align": "center"}),
-                                                   html.P(introduction_text, style={"text-align": "left"})], className="row"),
-                                         className="container"),
-                                html.Div([html.Br(),
-                                          html.P(children="", id="feature-data",
-                                                 style={"display": "none"}),
-                                          html.P(children=compress_object(0), id="button1",
-                                                 style={"display": "none"}),
-                                          html.P(children=compress_object(0), id="button2",
-                                                 style={"display": "none"}),
-                                          html.P(children="", id="alignment-data",
-                                                 style={"display": "none"}),
-                                          html.P(children="", id="msa-class",
-                                                 style={"display": "none"}),
-                                          html.Div([html.H3("Choose Structures", className="row", style={"text-align": "center"}),
-                                                    html.P("Input PDB files and click on Load Structures.", className="row"),
-                                                    html.Div([
-                                                        html.Div(
-                                                            dcc.Textarea(placeholder="PDB files", value="", id="custom-folder", required=True),
-                                                            className="four columns"),
-                                                        html.P(pdb_selection_text, className="four columns"),
-                                                        html.Button("(Re)load Structures", className="four columns", id="load-button")
-                                                    ], className="row"),
-                                                    html.Div([html.Div(dcc.Dropdown(placeholder="Gap open penalty (1.0)",
-                                                                                    options=[{"label": np.round(x, decimals=2), "value": x} for x in
-                                                                                             np.arange(0, 5, 0.1)],
-                                                                                    id="gap-open"), className="four columns"),
-                                                              html.Div(dcc.Dropdown(placeholder="Gap extend penalty (0.01)",
-                                                                                    options=[{"label": np.round(x, decimals=3), "value": x} for x in
-                                                                                             np.arange(0, 1, 0.002)],
-                                                                                    id="gap-extend"),
-                                                                       className="four columns")], className="row"),
-                                                    html.Br(),
-                                                    html.Div(html.Button("Align Structures", className="twelve columns", id="align"),
-                                                             className="row"),
-                                                    dcc.Loading(id="loading-1", children=[html.Div(id="output-1", style={"text-align": "center"})],
-                                                                type="default")
-                                                    ],
-                                                   className="container"),
-                                          html.Br()], className="container", style=box_style),
-                                html.Br(),
-                                html.Div(children=[html.Br(),
-                                                   html.H3("Sequence alignment", className="row", style={"text-align": "center"}),
-                                                   html.Div(
-                                                       html.P("", className="row"),
-                                                       className="container"),
-                                                   html.Div([html.Button("Download alignment", className="row", id="alignment-download"),
-                                                             html.Div(children="", className="row", id="fasta-download-link")],
-                                                            className="container"),
-                                                   html.Div(html.P(id="alignment", className="twelve columns"), className="row")],
-                                         className="container", style=box_style),
-                                html.Br(),
-                                html.Div([html.Br(),
-                                          html.H3("Structural alignment", className="row", style={"text-align": "center"}),
-                                          html.Div(html.P(structure_alignment_text,
-                                                          className="row"), className="container"),
-                                          html.Div([html.Button("Download PDB", className="row", id="pdb-download", style={"align": "center"}),
-                                                    html.Div(children="", className="row", id="pdb-download-link")],
-                                                   className="container"),
-                                          html.Div(children=dcc.Graph(figure=feature_heatmap([[0, 0], [0, 0]], zeros=True), id="scatter3d"),
-                                                   className="row", id="aligned-proteins"), html.Br()],
-                                         className="container", style=box_style),
-                                html.Br(),
-                                html.Div(
-                                    [html.Br(), html.Div([html.Div([html.H3("Feature alignment", className="row", style={"text-align": "center"}),
-                                                                    html.P(
-                                                                        feature_alignment_text,
-                                                                        className="row"),
-                                                                    dcc.Dropdown(placeholder="Choose a feature", id="feature-selection",
-                                                                                 className="six columns"),
-                                                                    html.Button("Display feature alignment", id="feature-button",
-                                                                                className="six columns")], className="row"),
-                                                          html.Div([html.Div([html.Button("Export feature", id="export"),
-                                                                              html.Button("Export all features", id="export-all")], id="exporter"),
-                                                                    html.Div(html.P(""), id="link-field"),
-                                                                    html.Br()])], className="container"),
-
-                                     html.Div(
-                                         html.Div(dcc.Graph(figure=feature_heatmap([[0, 0], [0, 0]], zeros=True), id="feature-line"),
-                                                  id="feature-plot1"),
-                                         className="row"),
-                                     html.Div(
-                                         html.Div(dcc.Graph(figure=feature_heatmap([[0, 0], [0, 0]], zeros=True), id="heatmap"), id="feature-plot2"),
-                                         className="row")],
-                                    className="container", style=box_style),
-                                html.Br(), html.Br(), html.Div(id="testi")])
-
-
-def to_fasta_str(alignment):
-    res = []
-    for k, v in alignment.items():
-        res.append(f">{k}")
-        res.append(v)
-    return "\n".join(res)
-
-
-@app.callback(dash.dependencies.Output('fasta-download-link', 'children'),
-              [dash.dependencies.Input('alignment-download', 'n_clicks')],
-              [dash.dependencies.State("alignment-data", "children")])
-def download_alignment(clicked, data):
-    if clicked and data:
-        alignment = decompress_object(data)
-        if not alignment:
-            return ""
-        fasta = to_fasta_str(alignment)
-        fnum = np.random.randint(0, 1000000000)
-        fname = f"static/{fnum}.fasta"
-        with open(fname, "w") as f:
-            f.write(fasta)
-        return html.A(f"Download %s here" % ("alignment" + ".fasta"), href="/%s" % fname)
-    else:
-        return ""
-
-
-@app.callback(dash.dependencies.Output('pdb-download-link', 'children'),
-              [dash.dependencies.Input('pdb-download', 'n_clicks')],
-              [dash.dependencies.State("alignment-data", "children"),
-               dash.dependencies.State("msa-class", "children")])
-def download_pdb(clicked, msa_class):
-    if clicked and msa_class:
-        msa_class = decompress_object(msa_class)
-        if not msa_class:
-            return ""
-        fnum = np.random.randint(0, 1000000000)
-        msa_class.output_files.pdb_folder = Path(f"static/{fnum}_pdb")
-        msa_class.write_files(write_pdb=True, write_fasta=False, write_class=False, write_features=False)
-        fname = f"static/{fnum}_pdb.zip"
-        pdb_zip_file = ZipFile(fname, mode="w")
-        for pdb_file in Path(msa_class.output_files.pdb_folder).glob("*.pdb"):
-            pdb_zip_file.write(str(pdb_file))
-        return html.A(f"Download %s here" % ("pdbs" + ".zip"), href="/%s" % fname)
-    else:
-        return ""
-
-
-@app.callback([dash.dependencies.Output("output-1", "children"),
-               dash.dependencies.Output("alignment", "children"),
-               dash.dependencies.Output("aligned-proteins", "children"),
-               dash.dependencies.Output("feature-data", "children"),
-               dash.dependencies.Output("feature-selection", "options"),
-               dash.dependencies.Output("alignment-data", "children"),
-               dash.dependencies.Output("msa-class", "children")],
-              [dash.dependencies.Input("align", "n_clicks")],
-              [dash.dependencies.State("custom-folder", "value"),
-               dash.dependencies.State("gap-open", "value"),
-               dash.dependencies.State("gap-extend", "value")])
-def align_structures(clicked, input_pdb, gap_open, gap_extend):
-    if clicked and input_pdb:
-        pdb_entries = [PdbEntry.from_user_input(f) for f in msa_numba.parse_pdb_files(input_pdb, "static/cleaned_pdb")]
-        if not gap_open:
-            gap_open = 1
-        if not gap_extend:
-            gap_extend = 0.01
-        msa_class = msa_numba.StructureMultiple.from_pdb_files([p.get_pdb()[1] for p in pdb_entries])
-        alignment = msa_class.align(gap_open_penalty=gap_open, gap_extend_penalty=gap_extend)
-        msa_class.superpose(alignment)
-        fasta = to_fasta_str(alignment)
-        component = dashbio.AlignmentChart(
-            id='my-dashbio-alignmentchart',
-            data=fasta, showconsensus=False, showconservation=False,
-            overview=None, height=300,
-            colorscale="hydrophobicity"
-        )
-        features = {s.name: s.features for s in msa_class.structures}
-        return "", component, dcc.Graph(figure=structure_plot({s.name: s.coords for s in msa_class.structures}),
-                                        id="scatter3d"), compress_object(
-            features), [{"label": x, "value": x} for x in features[list(features.keys())[0]]], compress_object(alignment), compress_object(msa_class)
-    else:
-        return "", "", "", compress_object(np.zeros(0)), [{"label": "no alignment present", "value": "no alignment"}], compress_object(np.zeros(0)), \
-               compress_object(np.zeros(0))
-
-
-@app.callback([dash.dependencies.Output("feature-plot1", "children"),
-               dash.dependencies.Output("feature-plot2", "children")],
-              [dash.dependencies.Input("feature-button", "n_clicks")],
-              [dash.dependencies.State("feature-selection", "value"),
-               dash.dependencies.State("feature-data", "children"),
-               dash.dependencies.State("alignment-data", "children")])
-def display_feature(clicked, chosen_feature, feature_dict, aln):
-    if clicked and chosen_feature and feature_dict:
-        alignment = decompress_object(aln)
-        feature_dict = decompress_object(feature_dict)
-        chosen_feature_dict = {x: feature_dict[x][chosen_feature] for x in feature_dict}
-        aln_np = {k: helper.aligned_string_to_array(alignment[k]) for k in alignment}
-        chosen_feature_aln_dict = {x: helper.get_aligned_string_data(aln_np[x], chosen_feature_dict[x]) for x in feature_dict.keys()}
-        z = get_feature_z(chosen_feature_aln_dict, alignment)
-        component1 = dcc.Graph(figure=feature_heatmap(z), id="heatmap")
-        component2 = dcc.Graph(figure=feature_line(z, alignment), id="feature-line")
-        return component2, component1
-    else:
-        return dcc.Graph(figure=feature_heatmap([[0, 0], [0, 0]], zeros=True),
-                         id="feature-line", style={"display": "none"}), dcc.Graph(figure=feature_heatmap([[0, 0], [0, 0]], zeros=True),
-                                                                                  id="heatmap", style={"display": "none"})
-
-
-@app.callback([dash.dependencies.Output("link-field", "children"),
-               dash.dependencies.Output("exporter", "children")],
-              [dash.dependencies.Input("export", "n_clicks"),
-               dash.dependencies.Input("export-all", "n_clicks")],
-              [dash.dependencies.State("feature-selection", "value"),
-               dash.dependencies.State("feature-data", "children"),
-               dash.dependencies.State("alignment-data", "children")])
-def write_output(clicked, clicked_all, chosen_feature, feature_dict, aln):
-    if (clicked and chosen_feature and feature_dict and aln) and not clicked_all:
-        alignment = decompress_object(aln)
-        feature_dict = decompress_object(feature_dict)
-        chosen_feature_dict = {x: feature_dict[x][chosen_feature] for x in feature_dict}
-        aln_np = {k: helper.aligned_string_to_array(alignment[k]) for k in alignment}
-        chosen_feature_aln_dict = {x: helper.get_aligned_string_data(aln_np[x], chosen_feature_dict[x]) for x in
-                                   feature_dict.keys()}
-        z = get_feature_z(chosen_feature_aln_dict, alignment)
-        fnum = np.random.randint(0, 1000000000)
-        fname = f"static/{fnum}.csv"
-        write_as_csv(z, fname)
-        return html.A(f"Download %s here" % (str(fnum) + ".csv"), href="/%s" % fname), [html.Button("Export feature", id="export"),
-                                                                                        html.Button("Export all features", id="export-all")]
-    elif (clicked_all and feature_dict and aln) and not clicked:
-        feature_dict = decompress_object(feature_dict)
-        fnum = np.random.randint(0, 1000000000)
-        fname = f"static/{fnum}.csv"
-        write_as_csv_all_features(feature_dict, fname)
-        return html.A(f"Download %s here" % (str(fnum) + ".csv"), href="/%s" % fname), [html.Button("Export feature", id="export"),
-                                                                                        html.Button("Export all features", id="export-all")]
-    else:
-        return "", [html.Button("Export feature", id="export"),
-                    html.Button("Export all features", id="export-all")]
+"""
 
+app.layout = app_layout.get_layout(
+    introduction_text, input_text, placeholder_text, selection_text, SUITE
+)
 
-@app.server.route('/static/<path:path>')
-def download_file(path):
-    root_dir = os.getcwd()
-    return flask.send_from_directory(
-        os.path.join(root_dir, 'static'), path)
 
+def get_pdb_entries_from_folder(folder):
+    return [
+        app_helper.PdbEntry.from_user_input(f) for f in helper.parse_pdb_files(folder)
+    ]
 
-@app.callback([dash.dependencies.Output("feature-line", "figure"),
-               dash.dependencies.Output("scatter3d", "figure"),
-               dash.dependencies.Output("button1", "children"),
-               dash.dependencies.Output("button2", "children")],
-              [dash.dependencies.Input("scatter3d", "clickData"),
-               dash.dependencies.Input("feature-line", "clickData")],
-              [dash.dependencies.State("feature-line", "figure"),
-               dash.dependencies.State("scatter3d", "figure"),
-               dash.dependencies.State("button1", "children"),
-               dash.dependencies.State("button2", "children"),
-               dash.dependencies.State("alignment-data", "children")])
-def update_features(clickdata_3d, clickdata_feature, feature_data, scatter3d_data, button1, button2, alignment_data):
-    if feature_data and scatter3d_data and clickdata_feature and compress_object(
-            (clickdata_feature["points"][0]["pointNumber"], clickdata_feature["points"][0]["curveNumber"])) != button1:
-        alignment = decompress_object(alignment_data)
-        number_of_structures = len(alignment)
-        clickdata = clickdata_feature
-        idx = clickdata["points"][0]["pointNumber"]
-        protein_index = clickdata["points"][0]["curveNumber"]
-        aln_positions = aln_index_to_protein(idx, alignment)
-        button1 = compress_object((idx, protein_index))
-        # x, y = clickdata["points"][0]["x"], clickdata["points"][0]["y"]
-        try:
-            maxim, minim = np.max(feature_data["data"][0]["y"]), np.min(feature_data["data"][0]["y"])
-        except KeyError:
-            return feature_data, scatter3d_data, button1, button2
-        if len(feature_data["data"]) > 2:
-            feature_data["data"] = feature_data["data"][:-1]
-        feature_data["data"] += [dict(y=[minim, maxim], x=[idx, idx], type="scatter", mode="lines",
-                                      name="selected residue")]
-        if len(scatter3d_data["data"]) > number_of_structures:
-            scatter3d_data["data"] = scatter3d_data["data"][:-1]
-        to_add = []
-        for i in range(len(scatter3d_data["data"])):
-            d = scatter3d_data["data"][i]
-            k = d["name"]
-            p = aln_positions[k]
-            if p is not None:
-                x, y, z = d["x"][p], d["y"][p], d["z"][p]
-                to_add.append((x, y, z))
-            else:
-                continue
-        scatter3d_data["data"] += [dict(x=[x[0] for x in to_add],
-                                        y=[y[1] for y in to_add],
-                                        z=[z[2] for z in to_add], type="scatter3d", mode="markers",
-                                        name="selected residues")]
-        return feature_data, scatter3d_data, button1, button2
-    if feature_data and scatter3d_data and clickdata_3d and compress_object(
-            (clickdata_3d["points"][0]["pointNumber"], clickdata_3d["points"][0]["curveNumber"])) != button2:
-        alignment = decompress_object(alignment_data)
-        number_of_structures = len(alignment)
-        clickdata = clickdata_3d
-        idx = clickdata["points"][0]["pointNumber"]
-        protein_index = clickdata["points"][0]["curveNumber"]
-        button2 = compress_object((idx, protein_index))
-        gapped_sequence = list(alignment.values())[protein_index]
-        aln_index = protein_to_aln_index(idx, gapped_sequence)
-        x, y, z = clickdata["points"][0]["x"], clickdata["points"][0]["y"], clickdata["points"][0]["z"]
-        try:
-            maxim, minim = np.max(feature_data["data"][0]["y"]), np.min(feature_data["data"][0]["y"])
-        except KeyError:
-            return feature_data, scatter3d_data, button1, button2
-        if len(feature_data["data"]) > 2:
-            feature_data["data"] = feature_data["data"][:-1]
-        feature_data["data"] += [dict(y=[minim, maxim], x=[aln_index, aln_index], type="scatter", mode="lines",
-                                      name="selected_residue")]
-        if len(scatter3d_data["data"]) > number_of_structures:
-            scatter3d_data["data"] = scatter3d_data["data"][:-1]
-        scatter3d_data["data"] += [dict(y=[y], x=[x], z=[z], type="scatter3d", mode="markers",
-                                        name="selected residue")]
-        return feature_data, scatter3d_data, button1, button2
 
-    elif feature_data and scatter3d_data:
-        return feature_data, scatter3d_data, button1, button2
+app_callbacks.register_callbacks(app, get_pdb_entries_from_folder, SUITE)
 
 
-def run_server(host="0.0.0.0", port=8888):
+def run(
+    host: str = typer.Argument("0.0.0.0", help="host IP to serve the app"),
+    port: int = typer.Argument(8888, help="port"),
+):
     """
     caretta-app is the GUI of caretta, capable of aligning and visualising multiple protein structures
     and allowing extraction of aligned features such as bond angles, residue depths and fluctuations.
-    ----------
-    host
-        host ip (string)
-    port
-        port
     """
+    multiple_alignment.trigger_numba_compilation()
     app.run_server(host=host, port=port)
 
 
-if __name__ == '__main__':
-    fire.Fire(run_server)
+if __name__ == "__main__":
+    typer.run(run)
diff --git a/bin/caretta-app-demo b/bin/caretta-app-demo
old mode 100644
new mode 100755
index 60b2eaaf54ed49cee96f1f9fd07ee5640a38d753..77345d7164242356d2ebdf1a59f8b1fc6739825a
--- a/bin/caretta-app-demo
+++ b/bin/caretta-app-demo
@@ -1,556 +1,69 @@
 #!/usr/bin/env python3
 
-import base64
-import os
-import pickle
 from pathlib import Path
-from zipfile import ZipFile
 
 import dash
-import dash_bio as dashbio
-import dash_core_components as dcc
-import dash_html_components as html
-import fire
-import flask
-import numpy as np
+import typer
 from cryptography.fernet import Fernet
 
-from caretta import helper
-from caretta.pfam import PfamToPDB
+from caretta import multiple_alignment
+from caretta.app import app_helper, app_layout, app_callbacks
+from flask import Flask, send_from_directory, abort
 
-key = Fernet.generate_key()
-suite = Fernet(key)
+# for compressing and decompressing files
+KEY = Fernet.generate_key()
+SUITE = Fernet(KEY)
 
-if not Path("static").exists():
-    Path("static").mkdir()
+STATIC = Path("static")
+if not STATIC.exists():
+    STATIC.mkdir()
 
-external_stylesheets = ["https://cdnjs.cloudflare.com/ajax/libs/skeleton/2.0.4/skeleton.css"]
+external_stylesheets = [
+    "https://cdnjs.cloudflare.com/ajax/libs/skeleton/2.0.4/skeleton.css"
+]
 
-app = dash.Dash(__name__, external_stylesheets=external_stylesheets, url_base_pathname="/caretta/")
-
-
-def feature_heatmap(data, zeros=False):
-    if zeros:
-        length = 2
-        z = np.zeros((length, length))
-    else:
-        keys = list(data.keys())
-        length = len(data[keys[0]])
-        z = np.zeros((len(data), length))
-        for i in range(len(data)):
-            for j in range(length):
-                z[i, j] = data[keys[i]][j]
-    return dict(data=[dict(z=z, type="heatmap", showscale=False)], layout=dict(margin=dict(l=25, r=25, t=25, b=25)))
-
-
-def feature_line(features, alignment):
-    length = len(features[list(features.keys())[0]])
-    z = np.zeros((len(features), length))
-    keys = list(features.keys())
-    for i in range(len(features)):
-        for j in range(length):
-            if alignment[keys[i]][j] is not "-":
-                z[i, j] = features[keys[i]][j]
-            else:
-                z[i, j] = np.NaN
-    y = np.array([np.nanmean(z[:, x]) for x in range(z.shape[1])])
-    y_se = np.array([np.nanstd(z[:, x]) / np.sqrt(z.shape[1]) for x in range(z.shape[1])])
-
-    data = [dict(y=list(y + y_se) + list(y - y_se)[::-1], x=list(range(length)) + list(range(length))[::-1],
-                 fillcolor="lightblue", fill="toself", type="scatter", mode="lines", name="Standard error",
-                 line=dict(color='lightblue')),
-            dict(y=y, x=np.arange(length), type="scatter", mode="lines", name="Mean",
-                 line=dict(color='blue'))]
-    return dict(data=data, layout=dict(legend=dict(x=0.5, y=1.2), margin=dict(l=25, r=25, t=25, b=25)))
-
-
-def structure_plot(coord_dict):
-    data = []
-    for k, v in coord_dict.items():
-        x, y, z = v[:, 0], v[:, 1], v[:, 2]
-        data.append(dict(
-            x=x,
-            y=y,
-            z=z,
-            mode='lines',
-            type='scatter3d',
-            text=None,
-            name=str(k),
-            line=dict(
-                width=3,
-                opacity=0.8)))
-    layout = dict(margin=dict(l=20, r=20, t=20, b=20), clickmode='event+select',
-                  scene=dict(xaxis=dict(visible=False, showgrid=False, showline=False),
-                             yaxis=dict(visible=False, showgrid=False, showline=False),
-                             zaxis=dict(visible=False, showgrid=False, showline=False)))
-    return dict(data=data, layout=layout)
-
-
-def check_gap(sequences, i):
-    for seq in sequences:
-        if seq[i] == "-":
-            return True
-    return False
-
-
-def get_feature_z(features, alignments):
-    core_indices = []
-    sequences = list(alignments.values())
-    for i in range(len(sequences[0])):
-        if not check_gap(sequences, i):
-            core_indices.append(i)
-        else:
-            continue
-    return {x: features[x][np.arange(len(sequences[0]))] for x in features}
-
-
-def write_as_csv(feature_dict, file_name):
-    with open(file_name, "w") as f:
-        for protein_name, features in feature_dict.items():
-            f.write(";".join([protein_name] + [str(x) for x in list(features)]) + "\n")
-
-
-def write_as_csv_all_features(feature_dict, file_name):
-    with open(file_name, "w") as f:
-        for protein_name, feature_dict in feature_dict.items():
-            for feature_name, feature_values in feature_dict.items():
-                f.write(";".join([protein_name, feature_name] + [str(x) for x in list(feature_values)]) + "\n")
-
-
-box_style = {"box-shadow": "1px 3px 20px -4px rgba(0,0,0,0.75)",
-             "border-radius": "5px", "background-color": "#f9f7f7"}
-
-box_style_lg = {"top-margin": 25,
-                "border-style": "solid",
-                "border-color": "rgb(187, 187, 187)",
-                "border-width": "1px",
-                "border-radius": "5px",
-                "background-color": "#edfdff"}
-
-box_style_lr = {"top-margin": 25,
-                "border-style": "solid",
-                "border-color": "rgb(187, 187, 187)",
-                "border-width": "1px",
-                "border-radius": "5px",
-                "background-color": "#ffbaba"}
-
-
-def compress_object(raw_object):
-    return base64.b64encode(suite.encrypt(pickle.dumps(raw_object, protocol=4))).decode("utf-8")
-
-
-def decompress_object(compressed_object):
-    return pickle.loads(suite.decrypt(base64.b64decode(compressed_object)))
-
-
-def protein_to_aln_index(protein_index, aln_seq):
-    n = 0
-    for i in range(len(aln_seq)):
-        if protein_index == n:
-            return i
-        elif aln_seq[i] == "-":
-            pass
-        else:
-            n += 1
-
-
-def aln_index_to_protein(alignment_index, alignment):
-    res = dict()
-    for k, v in alignment.items():
-        if v[alignment_index] == "-":
-            res[k] = None
-        else:
-            res[k] = alignment_index - v[:alignment_index].count("-")
-    return res
-
-
-pfam_start = PfamToPDB(from_file=False, limit=100)
-pfam_start = list(pfam_start.pfam_to_pdb_ids.keys())
-pfam_start = [{"label": x, "value": x} for x in pfam_start]
-
-introduction_text_web = dcc.Markdown("""This is a demo webserver for *caretta*. It generates multiple structure alignments for proteins from a selected 
-Pfam domain and displays the alignment, the superposed proteins, and aligned structural features. 
+app = dash.Dash(
+    __name__,
+    external_stylesheets=external_stylesheets,
+    url_base_pathname="/caretta/",
+)
 
+introduction_text = """This is a demo webserver for *caretta*. It generates multiple structure alignments for proteins from a selected 
+Pfam domain and displays the alignment, the superposed proteins, and aligned structural features.\n\n
 While the server is restricted to a maximum of 50 proteins and 100 Pfam domains, you can download this GUI and command-line tool from 
-[the git repository](https://git.wageningenur.nl/durai001/caretta) and run it locally to use it on as many proteins as you'd like. 
-
-All the generated data can further be exported for downstream use.""")
-
-pfam_selection_text_web = """Choose a Pfam ID and click on Load Structures. 
-Then use the dropdown box to select which PDB IDs to align."""
-structure_alignment_text = """Click on a residue to see its position on the feature alignment in the next section."""
-feature_alignment_text = """Click on a position in the feature alignment to see the corresponding residues in the previous section."""
-
-app.layout = html.Div(children=[html.Div(html.Div([html.H1("Caretta",
-                                                           style={"text-align": "center"}),
-                                                   html.H3(
-                                                       "a multiple protein structure alignment and feature extraction suite",
-                                                       style={"text-align": "center"}),
-                                                   html.P(introduction_text_web, style={"text-align": "left"})],
-                                                  className="row"),
-                                         className="container"),
-                                html.Div([html.Br(), html.P(children=compress_object(PfamToPDB(from_file=False,
-                                                                                               limit=100)),
-                                                            id="pfam-class",
-                                                            style={"display": "none"}),
-                                          html.P(children="", id="feature-data",
-                                                 style={"display": "none"}),
-                                          html.P(children=compress_object(0), id="button1",
-                                                 style={"display": "none"}),
-                                          html.P(children=compress_object(0), id="button2",
-                                                 style={"display": "none"}),
-                                          html.P(children="", id="alignment-data",
-                                                 style={"display": "none"}),
-                                          html.Div([html.H3("Choose Structures", className="row",
-                                                            style={"text-align": "center"}),
-                                                    html.P(pfam_selection_text_web, className="row"),
-                                                    html.Div([html.Div(dcc.Dropdown(placeholder="Choose Pfam ID",
-                                                                                    options=pfam_start, id="pfam-ids"),
-                                                                       className="four columns"),
-                                                              html.Button("Load Structures", className="four columns",
-                                                                          id="load-button"),
-                                                              # html.Div(
-                                                              #     dcc.Input(placeholder="Custom folder", value="", type="text", id="custom-folder"),
-                                                              #     className="four columns")],
-                                                              ], className="row"),
-                                                    html.Div(
-                                                        [html.Div(dcc.Dropdown(placeholder="Gap open penalty (1.0)",
-                                                                               options=[
-                                                                                   {"label": np.round(x, decimals=2),
-                                                                                    "value": x} for x in
-                                                                                   np.arange(0, 5, 0.1)],
-                                                                               id="gap-open"),
-                                                                  className="four columns"),
-                                                         html.Div(dcc.Dropdown(multi=True, id="structure-selection"),
-                                                                  className="four columns"),
-                                                         html.Div(dcc.Dropdown(placeholder="Gap extend penalty (0.01)",
-                                                                               options=[
-                                                                                   {"label": np.round(x, decimals=3),
-                                                                                    "value": x} for x in
-                                                                                   np.arange(0, 1, 0.002)],
-                                                                               id="gap-extend"),
-                                                                  className="four columns")], className="row"),
-                                                    html.Br(),
-                                                    html.Div(html.Button("Align Structures", className="twelve columns",
-                                                                         id="align"),
-                                                             className="row"),
-                                                    dcc.Loading(id="loading-1", children=[
-                                                        html.Div(id="output-1", style={"text-align": "center"})],
-                                                                type="default")
-                                                    ],
-                                                   className="container"),
-                                          html.Br()], className="container", style=box_style),
-                                html.Br(),
-                                html.Div(children=[html.Br(),
-                                                   html.H3("Sequence alignment", className="row",
-                                                           style={"text-align": "center"}),
-                                                   html.Div(
-                                                       html.P("", className="row"),
-                                                       className="container"),
-                                                   html.Div([html.Button("Download alignment", className="row",
-                                                                         id="alignment-download"),
-                                                             html.Div(children="", className="row",
-                                                                      id="fasta-download-link")],
-                                                            className="container"),
-                                                   html.Div(html.P(id="alignment", className="twelve columns"),
-                                                            className="row")],
-                                         className="container", style=box_style),
-                                html.Br(),
-                                html.Div([html.Br(),
-                                          html.H3("Structural alignment", className="row",
-                                                  style={"text-align": "center"}),
-                                          html.Div(html.P(structure_alignment_text,
-                                                          className="row"), className="container"),
-                                          html.Div([html.Button("Download PDB", className="row", id="pdb",
-                                                                style={"align": "center"}),
-                                                    html.Div(children="", className="row",
-                                                             id="pdb-download-link")],
-                                                   className="container"),
-                                          html.Div(
-                                              children=dcc.Graph(figure=feature_heatmap([[0, 0], [0, 0]], zeros=True),
-                                                                 id="scatter3d"),
-                                              className="row", id="aligned-proteins"), html.Br()],
-                                         className="container", style=box_style),
-                                html.Br(),
-                                html.Div(
-                                    [html.Br(), html.Div([html.Div(
-                                        [html.H3("Feature alignment", className="row", style={"text-align": "center"}),
-                                         html.P(
-                                             feature_alignment_text,
-                                             className="row"),
-                                         dcc.Dropdown(placeholder="Choose a feature", id="feature-selection",
-                                                      className="six columns"),
-                                         html.Button("Display feature alignment", id="feature-button",
-                                                     className="six columns")], className="row"),
-                                        html.Div(
-                                            [html.Div([html.Button("Export feature", id="export"),
-                                                       html.Button("Export all features",
-                                                                   id="export-all")], id="exporter"),
-                                             html.Div(html.P(""), id="link-field"),
-                                             html.Br()])], className="container"),
-
-                                     html.Div(
-                                         html.Div(dcc.Graph(figure=feature_heatmap([[0, 0], [0, 0]], zeros=True),
-                                                            id="feature-line"),
-                                                  id="feature-plot1"),
-                                         className="row"),
-                                     html.Div(
-                                         html.Div(dcc.Graph(figure=feature_heatmap([[0, 0], [0, 0]], zeros=True),
-                                                            id="heatmap"), id="feature-plot2"),
-                                         className="row")],
-                                    className="container", style=box_style),
-                                html.Br(), html.Br(), html.Div(id="testi")])
-
-
-@app.callback(dash.dependencies.Output('fasta-download-link', 'children'),
-              [dash.dependencies.Input('alignment-download', 'n_clicks')],
-              [dash.dependencies.State("alignment-data", "children"),
-               dash.dependencies.State("pfam-class", "children")])
-def download_alignment(clicked, data, pfam_data):
-    if clicked and data and pfam_data:
-        alignment = decompress_object(data)
-        if not alignment:
-            return ""
-        pfam_class = decompress_object(pfam_data)
-        fasta = pfam_class.to_fasta_str(alignment)
-        fnum = np.random.randint(0, 1000000000)
-        fname = f"static/{fnum}.fasta"
-        with open(fname, "w") as f:
-            f.write(fasta)
-        return html.A(f"Download %s here" % ("alignment" + ".fasta"), href="/caretta/%s" % fname)
-    else:
-        return ""
-
-
-@app.callback(dash.dependencies.Output('pdb-download-link', 'children'),
-              [dash.dependencies.Input('pdb', 'n_clicks')],
-              [dash.dependencies.State("alignment-data", "children"),
-               dash.dependencies.State("pfam-class", "children")])
-def download_pdb(clicked, data, pfam_data):
-    if clicked and data and pfam_data:
-        alignment = decompress_object(data)
-        if not alignment:
-            return ""
-        pfam_class = decompress_object(pfam_data)
-
-        fnum = np.random.randint(0, 1000000000)
-        pfam_class.msa.output_files.pdb_folder = Path(f"static/{fnum}_pdb")
-        pfam_class.msa.write_files(write_pdb=True, write_fasta=False, write_class=False, write_features=False)
-        pdb_zip_file = ZipFile(f"static/{fnum}_pdb.zip", mode="w")
-        for pdb_file in Path(pfam_class.msa.output_files.pdb_folder).glob("*.pdb"):
-            pdb_zip_file.write(str(pdb_file))
-        return html.A(f"Download %s here" % ("pdbs" + ".zip"), href="/caretta/%s" % f"static/{fnum}_pdb.zip")
-    else:
-        return ""
-
-
-@app.callback(dash.dependencies.Output('structure-selection', 'options'),
-              [dash.dependencies.Input('load-button', 'n_clicks')],
-              [dash.dependencies.State("pfam-class", "children"),
-               dash.dependencies.State("pfam-ids", "value")])
-def show_selected_atoms(clicked, pfam_class, pfam_id):
-    if clicked and pfam_class and pfam_id:
-        pfam_class = decompress_object(pfam_class)
-        pfam_structures = pfam_class.get_entries_for_pfam(pfam_id)
-        return [{"label": x.PDB_ID, "value": compress_object(x)} for x in pfam_structures]
-    else:
-        return [{"label": "no selection", "value": "None"}]
-
-
-@app.callback([dash.dependencies.Output("output-1", "children"),
-               dash.dependencies.Output("alignment", "children"),
-               dash.dependencies.Output("aligned-proteins", "children"),
-               dash.dependencies.Output("feature-data", "children"),
-               dash.dependencies.Output("feature-selection", "options"),
-               dash.dependencies.Output("alignment-data", "children"),
-               dash.dependencies.Output("pfam-class", "children")],
-              [dash.dependencies.Input("align", "n_clicks")],
-              [dash.dependencies.State("structure-selection", "value"),
-               dash.dependencies.State("pfam-class", "children"),
-               dash.dependencies.State("gap-open", "value"),
-               dash.dependencies.State("gap-extend", "value")])
-def align_structures(clicked, pdb_entries, pfam_class, gap_open, gap_extend):
-    if clicked and pdb_entries and pfam_class:
-        pfam_class = decompress_object(pfam_class)
-        pdb_entries = [decompress_object(x) for x in pdb_entries]
-        if gap_open and gap_extend:
-            alignment, pdbs, features = pfam_class.multiple_structure_alignment_from_pfam(pdb_entries,
-                                                                                          gap_open_penalty=gap_open,
-                                                                                          gap_extend_penalty=gap_extend)
-        else:
-            alignment, pdbs, features = pfam_class.multiple_structure_alignment_from_pfam(pdb_entries)
-        pfam_class.msa.superpose(alignment)
-        fasta = pfam_class.to_fasta_str(alignment)
-        component = dashbio.AlignmentChart(
-            id='my-dashbio-alignmentchart',
-            data=fasta, showconsensus=False, showconservation=False,
-            overview=None, height=300,
-            colorscale="hydrophobicity"
-        )
-        return "", component, dcc.Graph(figure=structure_plot({s.name: s.coords for s in pfam_class.msa.structures}),
-                                        id="scatter3d"), compress_object(
-            features), [{"label": x, "value": x} for x in features[list(features.keys())[0]]], compress_object(
-            alignment), compress_object(pfam_class)
-    else:
-        return "", "", "", compress_object(np.zeros(0)), [
-            {"label": "no alignment present", "value": "no alignment"}], pdb_entries, pfam_class
-
-
-@app.callback([dash.dependencies.Output("feature-plot1", "children"),
-               dash.dependencies.Output("feature-plot2", "children")],
-              [dash.dependencies.Input("feature-button", "n_clicks")],
-              [dash.dependencies.State("feature-selection", "value"),
-               dash.dependencies.State("feature-data", "children"),
-               dash.dependencies.State("alignment-data", "children")])
-def display_feature(clicked, chosen_feature, feature_dict, aln):
-    if clicked and chosen_feature and feature_dict:
-        alignment = decompress_object(aln)
-        feature_dict = decompress_object(feature_dict)
-        chosen_feature_dict = {x: feature_dict[x][chosen_feature] for x in feature_dict}
-        aln_np = {k: helper.aligned_string_to_array(alignment[k]) for k in alignment}
-        chosen_feature_aln_dict = {x: helper.get_aligned_string_data(aln_np[x], chosen_feature_dict[x]) for x in
-                                   feature_dict.keys()}
-        z = get_feature_z(chosen_feature_aln_dict, alignment)
-        component1 = dcc.Graph(figure=feature_heatmap(z), id="heatmap")
-        component2 = dcc.Graph(figure=feature_line(z, alignment), id="feature-line")
-        return component2, component1
-    else:
-        return dcc.Graph(figure=feature_heatmap([[0, 0], [0, 0]], zeros=True),
-                         id="feature-line", style={"display": "none"}), dcc.Graph(
-            figure=feature_heatmap([[0, 0], [0, 0]], zeros=True),
-            id="heatmap", style={"display": "none"})
-
-
-@app.callback([dash.dependencies.Output("link-field", "children"),
-               dash.dependencies.Output("exporter", "children")],
-              [dash.dependencies.Input("export", "n_clicks"),
-               dash.dependencies.Input("export-all", "n_clicks")],
-              [dash.dependencies.State("feature-selection", "value"),
-               dash.dependencies.State("feature-data", "children"),
-               dash.dependencies.State("alignment-data", "children")])
-def write_output(clicked, clicked_all, chosen_feature, feature_dict, aln):
-    if (clicked and chosen_feature and feature_dict and aln) and not clicked_all:
-        alignment = decompress_object(aln)
-        feature_dict = decompress_object(feature_dict)
-        chosen_feature_dict = {x: feature_dict[x][chosen_feature] for x in feature_dict}
-        aln_np = {k: helper.aligned_string_to_array(alignment[k]) for k in alignment}
-        chosen_feature_aln_dict = {x: helper.get_aligned_string_data(aln_np[x], chosen_feature_dict[x]) for x in
-                                   feature_dict.keys()}
-        z = get_feature_z(chosen_feature_aln_dict, alignment)
-        fnum = np.random.randint(0, 1000000000)
-        fname = f"static/{fnum}.csv"
-        write_as_csv(z, fname)
-        return html.A(f"Download %s here" % (str(fnum) + ".csv"), href="/caretta/%s" % fname), [
-            html.Button("Export feature", id="export"),
-            html.Button("Export all features", id="export-all")]
-    elif (clicked_all and feature_dict and aln) and not clicked:
-        feature_dict = decompress_object(feature_dict)
-        fnum = np.random.randint(0, 1000000000)
-        fname = f"static/{fnum}.csv"
-        write_as_csv_all_features(feature_dict, fname)
-        return html.A(f"Download %s here" % (str(fnum) + ".csv"), href="/caretta/%s" % fname), [
-            html.Button("Export feature", id="export"),
-            html.Button("Export all features", id="export-all")]
-    else:
-        return "", [html.Button("Export feature", id="export"),
-                    html.Button("Export all features", id="export-all")]
-
+[the git repository](https://github.com/TurtleTools/caretta) and run it locally to use it on as many proteins as you'd like. Also, while this
+demo server only uses structures from the PDB, a self-hosted version can be used to align and compare homology models as well.\n\n 
+The resulting alignment can be used in machine learning applications aimed at predicting a certain aspect of a protein family, such as 
+substrate specificity, interaction specificity, catalytic activity etc. For this purpose, caretta also outputs matrices of extracted 
+structural features. These represent different attributes of each residue in each protein, such as residue depth, electrostatic energy, 
+bond angles etc. Using these feature matrices as input to a supervised machine learning algorithm can pinpoint residue positions or structural 
+regions which correlate with a given property.\n
+All the generated data can further be exported for downstream use."""
 
-@app.server.route('/caretta/static/<path:path>')
-def download_file(path):
-    root_dir = os.getcwd()
-    return flask.send_from_directory(
-        os.path.join(root_dir, 'static'), path)
+input_text = "Enter a Pfam ID and click on Load Structures."
+placeholder_text = "Enter a Pfam ID (e.g. PF04851.14)"
+selection_text = """Use the dropdown box to select which PDB IDs to align."""
 
+PFAM_TO_PDB = app_helper.PfamToPDB(from_file=False, limit=100)
+app.layout = app_layout.get_layout(
+    introduction_text, input_text, placeholder_text, selection_text, SUITE, pfam_class=PFAM_TO_PDB
+)
 
-@app.callback([dash.dependencies.Output("feature-line", "figure"),
-               dash.dependencies.Output("scatter3d", "figure"),
-               dash.dependencies.Output("button1", "children"),
-               dash.dependencies.Output("button2", "children")],
-              [dash.dependencies.Input("scatter3d", "clickData"),
-               dash.dependencies.Input("feature-line", "clickData")],
-              [dash.dependencies.State("feature-line", "figure"),
-               dash.dependencies.State("scatter3d", "figure"),
-               dash.dependencies.State("button1", "children"),
-               dash.dependencies.State("button2", "children"),
-               dash.dependencies.State("alignment-data", "children")])
-def update_features(clickdata_3d, clickdata_feature, feature_data, scatter3d_data, button1, button2, alignment_data):
-    if feature_data and scatter3d_data and clickdata_feature and compress_object(
-            (clickdata_feature["points"][0]["pointNumber"], clickdata_feature["points"][0]["curveNumber"])) != button1:
-        alignment = decompress_object(alignment_data)
-        number_of_structures = len(alignment)
-        clickdata = clickdata_feature
-        idx = clickdata["points"][0]["pointNumber"]
-        protein_index = clickdata["points"][0]["curveNumber"]
-        aln_positions = aln_index_to_protein(idx, alignment)
-        button1 = compress_object((idx, protein_index))
-        x, y = clickdata["points"][0]["x"], clickdata["points"][0]["y"]
-        try:
-            maxim, minim = np.max(feature_data["data"][0]["y"]), np.min(feature_data["data"][0]["y"])
-        except KeyError:
-            return feature_data, scatter3d_data, button1, button2
-        if len(feature_data["data"]) > 2:
-            feature_data["data"] = feature_data["data"][:-1]
-        feature_data["data"] += [dict(y=[minim, maxim], x=[idx, idx], type="scatter", mode="lines",
-                                      name="selected residue")]
-        if len(scatter3d_data["data"]) > number_of_structures:
-            scatter3d_data["data"] = scatter3d_data["data"][:-1]
-        to_add = []
-        for i in range(len(scatter3d_data["data"])):
-            d = scatter3d_data["data"][i]
-            k = d["name"]
-            p = aln_positions[k]
-            if p is not None:
-                x, y, z = d["x"][p], d["y"][p], d["z"][p]
-                to_add.append((x, y, z))
-            else:
-                continue
-        scatter3d_data["data"] += [dict(x=[x[0] for x in to_add],
-                                        y=[y[1] for y in to_add],
-                                        z=[z[2] for z in to_add], type="scatter3d", mode="markers",
-                                        name="selected residues")]
-        return feature_data, scatter3d_data, button1, button2
-    if feature_data and scatter3d_data and clickdata_3d and compress_object(
-            (clickdata_3d["points"][0]["pointNumber"], clickdata_3d["points"][0]["curveNumber"])) != button2:
-        alignment = decompress_object(alignment_data)
-        number_of_structures = len(alignment)
-        clickdata = clickdata_3d
-        idx = clickdata["points"][0]["pointNumber"]
-        protein_index = clickdata["points"][0]["curveNumber"]
-        button2 = compress_object((idx, protein_index))
-        gapped_sequence = list(alignment.values())[protein_index]
-        aln_index = protein_to_aln_index(idx, gapped_sequence)
-        x, y, z = clickdata["points"][0]["x"], clickdata["points"][0]["y"], clickdata["points"][0]["z"]
-        try:
-            maxim, minim = np.max(feature_data["data"][0]["y"]), np.min(feature_data["data"][0]["y"])
-        except KeyError:
-            return feature_data, scatter3d_data, button1, button2
-        if len(feature_data["data"]) > 2:
-            feature_data["data"] = feature_data["data"][:-1]
-        feature_data["data"] += [dict(y=[minim, maxim], x=[aln_index, aln_index], type="scatter", mode="lines",
-                                      name="selected_residue")]
-        if len(scatter3d_data["data"]) > number_of_structures:
-            scatter3d_data["data"] = scatter3d_data["data"][:-1]
-        scatter3d_data["data"] += [dict(y=[y], x=[x], z=[z], type="scatter3d", mode="markers",
-                                        name="selected residue")]
-        return feature_data, scatter3d_data, button1, button2
 
-    elif feature_data and scatter3d_data:
-        return feature_data, scatter3d_data, button1, button2
+app_callbacks.register_callbacks(app, PFAM_TO_PDB.get_entries_for_pfam, SUITE)
 
 
-def run_server(host="0.0.0.0", port=8888):
+def run(
+    host: str = typer.Argument("0.0.0.0", help="host IP to serve the app"),
+    port: int = typer.Argument(8888, help="port"),
+):
     """
     caretta-app is the GUI of caretta, capable of aligning and visualising multiple protein structures
     and allowing extraction of aligned features such as bond angles, residue depths and fluctuations.
-    ----------
-    host
-        host ip (string)
-    port
-        port
     """
+    multiple_alignment.trigger_numba_compilation()
     app.run_server(host=host, port=port)
 
 
-if __name__ == '__main__':
-    fire.Fire(run_server)
+if __name__ == "__main__":
+    typer.run(run)
diff --git a/bin/caretta-app-demo3 b/bin/caretta-app-demo3
deleted file mode 100644
index 8f80e9fb66c41d2ed891903626e8bfc0b05272c7..0000000000000000000000000000000000000000
--- a/bin/caretta-app-demo3
+++ /dev/null
@@ -1,593 +0,0 @@
-#!/usr/bin/env python3
-
-import base64
-import os
-import pickle
-from zipfile import ZipFile
-import datetime
-
-import dash
-import dash_bio as dashbio
-import dash_core_components as dcc
-import dash_html_components as html
-import fire
-import flask
-import numpy as np
-from cryptography.fernet import Fernet
-from pathlib import Path
-
-from caretta import helper
-from caretta.pfam import PfamToPDB
-
-key = Fernet.generate_key()
-suite = Fernet(key)
-
-if not Path("static").exists():
-    Path("static").mkdir()
-
-external_stylesheets = ["https://cdnjs.cloudflare.com/ajax/libs/skeleton/2.0.4/skeleton.css"]
-
-app = dash.Dash(__name__, external_stylesheets=external_stylesheets, url_base_pathname="/caretta/")
-
-
-def feature_heatmap(data, zeros=False):
-    if zeros:
-        length = 2
-        z = np.zeros((length, length))
-    else:
-        keys = list(data.keys())
-        length = len(data[keys[0]])
-        z = np.zeros((len(data), length))
-        for i in range(len(data)):
-            for j in range(length):
-                z[i, j] = data[keys[i]][j]
-    return dict(data=[dict(z=z, type="heatmap", showscale=False)], layout=dict(margin=dict(l=25, r=25, t=25, b=25)))
-
-
-def feature_line(features, alignment):
-    length = len(features[list(features.keys())[0]])
-    z = np.zeros((len(features), length))
-    keys = list(features.keys())
-    for i in range(len(features)):
-        for j in range(length):
-            if alignment[keys[i]][j] is not "-":
-                z[i, j] = features[keys[i]][j]
-            else:
-                z[i, j] = np.NaN
-    y = np.array([np.nanmean(z[:, x]) for x in range(z.shape[1])])
-    y_se = np.array([np.nanstd(z[:, x]) / np.sqrt(z.shape[1]) for x in range(z.shape[1])])
-
-    data = [dict(y=list(y + y_se) + list(y - y_se)[::-1], x=list(range(length)) + list(range(length))[::-1],
-                 fillcolor="lightred", fill="toself", type="scatter", mode="lines", name="Mean +/- standard error",
-                 line=dict(color='red')),
-            dict(y=y, x=np.arange(length), type="scatter", mode="lines", name="Mean",
-                 line=dict(color='blue'))]
-    return dict(data=data, layout=dict(legend=dict(x=0.5, y=1.2), margin=dict(l=25, r=25, t=25, b=25)))
-
-
-def structure_plot(coord_dict):
-    data = []
-    for k, v in coord_dict.items():
-        x, y, z = v[:, 0], v[:, 1], v[:, 2]
-        data.append(dict(
-            x=x,
-            y=y,
-            z=z,
-            mode='lines',
-            type='scatter3d',
-            text=None,
-            name=str(k),
-            line=dict(
-                width=3,
-                opacity=0.8)))
-    layout = dict(margin=dict(l=20, r=20, t=20, b=20), clickmode='event+select',
-                  scene=dict(xaxis=dict(visible=False, showgrid=False, showline=False),
-                             yaxis=dict(visible=False, showgrid=False, showline=False),
-                             zaxis=dict(visible=False, showgrid=False, showline=False)))
-    return dict(data=data, layout=layout)
-
-
-def check_gap(sequences, i):
-    for seq in sequences:
-        if seq[i] == "-":
-            return True
-    return False
-
-
-def get_feature_z(features, alignments):
-    core_indices = []
-    sequences = list(alignments.values())
-    for i in range(len(sequences[0])):
-        if not check_gap(sequences, i):
-            core_indices.append(i)
-        else:
-            continue
-    return {x: features[x][np.arange(len(sequences[0]))] for x in features}
-
-
-def write_as_csv(feature_dict, file_name):
-    with open(file_name, "w") as f:
-        for protein_name, features in feature_dict.items():
-            f.write(";".join([protein_name] + [str(x) for x in list(features)]) + "\n")
-
-
-def write_as_csv_all_features(feature_dict, file_name):
-    with open(file_name, "w") as f:
-        for protein_name, feature_dict in feature_dict.items():
-            for feature_name, feature_values in feature_dict.items():
-                f.write(";".join([protein_name, feature_name] + [str(x) for x in list(feature_values)]) + "\n")
-
-
-box_style = {"box-shadow": "1px 3px 20px -4px rgba(0,0,0,0.75)",
-             "border-radius": "5px", "background-color": "#f9f7f7"}
-
-box_style_lg = {"top-margin": 25,
-                "border-style": "solid",
-                "border-color": "rgb(187, 187, 187)",
-                "border-width": "1px",
-                "border-radius": "5px",
-                "background-color": "#edfdff"}
-
-box_style_lr = {"top-margin": 25,
-                "border-style": "solid",
-                "border-color": "rgb(187, 187, 187)",
-                "border-width": "1px",
-                "border-radius": "5px",
-                "background-color": "#ffbaba"}
-
-
-def compress_object(raw_object):
-    return base64.b64encode(suite.encrypt(pickle.dumps(raw_object, protocol=4))).decode("utf-8")
-
-
-def decompress_object(compressed_object):
-    return pickle.loads(suite.decrypt(base64.b64decode(compressed_object)))
-
-
-def protein_to_aln_index(protein_index, aln_seq):
-    n = 0
-    for i in range(len(aln_seq)):
-        if protein_index == n:
-            return i
-        elif aln_seq[i] == "-":
-            pass
-        else:
-            n += 1
-
-
-def aln_index_to_protein(alignment_index, alignment):
-    res = dict()
-    for k, v in alignment.items():
-        if v[alignment_index] == "-":
-            res[k] = None
-        else:
-            res[k] = alignment_index - v[:alignment_index].count("-")
-    return res
-
-
-pfam_start = PfamToPDB(from_file=False, limit=100)
-pfam_start = list(pfam_start.pfam_to_pdb_ids.keys())
-pfam_start = [{"label": x, "value": x} for x in pfam_start]
-
-introduction_text_web = dcc.Markdown("""This is a demo webserver for *caretta*. It generates multiple structure alignments for proteins from a selected 
-Pfam domain and displays the alignment, the superposed proteins, and aligned structural features. 
-
-While the server is restricted to a maximum of 50 proteins and 100 Pfam domains, you can download this GUI and command-line tool from 
-[the git repository](https://git.wageningenur.nl/durai001/caretta) and run it locally to use it on as many proteins as you'd like. Also, while this
-demo server only uses structures from the PDB, a self-hosted version can be used to align and compare homology models as well. 
-
-The resulting alignment can be used in machine learning applications aimed at predicting a certain aspect of a protein family, such as 
-substrate specificity, interaction specificity, catalytic activity etc. For this purpose, caretta also outputs matrices of extracted 
-structural features. These represent different attributes of each residue in each protein, such as residue depth, electrostatic energy, 
-bond angles etc. Using these feature matrices as input to a supervised machine learning algorithm can pinpoint residue positions or structural 
-regions which correlate with a given property.
-
-All the generated data can further be exported for downstream use.""")
-
-pfam_selection_text_web = """Choose a Pfam ID and click on Load Structures. 
-Then use the dropdown box to select which PDB IDs to align. The PDB IDs also include the chain ID and the retrieved residue range that contain the selected pfam domain."""
-structure_alignment_text = """Click on a residue to see its position on the feature alignment in the next section."""
-feature_alignment_text = """Click on a position in the feature alignment to see the corresponding residues in the previous section."""
-
-app.layout = html.Div(children=[html.Div(html.Div([html.H1("Caretta",
-                                                           style={"text-align": "center"}),
-                                                   html.H3(
-                                                       "a multiple protein structure alignment and feature extraction suite",
-                                                       style={"text-align": "center"}),
-                                                   html.P(introduction_text_web, style={"text-align": "left"})],
-                                                  className="row"),
-                                         className="container"),
-                                html.Div([html.Br(), html.P(children=compress_object(PfamToPDB(from_file=False,
-                                                                                               limit=100)),
-                                                            id="pfam-class",
-                                                            style={"display": "none"}),
-                                          html.P(children="", id="feature-data",
-                                                 style={"display": "none"}),
-                                          html.P(children=compress_object(0), id="button1",
-                                                 style={"display": "none"}),
-                                          html.P(children=compress_object(0), id="button2",
-                                                 style={"display": "none"}),
-                                          html.P(children="", id="alignment-data",
-                                                 style={"display": "none"}),
-                                          html.Div([html.H3("Choose Structures", className="row",
-                                                            style={"text-align": "center"}),
-                                                    html.P(pfam_selection_text_web, className="row"),
-                                                    html.Div([html.Div(dcc.Dropdown(placeholder="Choose Pfam ID",
-                                                                                    options=pfam_start, id="pfam-ids"),
-                                                                       className="four columns"),
-                                                              html.Button("Load Structures", className="four columns",
-                                                                          id="load-button"),
-                                                              ], className="row"),
-                                                    html.Div(
-                                                        [html.Div(dcc.Dropdown(placeholder="Gap open penalty (1.0)",
-                                                                               options=[
-                                                                                   {"label": np.round(x, decimals=2),
-                                                                                    "value": x} for x in
-                                                                                   np.arange(0, 5, 0.1)],
-                                                                               id="gap-open"),
-                                                                  className="four columns"),
-                                                         html.Div(dcc.Dropdown(multi=True, id="structure-selection",
-                                                                               placeholder="Select PDB IDs to align"),
-                                                                  className="four columns"),
-                                                         html.Div(dcc.Dropdown(placeholder="Gap extend penalty (0.01)",
-                                                                               options=[
-                                                                                   {"label": np.round(x, decimals=3),
-                                                                                    "value": x} for x in
-                                                                                   np.arange(0, 1, 0.002)],
-                                                                               id="gap-extend"),
-                                                                  className="four columns")], className="row"),
-                                                    html.Br(),
-                                                    html.Div(html.Button("Align Structures", className="twelve columns",
-                                                                         id="align"),
-                                                             className="row"),
-                                                    dcc.Loading(id="loading-1", children=[
-                                                        html.Div(id="output-1", style={"text-align": "center"})],
-                                                                type="default"),
-                                                    html.P(id="time-estimate", style={"text-align": "center"},
-                                                           children="", className="row"),
-                                                    html.Div(id="alignment-done", style={"display": "none"},
-                                                             children=[False]),
-                                                    ],
-                                                   className="container"),
-                                          html.Br()], className="container", style=box_style),
-                                html.Br(),
-                                html.Div(children=[html.Br(),
-                                                   html.H3("Sequence alignment", className="row",
-                                                           style={"text-align": "center"}),
-                                                   html.Div(
-                                                       html.P("", className="row"),
-                                                       className="container"),
-                                                   html.Div([html.Button("Download alignment", className="row",
-                                                                         id="alignment-download"),
-                                                             html.Div(children="", className="row",
-                                                                      id="fasta-download-link")],
-                                                            className="container"),
-                                                   html.Div(html.P(id="alignment", className="twelve columns"),
-                                                            className="row")],
-                                         className="container", style=box_style),
-                                html.Br(),
-                                html.Div([html.Br(),
-                                          html.H3("Structural alignment", className="row",
-                                                  style={"text-align": "center"}),
-                                          html.Div(html.P(structure_alignment_text,
-                                                          className="row"), className="container"),
-                                          html.Div([html.Button("Download PDB", className="row", id="pdb",
-                                                                style={"align": "center"}),
-                                                    html.Div(children="", className="row",
-                                                             id="pdb-download-link")],
-                                                   className="container"),
-                                          html.Div(
-                                              children=dcc.Graph(figure=feature_heatmap([[0, 0], [0, 0]], zeros=True),
-                                                                 id="scatter3d"),
-                                              className="row", id="aligned-proteins"), html.Br()],
-                                         className="container", style=box_style),
-                                html.Br(),
-                                html.Div(
-                                    [html.Br(), html.Div([html.Div(
-                                        [html.H3("Feature alignment", className="row", style={"text-align": "center"}),
-                                         html.P(
-                                             feature_alignment_text,
-                                             className="row"),
-                                         dcc.Dropdown(placeholder="Choose a feature", id="feature-selection",
-                                                      className="six columns"),
-                                         html.Button("Display feature alignment", id="feature-button",
-                                                     className="six columns")], className="row"),
-                                        html.Div(
-                                            [html.Div([html.Button("Export feature", id="export"),
-                                                       html.Button("Export all features",
-                                                                   id="export-all")], id="exporter"),
-                                             html.Div(html.P(""), id="link-field"),
-                                             html.Br()])], className="container"),
-
-                                     html.Div(
-                                         html.Div(dcc.Graph(figure=feature_heatmap([[0, 0], [0, 0]], zeros=True),
-                                                            id="feature-line"),
-                                                  id="feature-plot1"),
-                                         className="row"),
-                                     html.Div(
-                                         html.Div(dcc.Graph(figure=feature_heatmap([[0, 0], [0, 0]], zeros=True),
-                                                            id="heatmap"), id="feature-plot2"),
-                                         className="row")],
-                                    className="container", style=box_style),
-                                html.Br(), html.Br(), html.Div(id="testi")])
-
-
-@app.callback(dash.dependencies.Output('fasta-download-link', 'children'),
-              [dash.dependencies.Input('alignment-download', 'n_clicks')],
-              [dash.dependencies.State("alignment-data", "children"),
-               dash.dependencies.State("pfam-class", "children")])
-def download_alignment(clicked, data, pfam_data):
-    if clicked and data and pfam_data:
-        alignment = decompress_object(data)
-        if not alignment:
-            return ""
-        pfam_class = decompress_object(pfam_data)
-        fasta = pfam_class.to_fasta_str(alignment)
-        fnum = np.random.randint(0, 1000000000)
-        fname = f"static/{fnum}.fasta"
-        with open(fname, "w") as f:
-            f.write(fasta)
-        return html.A(f"Download %s here" % ("alignment" + ".fasta"), href="/caretta/%s" % fname)
-    else:
-        return ""
-
-
-@app.callback(dash.dependencies.Output('pdb-download-link', 'children'),
-              [dash.dependencies.Input('pdb', 'n_clicks')],
-              [dash.dependencies.State("alignment-data", "children"),
-               dash.dependencies.State("pfam-class", "children")])
-def download_pdb(clicked, data, pfam_data):
-    if clicked and data and pfam_data:
-        alignment = decompress_object(data)
-        if not alignment:
-            return ""
-        pfam_class = decompress_object(pfam_data)
-
-        fnum = np.random.randint(0, 1000000000)
-        pfam_class.msa.output_files.pdb_folder = Path(f"static/{fnum}_pdb")
-        pfam_class.msa.write_files(write_pdb=True, write_fasta=False, write_class=False, write_features=False)
-        pdb_zip_file = ZipFile(f"static/{fnum}_pdb.zip", mode="w")
-        for pdb_file in Path(pfam_class.msa.output_files.pdb_folder).glob("*.pdb"):
-            pdb_zip_file.write(str(pdb_file))
-        return html.A(f"Download %s here" % ("pdbs" + ".zip"), href="/caretta/%s" % f"static/{fnum}_pdb.zip")
-    else:
-        return ""
-
-
-@app.callback(dash.dependencies.Output('structure-selection', 'options'),
-              [dash.dependencies.Input('load-button', 'n_clicks')],
-              [dash.dependencies.State("pfam-class", "children"),
-               dash.dependencies.State("pfam-ids", "value")])
-def show_selected_atoms(clicked, pfam_class, pfam_id):
-    if clicked and pfam_class and pfam_id:
-        pfam_class = decompress_object(pfam_class)
-        pfam_structures = pfam_class.get_entries_for_pfam(pfam_id)
-        return [{"label": f"{x.PDB_ID}.{x.CHAIN_ID} ({x.PdbResNumStart}-{x.PdbResNumEnd})", "value": compress_object(x)}
-                for x in pfam_structures]  # TODO: test if working
-    else:
-        return [{"label": "no selection", "value": "None"}]
-
-
-def get_estimated_time(pdb_entries):
-    n = len(pdb_entries)
-    l = max(p.PdbResNumEnd - p.PdbResNumStart for p in pdb_entries)
-    func = lambda x, r: (x[0] ** 2 * r * x[1] ** 2)
-    return str(datetime.timedelta(seconds=int(func((l, n), 9.14726052e-06))))
-
-
-@app.callback(dash.dependencies.Output("time-estimate", "children"),
-              [dash.dependencies.Input("align", "n_clicks")],
-              [dash.dependencies.State("structure-selection", "value")])
-def get_time_estimate(clicked, pdb_entries):
-    if clicked and pdb_entries:
-        pdb_entries = [decompress_object(x) for x in pdb_entries]
-        time = get_estimated_time(pdb_entries)
-        return f"Estimated time: {time}"
-    else:
-        return ""
-
-
-@app.callback(
-    [dash.dependencies.Output("output-1", "children"),
-     dash.dependencies.Output("alignment", "children"),
-     dash.dependencies.Output("aligned-proteins", "children"),
-     dash.dependencies.Output("feature-data", "children"),
-     dash.dependencies.Output("feature-selection", "options"),
-     dash.dependencies.Output("alignment-data", "children"),
-     dash.dependencies.Output("pfam-class", "children"),
-     dash.dependencies.Output("alignment-done", "children")
-     ],
-
-    [dash.dependencies.Input("align", "n_clicks")],
-
-    [dash.dependencies.State("structure-selection", "value"),
-     dash.dependencies.State("pfam-class", "children"),
-     dash.dependencies.State("gap-open", "value"),
-     dash.dependencies.State("gap-extend", "value")])
-def align_structures(clicked, pdb_entries, pfam_class, gap_open, gap_extend):
-    if clicked and pdb_entries and pfam_class:
-        pfam_class = decompress_object(pfam_class)
-        pdb_entries = [decompress_object(x) for x in pdb_entries]
-
-        if gap_open and gap_extend:
-            # TODO: add try except here to fix zero division error
-            alignment, pdbs, features = pfam_class.multiple_structure_alignment_from_pfam(pdb_entries,
-                                                                                          gap_open_penalty=gap_open,
-                                                                                          gap_extend_penalty=gap_extend)
-        else:
-            alignment, pdbs, features = pfam_class.multiple_structure_alignment_from_pfam(pdb_entries)
-        pfam_class.msa.superpose(alignment)
-        fasta = pfam_class.to_fasta_str(alignment)
-        component = dashbio.AlignmentChart(
-            id='my-dashbio-alignmentchart',
-            data=fasta, showconsensus=False, showconservation=False,
-            overview=None, height=300,
-            colorscale="hydrophobicity"
-        )
-        return "", component, dcc.Graph(figure=structure_plot({s.name: s.coords for s in pfam_class.msa.structures}),
-                                        id="scatter3d"), compress_object(
-            features), [{"label": x, "value": x} for x in features[list(features.keys())[0]]], compress_object(
-            alignment), compress_object(pfam_class), [True]  # TODO: Test the features != secondary part
-    else:
-        return "", "", "", compress_object(np.zeros(0)), [
-            {"label": "no alignment present", "value": "no alignment"}], pdb_entries, pfam_class, [False]
-
-
-@app.callback([dash.dependencies.Output("feature-plot1", "children"),
-               dash.dependencies.Output("feature-plot2", "children")],
-              [dash.dependencies.Input("feature-button", "n_clicks")],
-              [dash.dependencies.State("feature-selection", "value"),
-               dash.dependencies.State("feature-data", "children"),
-               dash.dependencies.State("alignment-data", "children")])
-def display_feature(clicked, chosen_feature, feature_dict, aln):
-    if clicked and chosen_feature and feature_dict:
-        alignment = decompress_object(aln)
-        feature_dict = decompress_object(feature_dict)
-        chosen_feature_dict = {x: feature_dict[x][chosen_feature] for x in feature_dict}
-        aln_np = {k: helper.aligned_string_to_array(alignment[k]) for k in alignment}
-        chosen_feature_aln_dict = {x: helper.get_aligned_string_data(aln_np[x], chosen_feature_dict[x]) for x in
-                                   feature_dict.keys()}
-        z = get_feature_z(chosen_feature_aln_dict, alignment)
-        component1 = dcc.Graph(figure=feature_heatmap(z), id="heatmap")
-        component2 = dcc.Graph(figure=feature_line(z, alignment), id="feature-line")
-        return component2, component1
-    else:
-        return dcc.Graph(figure=feature_heatmap([[0, 0], [0, 0]], zeros=True),
-                         id="feature-line", style={"display": "none"}), dcc.Graph(
-            figure=feature_heatmap([[0, 0], [0, 0]], zeros=True),
-            id="heatmap", style={"display": "none"})
-
-
-@app.callback([dash.dependencies.Output("link-field", "children"),
-               dash.dependencies.Output("exporter", "children")],
-              [dash.dependencies.Input("export", "n_clicks"),
-               dash.dependencies.Input("export-all", "n_clicks")],
-              [dash.dependencies.State("feature-selection", "value"),
-               dash.dependencies.State("feature-data", "children"),
-               dash.dependencies.State("alignment-data", "children")])
-def write_output(clicked, clicked_all, chosen_feature, feature_dict, aln):
-    if (clicked and chosen_feature and feature_dict and aln) and not clicked_all:
-        alignment = decompress_object(aln)
-        feature_dict = decompress_object(feature_dict)
-        chosen_feature_dict = {x: feature_dict[x][chosen_feature] for x in feature_dict}
-        aln_np = {k: helper.aligned_string_to_array(alignment[k]) for k in alignment}
-        chosen_feature_aln_dict = {x: helper.get_aligned_string_data(aln_np[x], chosen_feature_dict[x]) for x in
-                                   feature_dict.keys()}
-        z = get_feature_z(chosen_feature_aln_dict, alignment)
-        fnum = np.random.randint(0, 1000000000)
-        fname = f"static/{fnum}.csv"
-        write_as_csv(z, fname)
-        return html.A(f"Download %s here" % (str(fnum) + ".csv"), href="/caretta/%s" % fname), [
-            html.Button("Export feature", id="export"),
-            html.Button("Export all features", id="export-all")]
-    elif (clicked_all and feature_dict and aln) and not clicked:
-        feature_dict = decompress_object(feature_dict)
-        fnum = np.random.randint(0, 1000000000)
-        fname = f"static/{fnum}.csv"
-        write_as_csv_all_features(feature_dict, fname)
-        return html.A(f"Download %s here" % (str(fnum) + ".csv"), href="/caretta/%s" % fname), [
-            html.Button("Export feature", id="export"),
-            html.Button("Export all features", id="export-all")]
-    else:
-        return "", [html.Button("Export feature", id="export"),
-                    html.Button("Export all features", id="export-all")]
-
-
-@app.server.route('/caretta/static/<path:path>')
-def download_file(path):
-    root_dir = os.getcwd()
-    return flask.send_from_directory(
-        os.path.join(root_dir, 'static'), path)
-
-
-@app.callback([dash.dependencies.Output("feature-line", "figure"),
-               dash.dependencies.Output("scatter3d", "figure"),
-               dash.dependencies.Output("button1", "children"),
-               dash.dependencies.Output("button2", "children")],
-              [dash.dependencies.Input("scatter3d", "clickData"),
-               dash.dependencies.Input("feature-line", "clickData")],
-              [dash.dependencies.State("feature-line", "figure"),
-               dash.dependencies.State("scatter3d", "figure"),
-               dash.dependencies.State("button1", "children"),
-               dash.dependencies.State("button2", "children"),
-               dash.dependencies.State("alignment-data", "children")])
-def update_features(clickdata_3d, clickdata_feature, feature_data, scatter3d_data, button1, button2, alignment_data):
-    if feature_data and scatter3d_data and clickdata_feature and compress_object(
-            (clickdata_feature["points"][0]["pointNumber"], clickdata_feature["points"][0]["curveNumber"])) != button1:
-        alignment = decompress_object(alignment_data)
-        number_of_structures = len(alignment)
-        clickdata = clickdata_feature
-        idx = clickdata["points"][0]["pointNumber"]
-        protein_index = clickdata["points"][0]["curveNumber"]
-        aln_positions = aln_index_to_protein(idx, alignment)
-        button1 = compress_object((idx, protein_index))
-        x, y = clickdata["points"][0]["x"], clickdata["points"][0]["y"]
-        try:
-            maxim, minim = np.max(feature_data["data"][0]["y"]), np.min(feature_data["data"][0]["y"])
-        except KeyError:
-            return feature_data, scatter3d_data, button1, button2
-        if len(feature_data["data"]) > 2:
-            feature_data["data"] = feature_data["data"][:-1]
-        feature_data["data"] += [dict(y=[minim, maxim], x=[idx, idx], type="scatter", mode="lines",
-                                      name="selected residue")]
-        if len(scatter3d_data["data"]) > number_of_structures:
-            scatter3d_data["data"] = scatter3d_data["data"][:-1]
-        to_add = []
-        for i in range(len(scatter3d_data["data"])):
-            d = scatter3d_data["data"][i]
-            k = d["name"]
-            p = aln_positions[k]
-            if p is not None:
-                x, y, z = d["x"][p], d["y"][p], d["z"][p]
-                to_add.append((x, y, z))
-            else:
-                continue
-        scatter3d_data["data"] += [dict(x=[x[0] for x in to_add],
-                                        y=[y[1] for y in to_add],
-                                        z=[z[2] for z in to_add], type="scatter3d", mode="markers",
-                                        name="selected residues")]
-        return feature_data, scatter3d_data, button1, button2
-    if feature_data and scatter3d_data and clickdata_3d and compress_object(
-            (clickdata_3d["points"][0]["pointNumber"], clickdata_3d["points"][0]["curveNumber"])) != button2:
-        alignment = decompress_object(alignment_data)
-        number_of_structures = len(alignment)
-        clickdata = clickdata_3d
-        idx = clickdata["points"][0]["pointNumber"]
-        protein_index = clickdata["points"][0]["curveNumber"]
-        button2 = compress_object((idx, protein_index))
-        gapped_sequence = list(alignment.values())[protein_index]
-        aln_index = protein_to_aln_index(idx, gapped_sequence)
-        x, y, z = clickdata["points"][0]["x"], clickdata["points"][0]["y"], clickdata["points"][0]["z"]
-        try:
-            maxim, minim = np.max(feature_data["data"][0]["y"]), np.min(feature_data["data"][0]["y"])
-        except KeyError:
-            return feature_data, scatter3d_data, button1, button2
-        if len(feature_data["data"]) > 2:
-            feature_data["data"] = feature_data["data"][:-1]
-        feature_data["data"] += [dict(y=[minim, maxim], x=[aln_index, aln_index], type="scatter", mode="lines",
-                                      name="selected_residue")]
-        if len(scatter3d_data["data"]) > number_of_structures:
-            scatter3d_data["data"] = scatter3d_data["data"][:-1]
-        scatter3d_data["data"] += [dict(y=[y], x=[x], z=[z], type="scatter3d", mode="markers",
-                                        name="selected residue")]
-        return feature_data, scatter3d_data, button1, button2
-
-    elif feature_data and scatter3d_data:
-        return feature_data, scatter3d_data, button1, button2
-
-
-def run_server(host="0.0.0.0", port=3003):
-    """
-    caretta-app is the GUI of caretta, capable of aligning and visualising multiple protein structures
-    and allowing extraction of aligned features such as bond angles, residue depths and fluctuations.
-    ----------
-    host
-        host ip (string)
-    port
-        port
-    """
-    app.run_server(host=host, port=port, debug=False)
-
-
-if __name__ == '__main__':
-    fire.Fire(run_server)
diff --git a/bin/caretta-cli b/bin/caretta-cli
index 7ed61a36b74fb0fdb31827254ad622b930cb8da3..d9d0c97f8f4e4b0fb7503eaf2a987b37833daf86 100755
--- a/bin/caretta-cli
+++ b/bin/caretta-cli
@@ -1,71 +1,103 @@
 #!/usr/bin/env python3
-
+from caretta import multiple_alignment
 from pathlib import Path
+import typer
+
+app = typer.Typer()
+
+
+def input_folder_callback(folder: Path) -> Path:
+    if not folder.exists():
+        raise typer.BadParameter(f"Folder {folder} does not exist")
+    return folder
 
-import fire
 
-from caretta import msa_numba
+def output_folder_callback(folder: Path) -> Path:
+    if folder.exists():
+        raise typer.BadParameter(
+            f"Folder {folder} already exists, cowardly refusing to overwrite. Please delete it and try again"
+        )
+    return folder
 
 
-def align(input_pdb,
-          dssp_dir="caretta_tmp", num_threads=20, extract_all_features=False,
-          gap_open_penalty=1., gap_extend_penalty=0.01, consensus_weight=1.,
-          write_fasta=True, output_fasta_filename=Path("./result.fasta"),
-          write_pdb=True, output_pdb_folder=Path("./result_pdb/"),
-          write_features=True, output_feature_filename=Path("./result_features.pkl"),
-          write_class=True, output_class_filename=Path("./result_class.pkl"),
-          overwrite_dssp=False):
+def positive_penalty(value: float) -> float:
+    if value < 0.0:
+        raise typer.BadParameter(f"Value {value} must be positive")
+    return value
+
+
+@app.command()
+def align(
+    input_pdb: Path = typer.Argument(
+        ..., help="A folder with input protein files", callback=input_folder_callback
+    ),
+    gap_open_penalty: float = typer.Option(
+        1.0, "-p", help="gap open penalty", callback=positive_penalty
+    ),
+    gap_extend_penalty: float = typer.Option(
+        0.01, "-e", help="gap extend penalty", callback=positive_penalty
+    ),
+    consensus_weight: float = typer.Option(
+        1.0,
+        "--consensus-weight",
+        "-c",
+        help="weight well-aligned segments to reduce gaps in these areas",
+        callback=positive_penalty,
+    ),
+    full: bool = typer.Option(
+        False,
+        "--full",
+        "-f",
+        help="Use all vs. all pairwise alignment for distance matrix calculation (much slower)",
+    ),
+    output: Path = typer.Option(
+        Path("caretta_results"),
+        "--output",
+        "-o",
+        help="folder to store output files",
+        callback=output_folder_callback,
+    ),
+    fasta: bool = typer.Option(True, help="write alignment in FASTA file format"),
+    pdb: bool = typer.Option(
+        True, help="write PDB files superposed according to alignment"
+    ),
+    threads: int = typer.Option(
+        4, "--threads", "-t", help="number of threads to use for feature extraction"
+    ),
+    features: bool = typer.Option(
+        False,
+        "--features",
+        help="extract and write aligned features as a dictionary of NumPy arrays into a pickle file",
+    ),
+    write_class: bool = typer.Option(
+        False,
+        "--class",
+        help="write StructureMultiple class with intermediate structures and tree to pickle file",
+    ),
+):
     """
-    Caretta aligns protein structures and returns a sequence alignment, superposed PDB files, a set of aligned feature matrices, and
-    a class with intermediate structures made during progressive alignment.
-    Parameters
-    ----------
-    input_pdb
-        Can be \n
-        A folder with input protein files \n
-        A file which lists PDB filenames on each line \n
-        A file which lists PDB IDs on each line \n
-    dssp_dir
-        Folder to store temp DSSP files (default caretta_tmp)
-    num_threads
-        Number of threads to use for feature extraction
-    extract_all_features
-        True => obtains all features (default True) \n
-        False => only DSSP features (faster)
-    gap_open_penalty
-        default 1
-    gap_extend_penalty
-        default 0.01
-    consensus_weight
-        default 1
-    write_fasta
-        True => writes alignment as fasta file (default True)
-    output_fasta_filename
-        Fasta file of alignment (default result.fasta)
-    write_pdb
-        True => writes all protein PDB files superposed by alignment (default True)
-    output_pdb_folder
-        Folder to write superposed PDB files (default result_pdb)
-    write_features
-        True => writes aligned features a s a dictionary of numpy arrays into a pickle file (default True)
-    output_feature_filename
-        Pickle file to write aligned features (default result_features.pkl)
-    write_class
-        True => writes StructureMultiple class with intermediate structures and tree to pickle file (default True)
-    output_class_filename
-        Pickle file to write StructureMultiple class (default result_class.pkl)
-    overwrite_dssp
-        Forces DSSP to rerun (default False)
+    Align protein structures using Caretta.
+
+    Writes the resulting sequence alignment and superposed PDB files to "caretta_results".
+    Optionally also outputs a set of aligned feature matrices, or the python class with intermediate structures made during progressive alignment.
     """
-    msa_numba.StructureMultiple.align_from_pdb_files(input_pdb,
-                                                     dssp_dir, num_threads, extract_all_features,
-                                                     gap_open_penalty, gap_extend_penalty, consensus_weight,
-                                                     write_fasta, output_fasta_filename,
-                                                     write_pdb, output_pdb_folder,
-                                                     write_features, output_feature_filename,
-                                                     write_class, output_class_filename,
-                                                     overwrite_dssp)
+    input_pdb = input_folder_callback(input_pdb)
+    output = output_folder_callback(output)
+    multiple_alignment.trigger_numba_compilation()
+    multiple_alignment.StructureMultiple.align_from_pdb_files(
+        input_pdb,
+        gap_open_penalty,
+        gap_extend_penalty,
+        consensus_weight,
+        full,
+        output,
+        threads,
+        fasta,
+        pdb,
+        features,
+        write_class,
+    )
 
 
-if __name__ == '__main__':
-    fire.Fire(align)
+if __name__ == "__main__":
+    app()
diff --git a/caretta/app/__init__.py b/caretta/app/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/caretta/app/app_callbacks.py b/caretta/app/app_callbacks.py
new file mode 100644
index 0000000000000000000000000000000000000000..f768a26f4a5e2fecbf5c819a070d005f419e3918
--- /dev/null
+++ b/caretta/app/app_callbacks.py
@@ -0,0 +1,590 @@
+from pathlib import Path
+from zipfile import ZipFile
+
+import dash
+import dash_bio
+import dash_core_components as dcc
+import numpy as np
+
+from caretta import multiple_alignment
+from caretta.app import app_helper, app_layout
+from flask import send_from_directory, abort
+
+
+def register_callbacks(app, get_pdb_entries, suite):
+    """
+    Register callbacks to dash app
+
+    Parameters
+    ----------
+    app
+        Dash app
+    get_pdb_entries
+        function that takes a single user inputted string and returns a list of PdbEntry objects
+    suite
+        Fernet object
+    """
+    # Function called when user inputs folder / Pfam ID and clicks "Load Structures"
+    @app.callback(
+        # Output
+        # PDB files / PDB IDs found in user-inputted folder / Pfam ID
+        dash.dependencies.Output("proteins-selection-dropdown", "options"),
+        # Input
+        [
+            # Whether the load-structures-button has been clicked
+            dash.dependencies.Input("load-structures-button", "n_clicks")
+        ],
+        # State
+        [
+            # User input (PDB folder / Pfam ID)
+            dash.dependencies.State("user-input", "value"),
+        ],
+    )
+    def load_structures(load_structures_button, user_input):
+        if load_structures_button and user_input:
+            pdb_entries = get_pdb_entries(user_input)
+            labels = []
+            for x in pdb_entries:
+                label = x.PDB_ID
+                if x.CHAIN_ID != "none":
+                    label += f".{x.CHAIN_ID}"
+                if x.PdbResNumStart != -1 and x.PdbResNumEnd != -1:
+                    label += f" {x.PdbResNumStart}-{x.PdbResNumEnd}"
+                labels.append(label)
+            return [
+                {"label": label, "value": app_helper.compress_object(x, suite)}
+                for label, x in zip(labels, pdb_entries)
+            ]
+        else:
+            return [{"label": "no selection", "value": "None"}]
+
+    # Function called when user selects IDs to align and clicks Align
+    @app.callback(
+        # Outputs
+        [
+            # hidden Divs
+            # MSA class
+            dash.dependencies.Output("caretta-class", "children"),
+            # aligned sequences dict
+            dash.dependencies.Output("sequence-alignment-data", "children"),
+            # aligned feature dict
+            dash.dependencies.Output("feature-alignment-data", "children"),
+            # sequence alignment panel
+            dash.dependencies.Output("sequence-alignment", "children"),
+            # structure alignment panel
+            dash.dependencies.Output("structure-alignment", "children"),
+            # feature panel selection box
+            dash.dependencies.Output("feature-selection-dropdown", "options"),
+            # stop loading indicator
+            dash.dependencies.Output("loading-indicator-output", "children"),
+        ],
+        # Inputs
+        [
+            # Whether the Align button was clicked
+            dash.dependencies.Input("align-button", "n_clicks")
+        ],
+        # States
+        [
+            # User input (PDB folder / Pfam ID)
+            dash.dependencies.State("user-input", "value"),
+            # Selected PDB IDs
+            dash.dependencies.State("proteins-selection-dropdown", "value"),
+            # Penalties
+            dash.dependencies.State("gap-open-dropdown", "value"),
+            dash.dependencies.State("gap-extend-dropdown", "value"),
+            # Unique ID for setting output folder
+            dash.dependencies.State("unique-id", "children"),
+        ],
+    )
+    def align_structures(
+        align_button,
+        user_input,
+        proteins_selection_dropdown,
+        gap_open_dropdown,
+        gap_extend_dropdown,
+        unique_id,
+    ):
+        if align_button and user_input and proteins_selection_dropdown:
+            pdb_entries = [
+                app_helper.decompress_object(x, suite)
+                for x in proteins_selection_dropdown
+            ]
+            if not gap_open_dropdown:
+                gap_open_dropdown = 1
+            if not gap_extend_dropdown:
+                gap_extend_dropdown = 0.01
+            pdb_files = []
+            for p in pdb_entries:
+                try:
+                    pdb_files.append(p.get_pdb()[1])
+                except (OSError, AttributeError):
+                    continue
+            msa_class = multiple_alignment.StructureMultiple.from_pdb_files(
+                pdb_files,
+                multiple_alignment.DEFAULT_SUPERPOSITION_PARAMETERS,
+                output_folder=f"static/results_{app_helper.decompress_object(unique_id, suite)}",
+            )
+            if len(msa_class.structures) > 2:
+                pw_matrix = msa_class.make_pairwise_shape_matrix()
+                sequence_alignment = msa_class.align(
+                    pw_matrix,
+                    gap_open_penalty=gap_open_dropdown,
+                    gap_extend_penalty=gap_extend_dropdown,
+                )
+            else:
+                sequence_alignment = msa_class.align(
+                    pw_matrix=None,
+                    gap_open_penalty=gap_open_dropdown,
+                    gap_extend_penalty=gap_extend_dropdown,
+                )
+            msa_class.superpose()
+            fasta = app_helper.to_fasta_str(sequence_alignment)
+            dssp_dir = msa_class.output_folder / ".caretta_tmp"
+            if not dssp_dir.exists():
+                dssp_dir.mkdir()
+            features = msa_class.get_aligned_features(dssp_dir, 4)
+            caretta_class = app_helper.compress_object(msa_class, suite)
+            sequence_alignment_data = app_helper.compress_object(
+                sequence_alignment, suite
+            )
+            feature_alignment_data = app_helper.compress_object(features, suite)
+
+            sequence_alignment_component = dash_bio.AlignmentChart(
+                id="sequence-alignment-graph",
+                data=fasta,
+                showconsensus=False,
+                showconservation=False,
+                overview=None,
+                height=300,
+                colorscale="hydrophobicity",
+            )
+            structure_alignment_component = dcc.Graph(
+                figure=app_helper.scatter3D(
+                    {s.name: s.coordinates for s in msa_class.structures}
+                ),
+                id="scatter-plot",
+            )
+            feature_selection_dropdown = [{"label": x, "value": x} for x in features]
+            loading_indicator_output = ""
+            return (
+                caretta_class,
+                sequence_alignment_data,
+                feature_alignment_data,
+                sequence_alignment_component,
+                structure_alignment_component,
+                feature_selection_dropdown,
+                loading_indicator_output,
+            )
+        else:
+            return (
+                app_helper.empty_object(suite),
+                app_helper.empty_object(suite),
+                app_helper.empty_object(suite),
+                "",
+                "",
+                [{"label": "no alignment present", "value": "no alignment"}],
+                "",
+            )
+
+    # Function that displays mean +/- stdev and heatmap of user-selected feature
+    @app.callback(
+        # Outputs
+        [
+            # Feature line panel
+            dash.dependencies.Output("feature-line", "children"),
+            # Feature heatmap panel
+            dash.dependencies.Output("feature-heatmap", "children"),
+        ],
+        # Inputs
+        [
+            # Whether the display feature button has been clicked
+            dash.dependencies.Input("display-feature-button", "n_clicks")
+        ],
+        # States
+        [
+            # Dropdown of feature names
+            dash.dependencies.State("feature-selection-dropdown", "value"),
+            # Aligned features dict
+            dash.dependencies.State("feature-alignment-data", "children"),
+        ],
+    )
+    def display_feature(
+        display_feature_button_clicked,
+        feature_selection_dropdown_value,
+        feature_alignment_data,
+    ):
+        if (
+            display_feature_button_clicked
+            and feature_selection_dropdown_value
+            and feature_alignment_data
+        ):
+            feature_alignment_dict = app_helper.decompress_object(
+                feature_alignment_data, suite
+            )
+            chosen_feature_data = feature_alignment_dict[
+                feature_selection_dropdown_value
+            ]
+            feature_line_component = dcc.Graph(
+                figure=app_helper.line(chosen_feature_data), id="feature-line-graph"
+            )
+            feature_heatmap_component = dcc.Graph(
+                figure=app_helper.heatmap(chosen_feature_data),
+                id="feature-heatmap-graph",
+            )
+            return feature_line_component, feature_heatmap_component
+        else:
+            return (
+                dcc.Graph(
+                    figure=app_helper.empty_dict(),
+                    id="feature-line-graph",
+                    style={"display": "none"},
+                ),
+                dcc.Graph(
+                    figure=app_helper.empty_dict(),
+                    id="feature-heatmap-graph",
+                    style={"display": "none"},
+                ),
+            )
+
+    # Function that updated selected residues in structure alignment panel and feature alignment panel
+    @app.callback(
+        # Outputs
+        [
+            # hidden Divs
+            # Residue position selected in structure alignment panel
+            dash.dependencies.Output(
+                "structure-alignment-selected-residue", "children"
+            ),
+            # Residue position selected in feature alignment panel
+            dash.dependencies.Output("feature-alignment-selected-residue", "children"),
+            # Feature line component
+            dash.dependencies.Output("feature-line-graph", "figure"),
+            # 3D scatter component
+            dash.dependencies.Output("scatter-plot", "figure"),
+        ],
+        # Inputs
+        [
+            # Clicked indices in 3D scatter component
+            dash.dependencies.Input("scatter-plot", "clickData"),
+            # Clicked indices in feature line component
+            dash.dependencies.Input("feature-line-graph", "clickData"),
+        ],
+        # States
+        [
+            # Feature line component
+            dash.dependencies.State("feature-line-graph", "figure"),
+            # 3D scatter component
+            dash.dependencies.State("scatter-plot", "figure"),
+            # Residue position selected in structure alignment panel
+            dash.dependencies.State("structure-alignment-selected-residue", "children"),
+            # Residue position selected in feature alignment panel
+            dash.dependencies.State("feature-alignment-selected-residue", "children"),
+            # Aligned sequences dict
+            dash.dependencies.State("sequence-alignment-data", "children"),
+        ],
+    )
+    def update_interactive_panels(
+        scatter_plot_clickdata,
+        feature_line_clickdata,
+        feature_line_graph,
+        scatter_plot,
+        structure_alignment_selected_residue,
+        feature_alignment_selected_residue,
+        sequence_alignment_data,
+    ):
+        if feature_line_graph and scatter_plot:
+            changed = None
+            clickdata = None
+            if (
+                feature_line_clickdata
+                and app_helper.compress_object(
+                    (
+                        feature_line_clickdata["points"][0]["pointNumber"],
+                        feature_line_clickdata["points"][0]["curveNumber"],
+                    ),
+                    suite,
+                )
+                != feature_alignment_selected_residue
+            ):
+                clickdata = feature_line_clickdata
+                changed = "feature-panel"
+            elif (
+                scatter_plot_clickdata
+                and app_helper.compress_object(
+                    (
+                        scatter_plot_clickdata["points"][0]["pointNumber"],
+                        scatter_plot_clickdata["points"][0]["curveNumber"],
+                    ),
+                    suite,
+                )
+                != structure_alignment_selected_residue
+            ):
+                clickdata = scatter_plot_clickdata
+                changed = "structure-panel"
+            if changed is not None and clickdata is not None:
+                # Save new clicked index
+                aln_index = clickdata["points"][0]["pointNumber"]
+                protein_index = clickdata["points"][0]["curveNumber"]
+                if changed == "feature-panel":
+                    feature_alignment_selected_residue = app_helper.compress_object(
+                        (aln_index, protein_index), suite
+                    )
+                elif changed == "structure-panel":
+                    structure_alignment_selected_residue = app_helper.compress_object(
+                        (aln_index, protein_index), suite
+                    )
+
+                sequence_alignment = app_helper.decompress_object(
+                    sequence_alignment_data, suite
+                )
+                number_of_structures = len(sequence_alignment)
+
+                try:
+                    maxim, minim = (
+                        np.max(feature_line_graph["data"][0]["y"]),
+                        np.min(feature_line_graph["data"][0]["y"]),
+                    )
+                except KeyError:
+                    return (
+                        structure_alignment_selected_residue,
+                        feature_alignment_selected_residue,
+                        feature_line_graph,
+                        scatter_plot,
+                    )
+                if len(feature_line_graph["data"]) > 2:
+                    feature_line_graph["data"] = feature_line_graph["data"][:-1]
+                if len(scatter_plot["data"]) > number_of_structures:
+                    scatter_plot["data"] = scatter_plot["data"][:-1]
+
+                if changed == "feature-panel":
+                    aln_positions = app_helper.aln_index_to_protein(
+                        aln_index, sequence_alignment
+                    )
+                    feature_line_graph["data"] += [
+                        dict(
+                            y=[minim, maxim],
+                            x=[aln_index, aln_index],
+                            type="scatter",
+                            mode="lines",
+                            name="selected residue",
+                        )
+                    ]
+
+                    to_add = []
+                    for i in range(len(scatter_plot["data"])):
+                        p = aln_positions[scatter_plot["data"][i]["name"]]
+                        if p is not None:
+                            x, y, z = (
+                                scatter_plot["data"][i]["x"][p],
+                                scatter_plot["data"][i]["y"][p],
+                                scatter_plot["data"][i]["z"][p],
+                            )
+                            to_add.append((x, y, z))
+                        else:
+                            continue
+                    scatter_plot["data"] += [
+                        dict(
+                            x=[x[0] for x in to_add],
+                            y=[y[1] for y in to_add],
+                            z=[z[2] for z in to_add],
+                            type="scatter3d",
+                            mode="markers",
+                            name="selected residues",
+                        )
+                    ]
+                elif changed == "structure-panel":
+                    aligned_sequence = list(sequence_alignment.values())[protein_index]
+                    aln_index = app_helper.protein_to_aln_index(
+                        aln_index, aligned_sequence
+                    )
+                    x, y, z = (
+                        clickdata["points"][0]["x"],
+                        clickdata["points"][0]["y"],
+                        clickdata["points"][0]["z"],
+                    )
+                    feature_line_graph["data"] += [
+                        dict(
+                            y=[minim, maxim],
+                            x=[aln_index, aln_index],
+                            type="scatter",
+                            mode="lines",
+                            name="selected_residue",
+                        )
+                    ]
+                    scatter_plot["data"] += [
+                        dict(
+                            y=[y],
+                            x=[x],
+                            z=[z],
+                            type="scatter3d",
+                            mode="markers",
+                            name="selected residue",
+                        )
+                    ]
+        return (
+            structure_alignment_selected_residue,
+            feature_alignment_selected_residue,
+            feature_line_graph,
+            scatter_plot,
+        )
+
+    # Function to export FASTA alignment file
+    @app.callback(
+        # Outputs
+        # Link ot download file
+        dash.dependencies.Output("fasta-download-link", "children"),
+        # Inputs
+        [
+            # Whether the FASTA download button has been clicked
+            dash.dependencies.Input("fasta-download-button", "n_clicks")
+        ],
+        # States
+        [
+            # Aligned sequences dict
+            dash.dependencies.State("sequence-alignment-data", "children"),
+            # MSA class
+            dash.dependencies.State("caretta-class", "children"),
+        ],
+    )
+    def download_alignment(
+        fasta_download_button_clicked, sequence_alignment_data, caretta_class
+    ):
+        if fasta_download_button_clicked and sequence_alignment_data and caretta_class:
+            sequence_alignment = app_helper.decompress_object(
+                sequence_alignment_data, suite
+            )
+            if not sequence_alignment:
+                return ""
+            msa_class = app_helper.decompress_object(caretta_class, suite)
+            msa_class.write_files(
+                write_pdb=False,
+                write_fasta=True,
+                write_class=False,
+                write_features=False,
+            )
+            return app_layout.get_download_string(str(Path("static") / "result.fasta"),)
+        else:
+            return ""
+
+    # Function to export superposed PDB files
+    @app.callback(
+        # Output
+        # Link to download files
+        dash.dependencies.Output("pdb-download-link", "children"),
+        # Inputs
+        [
+            # Whether the PDB download button has been clicked
+            dash.dependencies.Input("pdb-download-button", "n_clicks")
+        ],
+        # States
+        [
+            # Aligned sequences dict
+            dash.dependencies.State("sequence-alignment-data", "children"),
+            # MSA class
+            dash.dependencies.State("caretta-class", "children"),
+        ],
+    )
+    def download_pdb(
+        pdb_download_button_clicked, sequence_alignment_data, caretta_class
+    ):
+        if pdb_download_button_clicked and sequence_alignment_data and caretta_class:
+            sequence_alignment = app_helper.decompress_object(
+                sequence_alignment_data, suite
+            )
+            if not sequence_alignment:
+                return ""
+            msa_class = app_helper.decompress_object(caretta_class, suite)
+
+            msa_class.write_files(
+                write_pdb=True,
+                write_fasta=False,
+                write_class=False,
+                write_features=False,
+            )
+            output_filename = f"{msa_class.output_folder}/superposed_pdbs.zip"
+            pdb_zip_file = ZipFile(output_filename, mode="w")
+            for pdb_file in (Path(msa_class.output_folder) / "superposed_pdbs").glob(
+                "*.pdb"
+            ):
+                pdb_zip_file.write(
+                    str(pdb_file), arcname=f"{pdb_file.stem}{pdb_file.suffix}"
+                )
+            return app_layout.get_download_string(output_filename)
+        else:
+            return ""
+
+    # Function to export feature files
+    @app.callback(
+        # Outputs
+        [
+            # Link to download feature files
+            dash.dependencies.Output("feature-download-link", "children"),
+            # # Div with export buttons
+            dash.dependencies.Output("feature-exporter", "children"),
+        ],
+        # Inputs
+        [
+            # Whether the export feature button has been clicked
+            dash.dependencies.Input("export-feature-button", "n_clicks"),
+            # Whether the export all features button has been clicked
+            dash.dependencies.Input("export-all-features-button", "n_clicks"),
+        ],
+        # States
+        [
+            # Selected value in the dropdown of feature names
+            dash.dependencies.State("feature-selection-dropdown", "value"),
+            # Aligned feature dict
+            dash.dependencies.State("feature-alignment-data", "children"),
+            # MSA class
+            dash.dependencies.State("caretta-class", "children"),
+        ],
+    )
+    def download_features(
+        export_feature_button_clicked,
+        export_all_features_button_clicked,
+        feature_selection_dropdown_value,
+        feature_alignment_data,
+        caretta_class,
+    ):
+        output_string = ""
+        if feature_alignment_data and caretta_class:
+            feature_alignment_dict = app_helper.decompress_object(
+                feature_alignment_data, suite
+            )
+            msa_class = app_helper.decompress_object(caretta_class, suite)
+            protein_names = [s.name for s in msa_class.structures]
+            if (
+                export_feature_button_clicked and feature_selection_dropdown_value
+            ) and not export_all_features_button_clicked:
+                output_filename = f"{msa_class.output_folder}/{'-'.join(feature_selection_dropdown_value.split())}.csv"
+                app_helper.write_feature_as_tsv(
+                    feature_alignment_dict[feature_selection_dropdown_value],
+                    protein_names,
+                    output_filename,
+                )
+                output_string = app_layout.get_download_string(output_filename)
+            elif (
+                export_all_features_button_clicked and not export_feature_button_clicked
+            ):
+                output_filename = f"{msa_class.output_folder}/features.zip"
+                features_zip_file = ZipFile(output_filename, mode="w")
+                for feature in feature_alignment_dict:
+                    feature_file = (
+                        f"{msa_class.output_folder}/{'-'.join(feature.split())}.csv"
+                    )
+                    app_helper.write_feature_as_tsv(
+                        feature_alignment_dict[feature], protein_names, feature_file
+                    )
+                    features_zip_file.write(
+                        str(feature_file), arcname=f"{'-'.join(feature.split())}.csv"
+                    )
+                output_string = app_layout.get_download_string(output_filename)
+        return output_string, app_layout.get_export_feature_buttons()
+
+    @app.server.route("/caretta/static/<path:path>")
+    def download(path):
+        """Serve a file from the static directory."""
+        try:
+            return send_from_directory(str(Path.cwd() / "static"), path)
+        except FileNotFoundError:
+            abort(404)
diff --git a/caretta/pfam.py b/caretta/app/app_helper.py
similarity index 52%
rename from caretta/pfam.py
rename to caretta/app/app_helper.py
index 708851d2a15fda4983df3140e6a26fbfcddf4500..a74bad0060baed2dd3784e955f6e0d8c0b3f2efc 100644
--- a/caretta/pfam.py
+++ b/caretta/app/app_helper.py
@@ -1,10 +1,148 @@
+import base64
+import datetime
+import pickle
 from dataclasses import dataclass
 from pathlib import Path
-
+import numpy as np
 import prody as pd
 import requests as rq
 
-from caretta import msa_numba
+import typing
+
+from caretta import multiple_alignment
+
+
+def heatmap(data):
+    return dict(
+        data=[dict(z=data, type="heatmap", showscale=False)],
+        layout=dict(margin=dict(l=25, r=25, t=25, b=25)),
+    )
+
+
+def empty_dict():
+    data = [dict(z=np.zeros((2, 2)), type="heatmap", showscale=False)]
+    layout = dict(margin=dict(l=25, r=25, t=25, b=25))
+    return dict(data=data, layout=layout)
+
+
+def empty_object(suite):
+    return compress_object(np.zeros(0), suite)
+
+
+def get_estimated_time(msa_class: multiple_alignment.StructureMultiple):
+    n = len(msa_class.structures)
+    l = max(s.length for s in msa_class.structures)
+    func = lambda x, r: (x[0] ** 2 * r * x[1] ** 2)
+    return str(datetime.timedelta(seconds=int(func((l, n), 9.14726052e-06))))
+
+
+def line(data):
+    y = np.array([np.nanmean(data[:, x]) for x in range(data.shape[1])])
+    y_se = np.array(
+        [np.nanstd(data[:, x]) / np.sqrt(data.shape[1]) for x in range(data.shape[1])]
+    )
+
+    data = [
+        dict(
+            y=list(y + y_se) + list(y - y_se)[::-1],
+            x=list(range(data.shape[1])) + list(range(data.shape[1]))[::-1],
+            fillcolor="lightblue",
+            fill="toself",
+            type="scatter",
+            mode="lines",
+            name="Standard error",
+            line=dict(color="lightblue"),
+        ),
+        dict(
+            y=y,
+            x=np.arange(data.shape[1]),
+            type="scatter",
+            mode="lines",
+            name="Mean",
+            line=dict(color="blue"),
+        ),
+    ]
+    return dict(
+        data=data,
+        layout=dict(legend=dict(x=0.5, y=1.2), margin=dict(l=25, r=25, t=25, b=25)),
+    )
+
+
+def scatter3D(coordinates_dict):
+    data = []
+    for k, v in coordinates_dict.items():
+        x, y, z = v[:, 0], v[:, 1], v[:, 2]
+        data.append(
+            dict(
+                x=x,
+                y=y,
+                z=z,
+                mode="lines",
+                type="scatter3d",
+                text=None,
+                name=str(k),
+                line=dict(width=3, opacity=0.8),
+            )
+        )
+    layout = dict(
+        margin=dict(l=20, r=20, t=20, b=20),
+        clickmode="event+select",
+        scene=dict(
+            xaxis=dict(visible=False, showgrid=False, showline=False),
+            yaxis=dict(visible=False, showgrid=False, showline=False),
+            zaxis=dict(visible=False, showgrid=False, showline=False),
+        ),
+    )
+    return dict(data=data, layout=layout)
+
+
+def write_feature_as_tsv(
+    feature_data: np.ndarray, keys: typing.List[str], file_name: typing.Union[Path, str]
+):
+    with open(file_name, "w") as f:
+        for i in range(feature_data.shape[0]):
+            f.write(
+                "\t".join([keys[i]] + [str(x) for x in list(feature_data[i])]) + "\n"
+            )
+
+
+def compress_object(raw_object, suite):
+    return base64.b64encode(suite.encrypt(pickle.dumps(raw_object, protocol=4))).decode(
+        "utf-8"
+    )
+
+
+def decompress_object(compressed_object, suite):
+    return pickle.loads(suite.decrypt(base64.b64decode(compressed_object)))
+
+
+def protein_to_aln_index(protein_index, aln_seq):
+    n = 0
+    for i in range(len(aln_seq)):
+        if protein_index == n:
+            return i
+        elif aln_seq[i] == "-":
+            pass
+        else:
+            n += 1
+
+
+def aln_index_to_protein(alignment_index, alignment):
+    res = dict()
+    for k, v in alignment.items():
+        if v[alignment_index] == "-":
+            res[k] = None
+        else:
+            res[k] = alignment_index - v[:alignment_index].count("-")
+    return res
+
+
+def to_fasta_str(alignment):
+    res = []
+    for k, v in alignment.items():
+        res.append(f">{k}")
+        res.append(v)
+    return "\n".join(res)
 
 
 @dataclass
@@ -51,8 +189,10 @@ class PdbEntry:
         )
 
     @classmethod
-    def from_user_input(cls, pdb_path, chain_id="A"):
-        return cls(pdb_path, chain_id, -1, -1, "none", "none", "none", 0.0, pdb_path)
+    def from_user_input(cls, pdb_path, chain_id="none"):
+        return cls(
+            Path(pdb_path).stem, chain_id, -1, -1, "none", "none", "none", 0.0, pdb_path
+        )
 
     def get_pdb(self, from_atm_file=None):
         if from_atm_file is not None:
@@ -98,14 +238,6 @@ class PdbEntry:
         return f"{self.PDB_ID}_{self.CHAIN_ID}_{self.PdbResNumStart}"
 
 
-def get_pdbs_from_folder(path):
-    file_names = Path(path).glob("*.pdb")
-    res = []
-    for f in file_names:
-        res.append(PdbEntry.from_user_input(str(f)))
-    return res
-
-
 class PfamToPDB:
     def __init__(
         self,
@@ -125,6 +257,7 @@ class PfamToPDB:
         data_lines = data_lines[1:]
         self.pfam_to_pdb_ids = dict()
         self._initiate_pfam_to_pdbids(data_lines, limit=limit)
+        self.pdb_entries = None
         self.msa = None
         self.caretta_alignment = None
 
@@ -152,66 +285,10 @@ class PfamToPDB:
                     n += 1
             self.pfam_to_pdb_ids = new
 
-    def multiple_structure_alignment_from_pfam(
-        self, pdb_entries, gap_open_penalty=0.1, gap_extend_penalty=0.001
-    ):
-        self.msa = PfamStructures.from_pdb_files([p.get_pdb()[1] for p in pdb_entries])
-        self.caretta_alignment = self.msa.align(
-            gap_open_penalty=gap_open_penalty, gap_extend_penalty=gap_extend_penalty
-        )
-        return (
-            self.caretta_alignment,
-            {s.name: pd.parsePDB(s.pdb_file) for s in self.msa.structures},
-            {s.name: s.features for s in self.msa.structures},
-        )
-
     def get_entries_for_pfam(
-        self, pfam_id, limit_by_score=1.0, limit_by_protein_number=50, gross_limit=1000
+        self, pfam_id, limit_by_score=1.0, limit_by_protein_number=50
     ):
         pdb_entries = list(
             filter(lambda x: (x.eValue < limit_by_score), self.pfam_to_pdb_ids[pfam_id])
         )[:limit_by_protein_number]
         return pdb_entries
-
-    def alignment_from_folder(self):
-        pass
-
-    def to_fasta_str(self, alignment):
-        res = []
-        for k, v in alignment.items():
-            res.append(f">{k}")
-            res.append(v)
-        return "\n".join(res)
-
-
-class PfamStructures(msa_numba.StructureMultiple):
-    def __init__(
-        self,
-        pdb_entries,
-        dssp_dir="caretta_tmp",
-        num_threads=20,
-        extract_all_features=True,
-        consensus_weight=1.0,
-        write_fasta=True,
-        output_fasta_filename=Path("./result.fasta"),
-        write_pdb=True,
-        output_pdb_folder=Path("./result_pdb/"),
-        write_features=True,
-        output_feature_filename=Path("./result_features.pkl"),
-        write_class=True,
-        output_class_filename=Path("./result_class.pkl"),
-        overwrite_dssp=False,
-    ):
-        self.pdb_entries = pdb_entries
-        super(PfamStructures, self).from_pdb_files(
-            [p.get_pdb()[1] for p in self.pdb_entries],
-            dssp_dir,
-            num_threads,
-            extract_all_features,
-            consensus_weight,
-            output_fasta_filename,
-            output_pdb_folder,
-            output_feature_filename,
-            output_class_filename,
-            overwrite_dssp,
-        )
diff --git a/caretta/app/app_layout.py b/caretta/app/app_layout.py
new file mode 100644
index 0000000000000000000000000000000000000000..262e710d0e557ddd8a10aad9bde722504b245152
--- /dev/null
+++ b/caretta/app/app_layout.py
@@ -0,0 +1,366 @@
+import dash_core_components as dcc
+import dash_html_components as html
+import numpy as np
+from caretta.app import app_helper
+from pathlib import Path
+
+box_style = {
+    "box-shadow": "1px 3px 20px -4px rgba(0,0,0,0.75)",
+    "border-radius": "5px",
+    "background-color": "#f9f7f7",
+}
+
+
+def get_layout(
+    introduction_text,
+    input_text,
+    placeholder_text,
+    selection_text,
+    suite,
+    pfam_class=None,
+):
+    return html.Div(
+        children=[
+            get_introduction_panel(introduction_text),
+            html.Br(),
+            get_input_panel_layout(
+                input_text, placeholder_text, selection_text, pfam_class
+            ),
+            html.Br(),
+            get_hidden_variables_layout(suite),
+            get_sequence_alignment_layout(),
+            html.Br(),
+            get_structure_alignment_layout(),
+            html.Br(),
+            get_feature_alignment_panel(),
+            html.Br(),
+        ]
+    )
+
+
+def get_introduction_panel(introduction_text: str):
+    return html.Div(
+        html.Div(
+            [
+                html.H1("Caretta", style={"text-align": "center"}),
+                html.H3(
+                    "a multiple protein structure alignment and feature extraction suite",
+                    style={"text-align": "center"},
+                ),
+                html.P(dcc.Markdown(introduction_text), style={"text-align": "left"}),
+            ],
+            className="row",
+        ),
+        className="container",
+    )
+
+
+def get_hidden_variables_layout(suite):
+    return html.Div(
+        [
+            html.P(children="", id="proteins-list", style={"display": "none"}),
+            html.P(children="", id="feature-alignment-data", style={"display": "none"}),
+            html.P(
+                children=app_helper.compress_object(0, suite),
+                id="structure-alignment-selected-residue",
+                style={"display": "none"},
+            ),
+            html.P(
+                children=app_helper.compress_object(0, suite),
+                id="feature-alignment-selected-residue",
+                style={"display": "none"},
+            ),
+            html.P(
+                children="", id="sequence-alignment-data", style={"display": "none"}
+            ),
+            html.P(children="", id="caretta-class", style={"display": "none"}),
+            html.P(
+                children=app_helper.compress_object(
+                    np.random.randint(0, 1000000000), suite
+                ),
+                id="unique-id",
+                style={"display": "none"},
+            ),
+        ]
+    )
+
+
+def get_input_panel_layout(
+    input_text: str,
+    placeholder_text: str,
+    selection_text: str,
+    pfam_class: app_helper.PfamToPDB = None,
+):
+    if pfam_class is not None:
+        user_input = dcc.Dropdown(
+            placeholder="Choose Pfam ID",
+            options=[{"label": x, "value": x} for x in pfam_class.pfam_to_pdb_ids],
+            id="user-input",
+        )
+
+    else:
+        user_input = (
+            dcc.Textarea(
+                placeholder=placeholder_text, value="", id="user-input", required=True,
+            ),
+        )
+    return html.Div(
+        [
+            html.Div(
+                [
+                    html.Br(),
+                    html.H3(
+                        "Choose Structures",
+                        className="row",
+                        style={"text-align": "center"},
+                    ),
+                    html.P(input_text, className="row",),
+                    html.Div(
+                        [
+                            html.Div(user_input, className="four columns",),
+                            html.P(
+                                dcc.Markdown(selection_text), className="four columns"
+                            ),
+                            html.Button(
+                                "Load Structures",
+                                className="four columns",
+                                id="load-structures-button",
+                            ),
+                        ],
+                        className="row",
+                    ),
+                    html.Div(
+                        [
+                            html.Div(
+                                dcc.Dropdown(
+                                    placeholder="Gap open penalty (1.0)",
+                                    options=[
+                                        {"label": np.round(x, decimals=2), "value": x,}
+                                        for x in np.arange(0, 5, 0.1)
+                                    ],
+                                    id="gap-open-dropdown",
+                                ),
+                                className="four columns",
+                            ),
+                            html.Div(
+                                dcc.Dropdown(
+                                    multi=True,
+                                    id="proteins-selection-dropdown",
+                                    placeholder="Select PDB IDs to align",
+                                ),
+                                className="four columns",
+                            ),
+                            html.Div(
+                                dcc.Dropdown(
+                                    placeholder="Gap extend penalty (0.01)",
+                                    options=[
+                                        {"label": np.round(x, decimals=3), "value": x,}
+                                        for x in np.arange(0, 1, 0.002)
+                                    ],
+                                    id="gap-extend-dropdown",
+                                ),
+                                className="four columns",
+                            ),
+                        ],
+                        className="row",
+                    ),
+                    html.Br(),
+                    html.Div(
+                        html.Button(
+                            "Align Structures",
+                            className="twelve columns",
+                            id="align-button",
+                        ),
+                        className="row",
+                    ),
+                    dcc.Loading(
+                        id="loading-indicator",
+                        children=[
+                            html.Div(
+                                id="loading-indicator-output",
+                                style={"text-align": "center"},
+                            )
+                        ],
+                        type="default",
+                    ),
+                    html.P(
+                        id="time-estimate",
+                        style={"text-align": "center"},
+                        children="",
+                        className="row",
+                    ),
+                ],
+                className="container",
+            ),
+            html.Br(),
+        ],
+        className="container",
+        style=box_style,
+    )
+
+
+def get_sequence_alignment_layout():
+    return html.Div(
+        children=[
+            html.Br(),
+            html.H3(
+                "Sequence alignment", className="row", style={"text-align": "center"},
+            ),
+            html.Div(html.P("", className="row"), className="container"),
+            html.Div(
+                [
+                    html.Button(
+                        "Download sequence alignment",
+                        className="twelve columns",
+                        id="fasta-download-button",
+                    ),
+                    html.Div(children="", className="row", id="fasta-download-link"),
+                ],
+                className="container",
+            ),
+            html.Div(
+                html.P(id="sequence-alignment", className="twelve columns"),
+                className="row",
+            ),
+        ],
+        className="container",
+        style=box_style,
+    )
+
+
+def get_structure_alignment_layout():
+    return html.Div(
+        [
+            html.Br(),
+            html.H3(
+                "Structural alignment", className="row", style={"text-align": "center"},
+            ),
+            html.Div(
+                html.P(
+                    "Click on a residue to see its position on the feature alignment in the next section.",
+                    className="row",
+                ),
+                className="container",
+            ),
+            html.Div(
+                [
+                    html.Button(
+                        "Download superposed PDBs",
+                        id="pdb-download-button",
+                        className="twelve columns",
+                    ),
+                    html.Div(children="", className="row", id="pdb-download-link"),
+                ],
+                className="container",
+            ),
+            html.Div(
+                children=dcc.Graph(figure=app_helper.empty_dict(), id="scatter-plot",),
+                className="row",
+                id="structure-alignment",
+            ),
+            html.Br(),
+        ],
+        className="container",
+        style=box_style,
+    )
+
+
+def get_feature_alignment_panel():
+    return html.Div(
+        [
+            html.Br(),
+            html.Div(
+                [
+                    html.Div(
+                        [
+                            html.H3(
+                                "Feature alignment",
+                                className="row",
+                                style={"text-align": "center"},
+                            ),
+                            html.P(
+                                "Click on a position in the feature alignment to see the corresponding residues in the previous section.",
+                                className="row",
+                            ),
+                            dcc.Dropdown(
+                                placeholder="Choose a feature",
+                                id="feature-selection-dropdown",
+                                className="six columns",
+                            ),
+                            html.Button(
+                                "Display feature alignment",
+                                id="display-feature-button",
+                                className="six columns",
+                            ),
+                        ],
+                        className="row",
+                    ),
+                    html.Br(),
+                    html.Div(
+                        [
+                            html.Div(
+                                get_export_feature_buttons(), id="feature-exporter",
+                            ),
+                            html.Div(html.P(""), id="feature-download-link"),
+                            html.Br(),
+                        ]
+                    ),
+                ],
+                className="container",
+            ),
+            html.Div(
+                html.Div(
+                    dcc.Graph(figure=app_helper.empty_dict(), id="feature-line-graph",),
+                    id="feature-line",
+                    className="twelve columns",
+                ),
+                className="row",
+            ),
+            html.Div(
+                html.Div(
+                    dcc.Graph(
+                        figure=app_helper.empty_dict(), id="feature-heatmap-graph",
+                    ),
+                    id="feature-heatmap",
+                    className="twelve columns",
+                ),
+                className="row",
+            ),
+        ],
+        className="container",
+        style=box_style,
+    )
+
+
+def get_export_feature_buttons():
+    return [
+        html.Div(
+            html.Button(
+                "Download feature as tab-separated file",
+                id="export-feature-button",
+                className="twelve columns",
+            ),
+            className="row",
+        ),
+        html.Br(),
+        html.Div(
+            html.Button(
+                "Download all features",
+                id="export-all-features-button",
+                className="twelve columns",
+            ),
+            className="row",
+        ),
+        html.Br(),
+    ]
+
+
+def get_download_string(filename):
+    return html.Div(
+        html.A(
+            f"Download {Path(filename).stem}{Path(filename).suffix} here",
+            href=f"/caretta/{filename}",
+            className="twelve columns",
+        ),
+        className="container",
+    )
diff --git a/caretta/feature_extraction.py b/caretta/feature_extraction.py
index 92d26ca897ef5ea38e3f5473774c79a5af54b1bb..580651bb0c45cd4e56f62213891b4ef0a08cdae1 100644
--- a/caretta/feature_extraction.py
+++ b/caretta/feature_extraction.py
@@ -5,7 +5,7 @@ import numpy as np
 import prody as pd
 import Bio.PDB
 from Bio.PDB.ResidueDepth import get_surface, residue_depth, ca_depth, min_dist
-from geometricus import protein_utility
+from caretta import helper
 
 
 def read_pdb(input_file, name: str = None, chain: str = None) -> tuple:
@@ -94,8 +94,8 @@ def get_fluctuations(protein: pd.AtomGroup, n_modes: int = 50):
     dict of anm_ca, anm_cb, gnm_ca, gnm_cb
     """
     data = {}
-    beta_indices = protein_utility.get_beta_indices(protein)
-    alpha_indices = protein_utility.get_alpha_indices(protein)
+    beta_indices = helper.get_beta_indices(protein)
+    alpha_indices = helper.get_alpha_indices(protein)
     data["anm_cb"] = get_anm_fluctuations(protein[beta_indices], n_modes)
     data["gnm_cb"] = get_gnm_fluctuations(protein[beta_indices], n_modes)
     data["anm_ca"] = get_anm_fluctuations(protein[alpha_indices], n_modes)
@@ -120,7 +120,7 @@ def get_gnm_fluctuations(protein: pd.AtomGroup, n_modes: int = 50):
 
 
 def get_features_multiple(
-    pdb_files, dssp_dir, num_threads=20, only_dssp=True, force_overwrite=False
+    pdb_files, dssp_dir, num_threads=20, only_dssp=True, force_overwrite=True
 ):
     """
     Extract features for a list of pdb_files in parallel
@@ -132,7 +132,7 @@ def get_features_multiple(
         directory to store tmp dssp files
     num_threads
     only_dssp
-        extract only dssp features (use if not interested in features)
+        extract only dssp features
     force_overwrite
         force rerun DSSP
 
@@ -228,18 +228,18 @@ def get_dssp_features(protein_dssp):
         if label.startswith("dssp") and label not in dssp_ignore
     ]
     data = {}
-    alpha_indices = protein_utility.get_alpha_indices(protein_dssp)
+    alpha_indices = helper.get_alpha_indices(protein_dssp)
     indices = [protein_dssp[x].getData("dssp_resnum") for x in alpha_indices]
-    for label in dssp_labels:
+    assert len(alpha_indices) == len(indices)
+    for label in dssp_labels + ["secondary"]:
         label_to_index = {
             i - 1: protein_dssp[x].getData(label)
             for i, x in zip(indices, alpha_indices)
         }
-        data[f"{label}"] = np.array(
+        data[label] = np.array(
             [
                 label_to_index[i] if i in label_to_index else 0
                 for i in range(len(alpha_indices))
             ]
         )
-    data["secondary"] = protein_dssp.getData("secondary")[alpha_indices]
     return data
diff --git a/caretta/helper.py b/caretta/helper.py
index 562f2d7ed9b209fd9abef004eecc6119bb09f903..5f59141fa0d1a846be267b1c19b0c0600911cb32 100644
--- a/caretta/helper.py
+++ b/caretta/helper.py
@@ -1,7 +1,6 @@
 import subprocess
-import typing
-from pathlib import Path
-
+from typing import List, Union, Tuple
+from pathlib import Path, PosixPath
 import Bio.PDB
 import numba as nb
 import numpy as np
@@ -12,40 +11,15 @@ def secondary_to_array(secondary):
     return np.array(secondary, dtype="S1").view(np.int8)
 
 
-def aligned_string_to_array(aln: str) -> np.ndarray:
-    """
-    Aligned sequence to array of indices with gaps as -1
-
-    Parameters
-    ----------
-    aln
-
-    Returns
-    -------
-    indices
-    """
-    aln_array = np.zeros(len(aln), dtype=np.int64)
-    i = 0
-    for j in range(len(aln)):
-        if aln[j] != "-":
-            aln_array[j] = i
-            i += 1
-        else:
-            aln_array[j] = -1
-    return aln_array
-
-
 @nb.njit
-# @numba_cc.export('get_common_positions', '(i64[:], i64[:], i64)')
-def get_common_positions(aln_array_1, aln_array_2, gap=-1):
+def get_common_positions(aln_array_1, aln_array_2):
     """
-    Return positions where neither alignment has a gap
+    Return positions where neither alignment has a gap (-1)
 
     Parameters
     ----------
     aln_array_1
     aln_array_2
-    gap
 
     Returns
     -------
@@ -55,7 +29,7 @@ def get_common_positions(aln_array_1, aln_array_2, gap=-1):
         [
             aln_array_1[i]
             for i in range(len(aln_array_1))
-            if aln_array_1[i] != gap and aln_array_2[i] != gap
+            if aln_array_1[i] != -1 and aln_array_2[i] != -1
         ],
         dtype=np.int64,
     )
@@ -63,7 +37,7 @@ def get_common_positions(aln_array_1, aln_array_2, gap=-1):
         [
             aln_array_2[i]
             for i in range(len(aln_array_2))
-            if aln_array_1[i] != gap and aln_array_2[i] != gap
+            if aln_array_1[i] != -1 and aln_array_2[i] != -1
         ],
         dtype=np.int64,
     )
@@ -71,43 +45,23 @@ def get_common_positions(aln_array_1, aln_array_2, gap=-1):
 
 
 @nb.njit
-# @numba_cc.export('get_aligned_data', '(i64[:], f64[:], i64)')
-def get_aligned_data(aln_array: np.ndarray, data: np.ndarray, gap=-1):
+def nb_mean_axis_0(array: np.ndarray) -> np.ndarray:
     """
-    Fills coordinates according to an alignment
-    gaps (-1) in the sequence correspond to NaNs in the aligned coordinates
-
-    Parameters
-    ----------
-    aln_array
-        sequence (with gaps)
-    data
-        data to align
-    gap
-        character that represents gaps
-    Returns
-    -------
-    aligned coordinates
+    Same as np.mean(array, axis=0) but njitted
     """
-    pos = np.array([i for i in range(len(aln_array)) if aln_array[i] != gap])
-    assert len(pos) == data.shape[0]
-    aln_coords = np.zeros((len(aln_array), data.shape[1]))
-    aln_coords[:] = np.nan
-    aln_coords[pos] = data
-    return aln_coords
+    mean_array = np.zeros(array.shape[1])
+    for i in range(array.shape[1]):
+        mean_array[i] = np.mean(array[:, i])
+    return mean_array
 
 
 @nb.njit
-# @numba_cc.export('get_aligned_string_data', '(i64[:], i8[:], i64)')
-def get_aligned_string_data(aln_array, data, gap=-1):
-    pos = np.array([i for i in range(len(aln_array)) if aln_array[i] != gap])
-    assert len(pos) == data.shape[0]
-    aln_coords = np.zeros(aln_array.shape[0], dtype=data.dtype)
-    aln_coords[pos] = data
-    return aln_coords
+def normalize(numbers):
+    minv, maxv = np.min(numbers), np.max(numbers)
+    return (numbers - minv) / (maxv - minv)
 
 
-def get_file_parts(input_filename: typing.Union[str, Path]) -> tuple:
+def get_file_parts(input_filename: Union[str, Path]) -> Tuple[str, str, str]:
     """
     Gets directory path, name, and extension from a filename
     Parameters
@@ -125,14 +79,14 @@ def get_file_parts(input_filename: typing.Union[str, Path]) -> tuple:
     return path, name, extension
 
 
-def get_alpha_indices(protein):
+def get_alpha_indices(protein: pd.AtomGroup) -> List[int]:
     """
     Get indices of alpha carbons of pd AtomGroup object
     """
     return [a.getIndex() for a in protein.iterAtoms() if a.getName() == "CA"]
 
 
-def get_beta_indices(protein: pd.AtomGroup) -> list:
+def get_beta_indices(protein: pd.AtomGroup) -> List[int]:
     """
     Get indices of beta carbons of pd AtomGroup object
     (If beta carbon doesn't exist, alpha carbon index is returned)
@@ -157,7 +111,7 @@ def get_beta_indices(protein: pd.AtomGroup) -> list:
     return indices
 
 
-def group_indices(input_list: list) -> list:
+def group_indices(input_list: List[int]) -> List[List[int]]:
     """
     [1, 1, 1, 2, 2, 3, 3, 3, 4] -> [[0, 1, 2], [3, 4], [5, 6, 7], [8]]
     Parameters
@@ -252,7 +206,7 @@ def clustal_msa_from_sequences(
 
 
 def get_sequences_from_fasta(
-    fasta_file: typing.Union[str, Path], prune_headers: bool = True
+    fasta_file: Union[str, Path], prune_headers: bool = True
 ) -> dict:
     """
     Returns dict of accession to sequence from fasta file
@@ -338,11 +292,11 @@ def get_beta_coordinates(residue) -> np.ndarray:
     return np.array(residue["CB"].get_coord())
 
 
-def parse_pdb_files(input_pdb, extension=".pdb"):
-    if type(input_pdb) == str or type(input_pdb) == Path:
+def parse_pdb_files(input_pdb):
+    if type(input_pdb) == str or type(input_pdb) == PosixPath:
         input_pdb = Path(input_pdb)
         if input_pdb.is_dir():
-            pdb_files = list(input_pdb.glob(f"*{extension}"))
+            pdb_files = list(input_pdb.glob("*.pdb"))
         elif input_pdb.is_file():
             with open(input_pdb) as f:
                 pdb_files = f.read().strip().split("\n")
@@ -352,18 +306,15 @@ def parse_pdb_files(input_pdb, extension=".pdb"):
         pdb_files = list(input_pdb)
         if not Path(pdb_files[0]).is_file():
             pdb_files = [pd.fetchPDB(pdb_name) for pdb_name in pdb_files]
-    print(f"Found {len(pdb_files)} PDB files")
     return pdb_files
 
 
 def parse_pdb_files_and_clean(
-    input_pdb: str,
-    extension=".pdb",
-    output_pdb: typing.Union[str, Path] = "./cleaned_pdb",
-) -> typing.List[typing.Union[str, Path]]:
+    input_pdb: str, output_pdb: Union[str, Path] = "./cleaned_pdb",
+) -> List[Union[str, Path]]:
     if not Path(output_pdb).exists():
         Path(output_pdb).mkdir()
-    pdb_files = parse_pdb_files(input_pdb, extension)
+    pdb_files = parse_pdb_files(input_pdb)
     output_pdb_files = []
     for pdb_file in pdb_files:
         pdb = pd.parsePDB(pdb_file).select("protein")
diff --git a/caretta/msa_numba.py b/caretta/msa_numba.py
deleted file mode 100644
index d2f37c788760ba6c7d77ebd487f9cd8591e4d758..0000000000000000000000000000000000000000
--- a/caretta/msa_numba.py
+++ /dev/null
@@ -1,1003 +0,0 @@
-import pickle
-import typing
-from dataclasses import dataclass, field
-from pathlib import Path
-
-import numba as nb
-import numpy as np
-import prody as pd
-
-from caretta import feature_extraction
-from caretta import neighbor_joining as nj
-from caretta import psa_numba as psa
-from caretta import rmsd_calculations, helper
-
-
-@nb.njit
-def get_common_coordinates(coords_1, coords_2, aln_1, aln_2, gap=-1):
-    assert aln_1.shape == aln_2.shape
-    pos_1, pos_2 = helper.get_common_positions(aln_1, aln_2, gap)
-    return coords_1[pos_1], coords_2[pos_2]
-
-
-@nb.njit
-def make_pairwise_dtw_score_matrix(
-    coords_array,
-    secondary_array,
-    lengths_array,
-    gamma,
-    gap_open_penalty: float,
-    gap_extend_penalty: float,
-    gap_open_sec,
-    gap_extend_sec,
-):
-    pairwise_matrix = np.zeros((coords_array.shape[0], coords_array.shape[0]))
-    for i in range(pairwise_matrix.shape[0] - 1):
-        for j in range(i + 1, pairwise_matrix.shape[1]):
-            dtw_aln_1, dtw_aln_2, score = psa.get_pairwise_alignment(
-                coords_array[i, : lengths_array[i]],
-                coords_array[j, : lengths_array[j]],
-                secondary_array[i, : lengths_array[i]],
-                secondary_array[j, : lengths_array[j]],
-                gamma,
-                gap_open_sec=gap_open_sec,
-                gap_extend_sec=gap_extend_sec,
-                gap_open_penalty=gap_open_penalty,
-                gap_extend_penalty=gap_extend_penalty,
-            )
-            common_coords_1, common_coords_2 = get_common_coordinates(
-                coords_array[i, : lengths_array[i]],
-                coords_array[j, : lengths_array[j]],
-                dtw_aln_1,
-                dtw_aln_2,
-            )
-            rotation_matrix, translation_matrix = rmsd_calculations.svd_superimpose(
-                common_coords_1[:, :3], common_coords_2[:, :3]
-            )
-            common_coords_2[:, :3] = rmsd_calculations.apply_rotran(
-                common_coords_2[:, :3], rotation_matrix, translation_matrix
-            )
-            score = rmsd_calculations.get_caretta_score(
-                common_coords_1, common_coords_2, gamma, True
-            )
-            pairwise_matrix[i, j] = -score
-    pairwise_matrix += pairwise_matrix.T
-    return pairwise_matrix
-
-
-@nb.njit(parallel=True)
-def make_pairwise_rmsd_score_matrix(
-    coords_array,
-    secondary_array,
-    lengths_array,
-    gamma,
-    gap_open_penalty: float,
-    gap_extend_penalty: float,
-    gap_open_sec,
-    gap_extend_sec,
-):
-    pairwise_matrix = np.zeros((coords_array.shape[0], coords_array.shape[0]))
-    for i in nb.prange(pairwise_matrix.shape[0] - 1):
-        for j in range(i + 1, pairwise_matrix.shape[1]):
-            dtw_aln_1, dtw_aln_2, score = psa.get_pairwise_alignment(
-                coords_array[i, : lengths_array[i]],
-                coords_array[j, : lengths_array[j]],
-                secondary_array[i, : lengths_array[i]],
-                secondary_array[j, : lengths_array[j]],
-                gamma,
-                gap_open_sec=gap_open_sec,
-                gap_extend_sec=gap_extend_sec,
-                gap_open_penalty=gap_open_penalty,
-                gap_extend_penalty=gap_extend_penalty,
-            )
-            common_coords_1, common_coords_2 = get_common_coordinates(
-                coords_array[i, : lengths_array[i]],
-                coords_array[j, : lengths_array[j]],
-                dtw_aln_1,
-                dtw_aln_2,
-            )
-            rotation_matrix, translation_matrix = rmsd_calculations.svd_superimpose(
-                common_coords_1[:, :3], common_coords_2[:, :3]
-            )
-            common_coords_2[:, :3] = rmsd_calculations.apply_rotran(
-                common_coords_2[:, :3], rotation_matrix, translation_matrix
-            )
-            score = rmsd_calculations.get_rmsd(common_coords_1, common_coords_2)
-            pairwise_matrix[i, j] = score
-    pairwise_matrix += pairwise_matrix.T
-    return pairwise_matrix
-
-
-@nb.njit
-def _get_alignment_data(
-    coords_1,
-    coords_2,
-    secondary_1,
-    secondary_2,
-    gamma,
-    gap_open_sec,
-    gap_extend_sec,
-    gap_open_penalty: float,
-    gap_extend_penalty: float,
-):
-    dtw_aln_1, dtw_aln_2, _ = psa.get_pairwise_alignment(
-        coords_1,
-        coords_2,
-        secondary_1,
-        secondary_2,
-        gamma,
-        gap_open_sec=gap_open_sec,
-        gap_extend_sec=gap_extend_sec,
-        gap_open_penalty=gap_open_penalty,
-        gap_extend_penalty=gap_extend_penalty,
-    )
-    pos_1, pos_2 = helper.get_common_positions(dtw_aln_1, dtw_aln_2, -1)
-    coords_1[:, :3], coords_2[:, :3], _ = rmsd_calculations.superpose_with_pos(
-        coords_1[:, :3], coords_2[:, :3], coords_1[pos_1][:, :3], coords_2[pos_2][:, :3]
-    )
-    aln_coords_1 = helper.get_aligned_data(dtw_aln_1, coords_1, -1)
-    aln_coords_2 = helper.get_aligned_data(dtw_aln_2, coords_2, -1)
-    aln_sec_1 = helper.get_aligned_string_data(dtw_aln_1, secondary_1, -1)
-    aln_sec_2 = helper.get_aligned_string_data(dtw_aln_2, secondary_2, -1)
-    return aln_coords_1, aln_coords_2, aln_sec_1, aln_sec_2, dtw_aln_1, dtw_aln_2
-
-
-@nb.njit
-def get_mean_coords_extra(
-    aln_coords_1: np.ndarray, aln_coords_2: np.ndarray
-) -> np.ndarray:
-    """
-    Mean of two coordinate sets (of the same shape)
-
-    Parameters
-    ----------
-    aln_coords_1
-    aln_coords_2
-
-    Returns
-    -------
-    mean_coords
-    """
-    mean_coords = np.zeros(aln_coords_1.shape)
-    for i in range(aln_coords_1.shape[0]):
-        mean_coords[i, :-1] = np.array(
-            [
-                np.nanmean(np.array([aln_coords_1[i, x], aln_coords_2[i, x]]))
-                for x in range(aln_coords_1.shape[1] - 1)
-            ]
-        )
-        if not np.isnan(aln_coords_1[i, 0]):
-            mean_coords[i, -1] += aln_coords_1[i, -1]
-        if not np.isnan(aln_coords_2[i, 0]):
-            mean_coords[i, -1] += aln_coords_2[i, -1]
-    return mean_coords
-
-
-@nb.njit
-def get_mean_secondary(
-    aln_sec_1: np.ndarray, aln_sec_2: np.ndarray, gap=0
-) -> np.ndarray:
-    """
-    Mean of two coordinate sets (of the same shape)
-
-    Parameters
-    ----------
-    aln_sec_1
-    aln_sec_2
-    gap
-
-    Returns
-    -------
-    mean_sec
-    """
-    mean_sec = np.zeros(aln_sec_1.shape, dtype=aln_sec_1.dtype)
-    for i in range(aln_sec_1.shape[0]):
-        if aln_sec_1[i] == aln_sec_2[i]:
-            mean_sec[i] = aln_sec_1[i]
-        else:
-            if aln_sec_1[i] != gap:
-                mean_sec[i] = aln_sec_1[i]
-            elif aln_sec_2[i] != gap:
-                mean_sec[i] = aln_sec_2[i]
-    return mean_sec
-
-
-@dataclass(eq=False)
-class Structure:
-    name: str
-    pdb_file: typing.Union[str, Path, None]
-    sequence: typing.Union[str, None]
-    secondary: np.ndarray = field(repr=False)
-    features: typing.Union[np.ndarray, None] = field(repr=False)
-    coords: np.ndarray = field(repr=False)
-
-    @classmethod
-    def from_pdb_file(
-        cls,
-        pdb_file: typing.Union[str, Path],
-        dssp_dir="caretta_tmp",
-        extract_all_features=True,
-        force_overwrite=False,
-    ):
-        pdb_name = helper.get_file_parts(pdb_file)[1]
-        pdb = pd.parsePDB(str(pdb_file)).select("protein")
-
-        alpha_indices = helper.get_alpha_indices(pdb)
-        sequence = pdb[alpha_indices].getSequence()
-        coordinates = pdb[alpha_indices].getCoords().astype(np.float64)
-        only_dssp = not extract_all_features
-        features = feature_extraction.get_features(
-            str(pdb_file),
-            str(dssp_dir),
-            only_dssp=only_dssp,
-            force_overwrite=force_overwrite,
-        )
-        return cls(
-            pdb_name,
-            pdb_file,
-            sequence,
-            helper.secondary_to_array(features["secondary"]),
-            features,
-            coordinates,
-        )
-
-    @classmethod
-    def mutate_from_pdb_file(
-        cls,
-        pdb_file: typing.Union[str, Path],
-        dssp_dir="caretta_tmp",
-        extract_all_features=True,
-        force_overwrite=False,
-        randomness=0.1,
-    ):
-        pdb_name = helper.get_file_parts(pdb_file)[1]
-        pdb = pd.parsePDB(str(pdb_file)).select("protein")
-
-        alpha_indices = helper.get_alpha_indices(pdb)
-        sequence = pdb[alpha_indices].getSequence()
-        coordinates = pdb[alpha_indices].getCoords().astype(np.float64)
-        coordinates += np.random.normal(
-            coordinates.mean(), randomness, coordinates.shape
-        )
-        only_dssp = not extract_all_features
-        features = feature_extraction.get_features(
-            str(pdb_file),
-            str(dssp_dir),
-            only_dssp=only_dssp,
-            force_overwrite=force_overwrite,
-        )
-        return cls(
-            pdb_name,
-            pdb_file,
-            sequence,
-            helper.secondary_to_array(features["secondary"]),
-            features,
-            coordinates,
-        )
-
-    @classmethod
-    def from_pdb_id(
-        cls,
-        pdb_name: str,
-        chain: str,
-        dssp_dir="caretta_tmp",
-        extract_all_features=True,
-        force_overwrite=False,
-    ):
-        pdb = pd.parsePDB(pdb_name).select("protein").select(f"chain {chain}")
-        pdb_file = pd.writePDB(pdb_name, pdb)
-        alpha_indices = helper.get_alpha_indices(pdb)
-        sequence = pdb[alpha_indices].getSequence()
-        coordinates = pdb[alpha_indices].getCoords().astype(np.float64)
-        only_dssp = not extract_all_features
-        features = feature_extraction.get_features(
-            str(Path(pdb_file)),
-            str(dssp_dir),
-            only_dssp=only_dssp,
-            force_overwrite=force_overwrite,
-        )
-        return cls(
-            pdb_name,
-            pdb_file,
-            sequence,
-            helper.secondary_to_array(features["secondary"]),
-            features,
-            coordinates,
-        )
-
-
-@dataclass
-class OutputFiles:
-    fasta_file: Path = Path("./result.fasta")
-    pdb_folder: Path = Path("./result_pdb/")
-    cleaned_pdb_folder: Path = Path("./cleaned_pdb")
-    feature_file: Path = Path("./result_features.pkl")
-    class_file: Path = Path("./result_class.pkl")
-
-
-def parse_pdb_files(
-    input_pdb: str, output_pdb: typing.Union[str, Path] = "./cleaned_pdb"
-) -> typing.List[typing.Union[str, Path]]:
-    if not Path(output_pdb).exists():
-        Path(output_pdb).mkdir()
-    if type(input_pdb) == str:
-        input_pdb = Path(input_pdb)
-        if input_pdb.is_dir():
-            pdb_files = list(input_pdb.glob("*.pdb"))
-        elif input_pdb.is_file():
-            with open(input_pdb) as f:
-                pdb_files = f.read().strip().split("\n")
-        else:
-            pdb_files = str(input_pdb).split("\n")
-    else:
-        pdb_files = list(input_pdb)
-        if not Path(pdb_files[0]).is_file():
-            pdb_files = [pd.fetchPDB(pdb_name) for pdb_name in pdb_files]
-    output_pdb_files = []
-    for pdb_file in pdb_files:
-        pdb = pd.parsePDB(pdb_file).select("protein")
-        chains = pdb.getChids()
-        if len(chains) and len(chains[0].strip()):
-            pdb = pdb.select(f"chain {chains[0]}")
-        output_pdb_file = str(
-            Path(output_pdb) / f"{helper.get_file_parts(pdb_file)[1]}.pdb"
-        )
-        pd.writePDB(output_pdb_file, pdb)
-        output_pdb_files.append(output_pdb_file)
-    print(f"Found {len(output_pdb_files)} PDB files")
-    return output_pdb_files
-
-
-@dataclass
-class StructureMultiple:
-    """
-    Class for multiple structure alignment
-    """
-
-    structures: typing.Union[None, typing.List[Structure]] = None
-    lengths_array: typing.Union[None, np.ndarray] = None
-    max_length: int = 0
-    coords_array: typing.Union[None, np.ndarray] = None
-    secondary_array: typing.Union[None, np.ndarray] = None
-    final_structures: typing.Union[None, typing.List[Structure]] = None
-    tree: typing.Union[None, np.ndarray] = None
-    branch_lengths: typing.Union[None, np.ndarray] = None
-    alignment: typing.Union[None, dict] = None
-    output_files: OutputFiles = OutputFiles()
-
-    @classmethod
-    def from_structures(
-        cls,
-        structures: typing.List[Structure],
-        output_fasta_filename=Path("./result.fasta"),
-        output_pdb_folder=Path("./result_pdb/"),
-        output_feature_filename=Path("./result_features.pkl"),
-        output_class_filename=Path("./result_class.pkl"),
-    ):
-        lengths_array = np.array([len(s.sequence) for s in structures])
-        # print(lengths_array)
-        max_length = np.max(lengths_array)
-        coords_array = np.zeros(
-            (len(structures), max_length, structures[0].coords.shape[1])
-        )
-        secondary_array = np.zeros((len(structures), max_length))
-        for i in range(len(structures)):
-            # print(i, structures[i].secondary.shape)
-            coords_array[i, : lengths_array[i]] = structures[i].coords
-            secondary_array[i, : lengths_array[i]] = structures[i].secondary
-        if not Path(output_pdb_folder).exists():
-            Path(output_pdb_folder).mkdir()
-        cleaned_pdb_folder = (
-            Path(helper.get_file_parts(output_pdb_folder)[1]) / "cleaned_pdb"
-        )
-        if not cleaned_pdb_folder.exists():
-            cleaned_pdb_folder.mkdir()
-        output_files = OutputFiles(
-            output_fasta_filename,
-            output_pdb_folder,
-            cleaned_pdb_folder,
-            output_feature_filename,
-            output_class_filename,
-        )
-        return cls(
-            structures,
-            lengths_array,
-            max_length,
-            coords_array,
-            secondary_array,
-            output_files=output_files,
-        )
-
-    @classmethod
-    def from_pdb_files(
-        cls,
-        input_pdb,
-        dssp_dir=Path("./caretta_tmp/"),
-        num_threads=20,
-        extract_all_features=True,
-        consensus_weight=1.0,
-        output_fasta_filename=Path("./result.fasta"),
-        output_pdb_folder=Path("./result_pdb/"),
-        output_feature_filename=Path("./result_features.pkl"),
-        output_class_filename=Path("./result_class.pkl"),
-        overwrite_dssp=False,
-    ):
-        if not Path(output_pdb_folder).exists():
-            Path(output_pdb_folder).mkdir()
-        cleaned_pdb_folder = (
-            Path(helper.get_file_parts(output_pdb_folder)[0]) / "cleaned_pdb"
-        )
-        if not cleaned_pdb_folder.exists():
-            cleaned_pdb_folder.mkdir()
-        pdb_files = parse_pdb_files(input_pdb, cleaned_pdb_folder)
-        if not Path(dssp_dir).exists():
-            Path(dssp_dir).mkdir()
-        pdbs = [pd.parsePDB(filename).select("protein") for filename in pdb_files]
-        alpha_indices = [helper.get_alpha_indices(pdb) for pdb in pdbs]
-        sequences = [pdbs[i][alpha_indices[i]].getSequence() for i in range(len(pdbs))]
-        coordinates = [
-            np.hstack(
-                (
-                    pdbs[i][alpha_indices[i]].getCoords().astype(np.float64),
-                    np.zeros((len(sequences[i]), 1)) + consensus_weight,
-                )
-            )
-            for i in range(len(pdbs))
-        ]
-        only_dssp = not extract_all_features
-        print("Extracting features...")
-        features = feature_extraction.get_features_multiple(
-            pdb_files,
-            str(dssp_dir),
-            num_threads=num_threads,
-            only_dssp=only_dssp,
-            force_overwrite=overwrite_dssp,
-        )
-        structures = []
-        for i in range(len(pdbs)):
-            pdb_name = helper.get_file_parts(pdb_files[i])[1]
-            structures.append(
-                Structure(
-                    pdb_name,
-                    pdb_files[i],
-                    sequences[i],
-                    helper.secondary_to_array(features[i]["secondary"]),
-                    features[i],
-                    coordinates[i],
-                )
-            )
-        msa_class = StructureMultiple.from_structures(
-            structures,
-            output_fasta_filename,
-            output_pdb_folder,
-            output_feature_filename,
-            output_class_filename,
-        )
-        return msa_class
-
-    @classmethod
-    def mutate_multiple_from_single_pdb_file(
-        cls,
-        input_pdb,
-        filename,
-        dssp_dir=Path("./caretta_tmp/"),
-        num_threads=20,
-        extract_all_features=False,
-        consensus_weight=1.0,
-        output_fasta_filename=Path("./result.fasta"),
-        output_pdb_folder=Path("./result_pdb/"),
-        output_feature_filename=Path("./result_features.pkl"),
-        output_class_filename=Path("./result_class.pkl"),
-        overwrite_dssp=False,
-        number=10,
-        noise_level=0.1,
-    ):
-        from shutil import copyfile
-
-        if not Path(output_pdb_folder).exists():
-            Path(output_pdb_folder).mkdir()
-        cleaned_pdb_folder = (
-            Path(helper.get_file_parts(output_pdb_folder)[0]) / "cleaned_pdb"
-        )
-        if not cleaned_pdb_folder.exists():
-            cleaned_pdb_folder.mkdir()
-        for i in range(number):
-            copyfile(filename, str(cleaned_pdb_folder / f"n_{i}.pdb"))
-        if not Path(dssp_dir).exists():
-            Path(dssp_dir).mkdir()
-        pdbs = [
-            pd.parsePDB(cleaned_pdb_folder / f"n_{i}.pdb").select("protein")
-            for i in range(number)
-        ]
-        alpha_indices = [helper.get_alpha_indices(pdb) for pdb in pdbs]
-        sequences = [pdbs[i][alpha_indices[i]].getSequence() for i in range(len(pdbs))]
-        coordinates = [
-            np.hstack(
-                (
-                    pdbs[i][alpha_indices[i]].getCoords().astype(np.float64)
-                    + np.random.normal(
-                        0,
-                        np.mean(
-                            pdbs[i][alpha_indices[i]].getCoords().astype(np.float64)
-                        )
-                        * noise_level,
-                        pdbs[i][alpha_indices[i]].getCoords().astype(np.float64).shape,
-                    ),
-                    np.zeros((len(sequences[i]), 1)) + consensus_weight,
-                )
-            )
-            for i in range(len(pdbs))
-        ]
-        only_dssp = not extract_all_features
-        print("Extracting features...")
-        features = feature_extraction.get_features_multiple(
-            [cleaned_pdb_folder / f"n_{i}.pdb" for i in range(number)],
-            str(dssp_dir),
-            num_threads=num_threads,
-            only_dssp=only_dssp,
-            force_overwrite=True,
-        )
-        structures = []
-        name = "a"
-        for i in range(len(pdbs)):
-            name += name
-            pdb_name = str(name)  # helper.get_file_parts(pdb_files[i])[1]
-            # print(len(sequences[i]))
-            # print(len(features[i]["secondary"]))
-            structures.append(
-                Structure(
-                    pdb_name,
-                    None,
-                    sequences[i],
-                    helper.secondary_to_array(features[i]["secondary"]),
-                    features[i],
-                    coordinates[i],
-                )
-            )
-        msa_class = StructureMultiple.from_structures(
-            structures,
-            output_fasta_filename,
-            output_pdb_folder,
-            output_feature_filename,
-            output_class_filename,
-        )
-        return msa_class
-
-    def get_pairwise_matrix(
-        self,
-        gap_open_penalty,
-        gap_extend_penalty,
-        gamma=0.03,
-        gap_open_sec=1.0,
-        gap_extend_sec=0.1,
-    ):
-        return make_pairwise_rmsd_score_matrix(
-            self.coords_array,
-            self.secondary_array,
-            self.lengths_array,
-            gamma,
-            gap_open_sec,
-            gap_extend_sec,
-            gap_open_penalty,
-            gap_extend_penalty,
-        )
-
-    @staticmethod
-    def align_from_pdb_files(
-        input_pdb,
-        dssp_dir="caretta_tmp",
-        num_threads=20,
-        extract_all_features=True,
-        gap_open_penalty=1.0,
-        gap_extend_penalty=0.01,
-        consensus_weight=1.0,
-        write_fasta=False,
-        output_fasta_filename=Path("./result.fasta"),
-        write_pdb=False,
-        output_pdb_folder=Path("./result_pdb/"),
-        write_features=False,
-        output_feature_filename=Path("./result_features.pkl"),
-        write_class=False,
-        output_class_filename=Path("./result_class.pkl"),
-        overwrite_dssp=False,
-    ):
-        """
-        Caretta aligns protein structures and returns a sequence alignment, a set of aligned feature matrices, superposed PDB files, and
-        a class with intermediate structures made during progressive alignment.
-        Parameters
-        ----------
-        input_pdb
-            Can be \n
-            A list of PDB files
-            A list of PDB IDs
-            A folder with input protein files
-            A file which lists PDB filenames on each line
-            A file which lists PDB IDs on each line
-        dssp_dir
-            Folder to store temp DSSP files (default caretta_tmp)
-        num_threads
-            Number of threads to use for feature extraction
-        extract_all_features
-            True => obtains all features (default True) \n
-            False => only DSSP features (faster)
-        gap_open_penalty
-            default 1
-        gap_extend_penalty
-            default 0.01
-        consensus_weight
-            default 1
-        write_fasta
-            True => writes alignment as fasta file (default True)
-        output_fasta_filename
-            Fasta file of alignment (default result.fasta)
-        write_pdb
-            True => writes all protein PDB files superposed by alignment (default True)
-        output_pdb_folder
-            Folder to write superposed PDB files (default result_pdb)
-        write_features
-            True => writes aligned features as a dictionary of numpy arrays into a pickle file (default True)
-        output_feature_filename
-            Pickle file to write aligned features (default result_features.pkl)
-        write_class
-            True => writes StructureMultiple class with intermediate structures and tree to pickle file (default True)
-        output_class_filename
-            Pickle file to write StructureMultiple class (default result_class.pkl)
-        overwrite_dssp
-            Forces DSSP to rerun (default False)
-
-        Returns
-        -------
-        StructureMultiple class
-        """
-        msa_class = StructureMultiple.from_pdb_files(
-            input_pdb,
-            dssp_dir,
-            num_threads,
-            extract_all_features,
-            consensus_weight,
-            output_fasta_filename,
-            output_pdb_folder,
-            output_feature_filename,
-            output_class_filename,
-            overwrite_dssp,
-        )
-        msa_class.align(gap_open_penalty, gap_extend_penalty)
-        msa_class.write_files(write_fasta, write_pdb, write_features, write_class)
-        return msa_class
-
-    def align(
-        self,
-        gap_open_penalty,
-        gap_extend_penalty,
-        pw_matrix=None,
-        gamma=0.03,
-        gap_open_sec=1.0,
-        gap_extend_sec=0.1,
-    ) -> dict:
-        print("Aligning...")
-        if len(self.structures) == 2:
-            dtw_1, dtw_2, _ = psa.get_pairwise_alignment(
-                self.coords_array[0, : self.lengths_array[0]],
-                self.coords_array[1, : self.lengths_array[1]],
-                self.secondary_array[0, : self.lengths_array[0]],
-                self.secondary_array[1, : self.lengths_array[1]],
-                gamma,
-                gap_open_sec=gap_open_sec,
-                gap_extend_sec=gap_extend_sec,
-                gap_open_penalty=gap_open_penalty,
-                gap_extend_penalty=gap_extend_penalty,
-            )
-            self.alignment = {
-                self.structures[0].name: "".join(
-                    [self.structures[0].sequence[i] if i != -1 else "-" for i in dtw_1]
-                ),
-                self.structures[1].name: "".join(
-                    [self.structures[1].sequence[i] if i != -1 else "-" for i in dtw_2]
-                ),
-            }
-            return self.alignment
-        print("start pairwise")
-        if pw_matrix is None:
-            pw_matrix = make_pairwise_dtw_score_matrix(
-                self.coords_array,
-                self.secondary_array,
-                self.lengths_array,
-                gamma,
-                gap_open_sec,
-                gap_extend_sec,
-                gap_open_penalty,
-                gap_extend_penalty,
-            )
-
-        print("Pairwise score matrix calculation done")
-
-        tree, branch_lengths = nj.neighbor_joining(pw_matrix)
-        self.tree = tree
-        self.branch_lengths = branch_lengths
-        self.final_structures = [s for s in self.structures]
-        msa_alignments = {s.name: {s.name: s.sequence} for s in self.structures}
-
-        print("Neighbor joining tree constructed")
-
-        def make_intermediate_node(n1, n2, n_int):
-            name_1, name_2 = (
-                self.final_structures[n1].name,
-                self.final_structures[n2].name,
-            )
-            name_int = f"int-{n_int}"
-            n1_coords = self.final_structures[n1].coords
-            n1_coords[:, -1] *= len(msa_alignments[name_2])
-            n1_coords[:, -1] /= 2.0
-            n2_coords = self.final_structures[n2].coords
-            n2_coords[:, -1] *= len(msa_alignments[name_1])
-            n2_coords[:, -1] /= 2.0
-            (
-                aln_coords_1,
-                aln_coords_2,
-                aln_sec_1,
-                aln_sec_2,
-                dtw_aln_1,
-                dtw_aln_2,
-            ) = _get_alignment_data(
-                n1_coords,
-                n2_coords,
-                self.final_structures[n1].secondary,
-                self.final_structures[n2].secondary,
-                gamma,
-                gap_open_sec=gap_open_sec,
-                gap_extend_sec=gap_extend_sec,
-                gap_open_penalty=gap_open_penalty,
-                gap_extend_penalty=gap_extend_penalty,
-            )
-            aln_coords_1[:, -1] *= 2.0 / len(msa_alignments[name_2])
-            aln_coords_2[:, -1] *= 2.0 / len(msa_alignments[name_1])
-            msa_alignments[name_1] = {
-                name: "".join([sequence[i] if i != -1 else "-" for i in dtw_aln_1])
-                for name, sequence in msa_alignments[name_1].items()
-            }
-            msa_alignments[name_2] = {
-                name: "".join([sequence[i] if i != -1 else "-" for i in dtw_aln_2])
-                for name, sequence in msa_alignments[name_2].items()
-            }
-            msa_alignments[name_int] = {
-                **msa_alignments[name_1],
-                **msa_alignments[name_2],
-            }
-
-            mean_coords = get_mean_coords_extra(aln_coords_1, aln_coords_2)
-            mean_sec = get_mean_secondary(aln_sec_1, aln_sec_2, 0)
-            self.final_structures.append(
-                Structure(name_int, None, None, mean_sec, None, mean_coords)
-            )
-
-        for x in range(0, self.tree.shape[0] - 1, 2):
-            node_1, node_2, node_int = (
-                self.tree[x, 0],
-                self.tree[x + 1, 0],
-                self.tree[x, 1],
-            )
-            assert self.tree[x + 1, 1] == node_int
-            make_intermediate_node(node_1, node_2, node_int)
-
-        node_1, node_2 = self.tree[-1, 0], self.tree[-1, 1]
-        make_intermediate_node(node_1, node_2, "final")
-        alignment = {
-            **msa_alignments[self.final_structures[node_1].name],
-            **msa_alignments[self.final_structures[node_2].name],
-        }
-        self.alignment = alignment
-        return alignment
-
-    def write_files(
-        self, write_fasta=True, write_pdb=True, write_features=True, write_class=True
-    ):
-        if any((write_fasta, write_pdb, write_pdb, write_class)):
-            print("Writing files...")
-        if write_fasta:
-            self.write_alignment(self.output_files.fasta_file)
-        if write_pdb:
-            if not self.output_files.pdb_folder.exists():
-                self.output_files.pdb_folder.mkdir()
-            self.write_superposed_pdbs(self.output_files.pdb_folder)
-        if write_features:
-            with open(str(self.output_files.feature_file), "wb") as f:
-                pickle.dump(self.get_aligned_features(), f)
-        if write_class:
-            with open(str(self.output_files.class_file), "wb") as f:
-                pickle.dump(self, f)
-
-    def superpose(self, alignments: dict = None):
-        """
-        Superpose structures according to alignment
-        """
-        if alignments is None:
-            alignments = self.alignment
-        reference_index = 0
-        reference_key = self.structures[reference_index].name
-        core_indices = np.array(
-            [
-                i
-                for i in range(len(alignments[reference_key]))
-                if "-" not in [alignments[n][i] for n in alignments]
-            ]
-        )
-        aln_ref = helper.aligned_string_to_array(alignments[reference_key])
-        ref_coords = self.structures[reference_index].coords[
-            np.array([aln_ref[c] for c in core_indices])
-        ][:, :3]
-        ref_centroid = rmsd_calculations.nb_mean_axis_0(ref_coords)
-        ref_coords -= ref_centroid
-        for i in range(len(self.structures)):
-            if i == reference_index:
-                self.structures[i].coords[:, :3] -= ref_centroid
-            else:
-                aln_c = helper.aligned_string_to_array(
-                    alignments[self.structures[i].name]
-                )
-                common_coords_2 = self.structures[i].coords[
-                    np.array([aln_c[c] for c in core_indices])
-                ][:, :3]
-                rotation_matrix, translation_matrix = rmsd_calculations.svd_superimpose(
-                    ref_coords, common_coords_2
-                )
-                self.structures[i].coords[:, :3] = rmsd_calculations.apply_rotran(
-                    self.structures[i].coords[:, :3],
-                    rotation_matrix,
-                    translation_matrix,
-                )
-
-    def make_pairwise_rmsd_coverage_matrix(
-        self, alignments: dict = None, superpose_first: bool = True
-    ):
-        """
-        Find RMSDs and coverages of the alignment of each pair of sequences
-
-        Parameters
-        ----------
-        alignments
-        superpose_first
-            if True then superposes all structures to first structure first
-
-        Returns
-        -------
-        RMSD matrix, coverage matrix
-        """
-        if alignments is None:
-            alignments = self.alignment
-        num = len(self.structures)
-        pairwise_rmsd_matrix = np.zeros((num, num))
-        pairwise_rmsd_matrix[:] = np.nan
-        pairwise_coverage = np.zeros((num, num))
-        pairwise_coverage[:] = np.nan
-        if superpose_first:
-            self.superpose(alignments)
-        for i in range(num):
-            for j in range(i + 1, num):
-                name_1, name_2 = self.structures[i].name, self.structures[j].name
-                if isinstance(alignments[name_1], str):
-                    aln_1 = helper.aligned_string_to_array(alignments[name_1])
-                    aln_2 = helper.aligned_string_to_array(alignments[name_2])
-                else:
-                    aln_1 = alignments[name_1]
-                    aln_2 = alignments[name_2]
-                common_coords_1, common_coords_2 = get_common_coordinates(
-                    self.structures[i].coords[:, :3],
-                    self.structures[j].coords[:, :3],
-                    aln_1,
-                    aln_2,
-                )
-                assert common_coords_1.shape[0] > 0
-                if not superpose_first:
-                    rot, tran = rmsd_calculations.svd_superimpose(
-                        common_coords_1, common_coords_2
-                    )
-                    common_coords_2 = rmsd_calculations.apply_rotran(
-                        common_coords_2, rot, tran
-                    )
-                pairwise_rmsd_matrix[i, j] = pairwise_rmsd_matrix[
-                    j, i
-                ] = rmsd_calculations.get_rmsd(common_coords_1, common_coords_2)
-                pairwise_coverage[i, j] = pairwise_coverage[
-                    j, i
-                ] = common_coords_1.shape[0] / len(aln_1)
-        return pairwise_rmsd_matrix, pairwise_coverage
-
-    def get_aligned_features(self, alignments: dict = None):
-        """
-        Get dict of aligned features
-        """
-        if alignments is None:
-            alignments = self.alignment
-        feature_names = list(self.structures[0].features.keys())
-        aligned_features = {}
-        alignment_length = len(alignments[self.structures[0].name])
-        for feature_name in feature_names:
-            if feature_name == "secondary":
-                continue
-            aligned_features[feature_name] = np.zeros(
-                (len(self.structures), alignment_length)
-            )
-            aligned_features[feature_name][:] = np.nan
-            for p in range(len(self.structures)):
-                farray = self.structures[p].features[feature_name]
-                if "gnm" in feature_name or "anm" in feature_name:
-                    farray = farray / np.nansum(farray ** 2) ** 0.5
-                indices = [
-                    i
-                    for i in range(alignment_length)
-                    if alignments[self.structures[p].name][i] != "-"
-                ]
-                aligned_features[feature_name][p, indices] = farray
-        return aligned_features
-
-    def write_alignment(self, filename, alignments: dict = None):
-        """
-        Writes alignment to a fasta file
-        """
-        if alignments is None:
-            alignments = self.alignment
-        with open(filename, "w") as f:
-            for key in alignments:
-                f.write(f">{key}\n{alignments[key]}\n")
-
-    def write_superposed_pdbs(self, output_pdb_folder, alignments: dict = None):
-        """
-        Superposes PDBs according to alignment and writes transformed PDBs to files
-        (View with Pymol)
-
-        Parameters
-        ----------
-        alignments
-        output_pdb_folder
-        """
-        if alignments is None:
-            alignments = self.alignment
-        output_pdb_folder = Path(output_pdb_folder)
-        if not output_pdb_folder.exists():
-            output_pdb_folder.mkdir()
-        reference_name = self.structures[0].name
-        reference_pdb = pd.parsePDB(str(self.structures[0].pdb_file))
-        core_indices = np.array(
-            [
-                i
-                for i in range(len(alignments[reference_name]))
-                if "-" not in [alignments[n][i] for n in alignments]
-            ]
-        )
-        aln_ref = helper.aligned_string_to_array(alignments[reference_name])
-        ref_coords_core = (
-            reference_pdb[helper.get_alpha_indices(reference_pdb)]
-            .getCoords()
-            .astype(np.float64)[np.array([aln_ref[c] for c in core_indices])]
-        )
-        ref_centroid = rmsd_calculations.nb_mean_axis_0(ref_coords_core)
-        ref_coords_core -= ref_centroid
-        transformation = pd.Transformation(np.eye(3), -ref_centroid)
-        reference_pdb = pd.applyTransformation(transformation, reference_pdb)
-        pd.writePDB(str(output_pdb_folder / f"{reference_name}.pdb"), reference_pdb)
-        for i in range(1, len(self.structures)):
-            name = self.structures[i].name
-            pdb = pd.parsePDB(str(self.structures[i].pdb_file))
-            aln_name = helper.aligned_string_to_array(alignments[name])
-            common_coords_2 = (
-                pdb[helper.get_alpha_indices(pdb)]
-                .getCoords()
-                .astype(np.float64)[np.array([aln_name[c] for c in core_indices])]
-            )
-            rotation_matrix, translation_matrix = rmsd_calculations.svd_superimpose(
-                ref_coords_core, common_coords_2
-            )
-            transformation = pd.Transformation(rotation_matrix.T, translation_matrix)
-            pdb = pd.applyTransformation(transformation, pdb)
-            pd.writePDB(str(output_pdb_folder / f"{name}.pdb"), pdb)
-
-
-if __name__ == "__main__":
-    from os import listdir
-    from sys import argv
-
-    print(listdir("./"))
-    pdb_folder = "/mnt/local_scratch/akdel001/caretta/test_data/"
-    pdb_file = "/mnt/local_scratch/akdel001/caretta/test_data/1kdu.pdb"
-    al = StructureMultiple.from_pdb_files(pdb_folder)
-    alignment = al.align(1, 0.01)
diff --git a/caretta/multiple_alignment.py b/caretta/multiple_alignment.py
index 9d3acd75bd01b0e216bff9ea1db73689ce242e07..b9ef655bdb5d51e6c36c095766a8f5389c8a5396 100644
--- a/caretta/multiple_alignment.py
+++ b/caretta/multiple_alignment.py
@@ -6,9 +6,9 @@ from pathlib import Path
 import numba as nb
 import numpy as np
 import prody as pd
+import typer
+from geometricus import Structure, MomentInvariants, SplitType, GeometricusEmbedding
 from scipy.spatial.distance import pdist, squareform
-from geometricus import geometricus, protein_utility
-from geometricus.moment_utility import nb_mean_axis_0
 
 from caretta import (
     dynamic_time_warping as dtw,
@@ -20,13 +20,15 @@ from caretta import (
 )
 
 
-@nb.njit(cache=False)
-def get_common_coordinates(coords_1, coords_2, aln_1, aln_2, gap=-1):
+@nb.njit
+def get_common_coordinates(
+    coords_1: np.ndarray, coords_2: np.ndarray, aln_1: np.ndarray, aln_2: np.ndarray
+) -> typing.Tuple[np.ndarray, np.ndarray]:
     """
     Return coordinate positions aligned in both coords_1 and coords_2
     """
     assert aln_1.shape == aln_2.shape
-    pos_1, pos_2 = score_functions.get_common_positions(aln_1, aln_2, gap)
+    pos_1, pos_2 = helper.get_common_positions(aln_1, aln_2)
     return coords_1[pos_1], coords_2[pos_2]
 
 
@@ -53,23 +55,14 @@ def get_mean_coords(
     return mean_coords
 
 
-@nb.njit
-def get_pairwise_braycurtis(fingerprints):
-    res = np.zeros((fingerprints.shape[0], fingerprints.shape[0]), dtype=np.float64)
-    for i in range(fingerprints.shape[0]):
-        for j in range(fingerprints.shape[0]):
-            res[i, j] = np.sum(np.abs(fingerprints[i] - fingerprints[j])) / np.sum(
-                np.abs(fingerprints[i] + fingerprints[j])
-            )
-    return res
-
-
-def get_mean_weights(weights_1, weights_2, aln_1, aln_2, gap=-1):
+def get_mean_weights(
+    weights_1: np.ndarray, weights_2: np.ndarray, aln_1: np.ndarray, aln_2: np.ndarray
+) -> np.ndarray:
     mean_weights = np.zeros(aln_1.shape[0])
     for i, (x, y) in enumerate(zip(aln_1, aln_2)):
-        if not x == gap:
+        if not x == -1:
             mean_weights[i] += weights_1[x]
-        if not y == gap:
+        if not y == -1:
             mean_weights[i] += weights_2[y]
     return mean_weights
 
@@ -83,6 +76,18 @@ class OutputFiles:
     class_file: Path = Path("./result_class.pkl")
 
 
+DEFAULT_SUPERPOSITION_PARAMETERS = {
+    "num_split_types": 2,
+    "split_type_0": "KMER",
+    "split_size_0": 30,
+    "split_type_1": "RADIUS",
+    "split_size_1": 16,
+    "gap_open_penalty": 0.01,
+    "gap_extend_penalty": 0.001,
+    "scale": True,
+}
+
+
 @dataclass
 class StructureMultiple:
     """
@@ -92,6 +97,8 @@ class StructureMultiple:
     ---------------------
     structures
         list of protein_utility.Structure objects
+    superposition_parameters
+        dictionary of parameters to pass to the superposition function
     superposition_function
         a function that takes two coordinate sets as input and superposes them
         returns a score, superposed_coords_1, superposed_coords_2
@@ -108,8 +115,9 @@ class StructureMultiple:
         indices of aligning residues from each structure, gaps are -1s
     """
 
-    structures: typing.List[protein_utility.Structure]
+    structures: typing.List[Structure]
     sequences: typing.Dict[str, str]
+    superposition_parameters: typing.Dict[str, typing.Any]
     superposition_function: typing.Callable[
         [
             np.ndarray,
@@ -122,50 +130,53 @@ class StructureMultiple:
     score_function: typing.Callable[
         [np.ndarray, np.ndarray], float
     ] = score_functions.get_caretta_score
-    gamma: float = 0.3
     mean_function: typing.Callable[
         [np.ndarray, np.ndarray, np.ndarray, np.ndarray], np.ndarray
     ] = get_mean_coords
     consensus_weight: float = 1.0
-    final_structures: typing.Union[None, typing.List[protein_utility.Structure]] = None
+    final_structures: typing.Union[None, typing.List[Structure]] = None
     final_consensus_weights: typing.Union[None, typing.List[np.ndarray]] = None
     tree: typing.Union[None, np.ndarray] = None
     branch_lengths: typing.Union[None, np.ndarray] = None
-    alignment: typing.Union[None, dict] = None
+    alignment: typing.Union[dict, None] = None
     output_folder: Path = Path("./caretta_results")
+    features: typing.Union[dict, None] = None
 
     @staticmethod
     def align_from_pdb_files(
-        input_pdb,
-        gap_open_penalty=1.0,
-        gap_extend_penalty=0.01,
-        consensus_weight=1.0,
-        output_folder=Path("../caretta_results"),
-        num_threads=20,
-        write_fasta=False,
-        write_pdb=False,
-        write_features=False,
-        write_class=False,
-        overwrite_dssp=False,
+        input_pdb: typing.Union[typing.List[str], Path, str],
+        gap_open_penalty: float = 1.0,
+        gap_extend_penalty: float = 0.01,
+        consensus_weight: float = 1.0,
+        full: bool = False,
+        output_folder: typing.Union[str, Path] = Path("../caretta_results"),
+        num_threads: int = 20,
+        write_fasta: bool = False,
+        write_pdb: bool = False,
+        write_features: bool = False,
+        write_class: bool = False,
     ):
         """
-        Caretta aligns protein structures and returns a sequence alignment, a set of aligned feature matrices, superposed PDB files, and
-        a class with intermediate structures made during progressive alignment.
+        Caretta aligns protein structures and can output a sequence alignment, superposed PDB files,
+        a set of aligned feature matrices and a class with intermediate structures made during progressive alignment.
+
         Parameters
         ----------
         input_pdb
             Can be \n
-            A list of PDB files
-            A list of PDB IDs
-            A folder with input protein files
-            A file which lists PDB filenames on each line
-            A file which lists PDB IDs on each line
+            A list of PDB files,
+            A list of PDB IDs,
+            A folder with input protein files,
+            A file which lists PDB filenames on each line,
+            A file which lists PDB IDs on each line,
         gap_open_penalty
             default 1
         gap_extend_penalty
             default 0.01
         consensus_weight
             default 1
+        full
+            True =>  Uses all-vs-all pairwise Caretta alignment to make the distance matrix (much slower)
         output_folder
             default "caretta_results"
         num_threads
@@ -182,8 +193,6 @@ class StructureMultiple:
         write_class
             True => writes StructureMultiple class with intermediate structures and tree to pickle file (default True)
             writes to output_folder / result_class.pkl
-        overwrite_dssp
-            Forces DSSP to rerun even if files are already present (default False)
 
         Returns
         -------
@@ -191,19 +200,26 @@ class StructureMultiple:
         """
         msa_class = StructureMultiple.from_pdb_files(
             input_pdb,
-            superposition_function=superposition_functions.moment_svd_superpose_function,
+            superposition_parameters=DEFAULT_SUPERPOSITION_PARAMETERS,
+            superposition_function=superposition_functions.moment_multiple_svd_superpose_function,
             consensus_weight=consensus_weight,
             output_folder=output_folder,
         )
-        pw_matrix = msa_class.make_pairwise_dtw_matrix()
-        msa_class.align(pw_matrix, gap_open_penalty, gap_extend_penalty)
+        if len(msa_class.structures) > 2:
+            if full:
+                pw_matrix = msa_class.make_pairwise_dtw_matrix(
+                    gap_open_penalty, gap_extend_penalty
+                )
+            else:
+                pw_matrix = msa_class.make_pairwise_shape_matrix()
+            msa_class.align(pw_matrix, gap_open_penalty, gap_extend_penalty)
+        else:
+            msa_class.align(
+                gap_open_penalty=gap_open_penalty, gap_extend_penalty=gap_extend_penalty
+            )
+
         msa_class.write_files(
-            write_fasta,
-            write_pdb,
-            write_features,
-            write_class,
-            num_threads,
-            overwrite_dssp,
+            write_fasta, write_pdb, write_features, write_class, num_threads,
         )
         return msa_class
 
@@ -211,7 +227,8 @@ class StructureMultiple:
     def from_pdb_files(
         cls,
         input_pdb,
-        superposition_function,
+        superposition_parameters,
+        superposition_function=superposition_functions.moment_multiple_svd_superpose_function,
         score_function=score_functions.get_caretta_score,
         consensus_weight=1.0,
         output_folder=Path("./caretta_results"),
@@ -223,6 +240,8 @@ class StructureMultiple:
         ----------
         input_pdb
             list of pdb files/names or a folder containing pdb files
+        superposition_parameters
+            parameters to give to the superposition function
         superposition_function
             a function that takes two coordinate sets as input and superposes them
             returns a score, superposed_coords_1, superposed_coords_2
@@ -244,24 +263,20 @@ class StructureMultiple:
         if not cleaned_pdb_folder.exists():
             cleaned_pdb_folder.mkdir()
         pdb_files = helper.parse_pdb_files_and_clean(input_pdb, cleaned_pdb_folder)
+        typer.echo(f"Found {len(pdb_files)} PDB files")
 
         structures = []
         sequences = {}
         for pdb_file in pdb_files:
             pdb_name = Path(pdb_file).stem
-            protein = pd.parsePDB(str(pdb_file)).select("protein")
-            indices = [
-                i for i, a in enumerate(protein.iterAtoms()) if a.getName() == "CA"
-            ]
-            protein = protein[indices]
+            protein = pd.parsePDB(str(pdb_file)).select("protein and calpha")
             coordinates = protein.getCoords()
-            structures.append(
-                protein_utility.Structure(pdb_name, coordinates.shape[0], coordinates)
-            )
+            structures.append(Structure(pdb_name, coordinates.shape[0], coordinates))
             sequences[pdb_name] = protein.getSequence()
         msa_class = StructureMultiple(
             structures,
             sequences,
+            superposition_parameters,
             superposition_function,
             score_function=score_function,
             consensus_weight=consensus_weight,
@@ -273,7 +288,6 @@ class StructureMultiple:
         self,
         coords_1,
         coords_2,
-        parameters,
         gap_open_penalty: float,
         gap_extend_penalty: float,
         weight=False,
@@ -287,7 +301,6 @@ class StructureMultiple:
         ----------
         coords_1
         coords_2
-        parameters
         gap_open_penalty
         gap_extend_penalty
         weight
@@ -299,11 +312,8 @@ class StructureMultiple:
         -------
         alignment_1, alignment_2, score, superposed_coords_1, superposed_coords_2
         """
-        #  if exclude_last:
-        #      _, coords_1[:, :-1], coords_2[:, :-1] = self.superposition_function(coords_1[:, :-1], coords_2[:, :-1])
-        #  else:
         _, coords_1, coords_2 = self.superposition_function(
-            coords_1, coords_2, parameters
+            coords_1, coords_2, self.superposition_parameters
         )
         if weight:
             assert weights_1 is not None
@@ -323,9 +333,7 @@ class StructureMultiple:
             score_matrix, gap_open_penalty, gap_extend_penalty
         )
         for i in range(n_iter):
-            pos_1, pos_2 = score_functions.get_common_positions(
-                dtw_aln_array_1, dtw_aln_array_2
-            )
+            pos_1, pos_2 = helper.get_common_positions(dtw_aln_array_1, dtw_aln_array_2)
             common_coords_1, common_coords_2 = coords_1[pos_1], coords_2[pos_2]
             (
                 c1,
@@ -353,7 +361,11 @@ class StructureMultiple:
         return dtw_aln_array_1, dtw_aln_array_2, dtw_score, coords_1, coords_2
 
     def make_pairwise_shape_matrix(
-        self, resolution: np.ndarray, kmer_size=30, radius=10, metric="braycurtis"
+        self,
+        resolution: typing.Union[float, np.ndarray] = 2.0,
+        kmer_size: int = 30,
+        radius: int = 16,
+        metric="braycurtis",
     ):
         """
         Makes an all vs. all matrix of distance scores between all the structures.
@@ -370,49 +382,47 @@ class StructureMultiple:
         -------
         [n x n] distance matrix
         """
-        kmer_invariants = [
-            geometricus.MomentInvariants.from_coordinates(
+        typer.echo("Calculating pairwise distances...")
+        kmer_invariants = (
+            MomentInvariants.from_coordinates(
                 s.name,
                 s.coordinates,
                 None,
                 split_size=kmer_size,
-                split_type=geometricus.SplitType.KMER,
+                split_type=SplitType.KMER,
             )
             for s in self.structures
-        ]
-        radius_invariants = [
-            geometricus.MomentInvariants.from_coordinates(
+        )
+        radius_invariants = (
+            MomentInvariants.from_coordinates(
                 s.name,
                 s.coordinates,
                 None,
                 split_size=radius,
-                split_type=geometricus.SplitType.RADIUS,
+                split_type=SplitType.RADIUS,
             )
             for s in self.structures
-        ]
-        kmer_embedder = geometricus.GeometricusEmbedding.from_invariants(
+        )
+        kmer_embedder = GeometricusEmbedding.from_invariants(
             kmer_invariants,
             resolution=resolution,
             protein_keys=[s.name for s in self.structures],
         )
-        radius_embedder = geometricus.GeometricusEmbedding.from_invariants(
+        radius_embedder = GeometricusEmbedding.from_invariants(
             radius_invariants,
             resolution=resolution,
             protein_keys=[s.name for s in self.structures],
         )
-        return squareform(
+        distance_matrix = squareform(
             pdist(
                 np.hstack((kmer_embedder.embedding, radius_embedder.embedding)),
                 metric=metric,
             )
         )
+        return distance_matrix
 
     def make_pairwise_dtw_matrix(
-        self,
-        parameters: dict,
-        gap_open_penalty: float,
-        gap_extend_penalty: float,
-        invert=True,
+        self, gap_open_penalty: float, gap_extend_penalty: float, invert=True,
     ):
         """
         Makes an all vs. all matrix of distance (or similarity) scores between all the structures using pairwise alignment.
@@ -429,6 +439,7 @@ class StructureMultiple:
         -------
         [n x n] matrix
         """
+        typer.echo("Calculating pairwise distances...")
         pairwise_matrix = np.zeros((len(self.structures), len(self.structures)))
         for i in range(pairwise_matrix.shape[0] - 1):
             for j in range(i + 1, pairwise_matrix.shape[1]):
@@ -445,7 +456,6 @@ class StructureMultiple:
                 ) = self.get_pairwise_alignment(
                     coords_1,
                     coords_2,
-                    parameters,
                     gap_open_penalty=gap_open_penalty,
                     gap_extend_penalty=gap_extend_penalty,
                     weight=False,
@@ -461,9 +471,7 @@ class StructureMultiple:
         pairwise_matrix += pairwise_matrix.T
         return pairwise_matrix
 
-    def align(
-        self, pw_matrix, parameters, gap_open_penalty, gap_extend_penalty
-    ) -> dict:
+    def align(self, pw_matrix, gap_open_penalty, gap_extend_penalty) -> dict:
         """
         Makes a multiple structure alignment
 
@@ -478,7 +486,6 @@ class StructureMultiple:
         -------
         alignment = {name: indices of aligning residues with gaps as -1s}
         """
-        print("Aligning...")
         if len(self.structures) == 2:
             coords_1, coords_2 = (
                 self.structures[0].coordinates,
@@ -487,7 +494,6 @@ class StructureMultiple:
             dtw_1, dtw_2, _, _, _ = self.get_pairwise_alignment(
                 coords_1,
                 coords_2,
-                parameters,
                 gap_open_penalty=gap_open_penalty,
                 gap_extend_penalty=gap_extend_penalty,
                 weight=False,
@@ -496,13 +502,13 @@ class StructureMultiple:
                 self.structures[0].name: dtw_1,
                 self.structures[1].name: dtw_2,
             }
-            return self.alignment
+            return self.make_sequence_alignment()
         assert pw_matrix is not None
         assert pw_matrix.shape[0] == len(self.structures)
+        typer.echo("Constructing neighbor joining tree...")
         tree, branch_lengths = nj.neighbor_joining(pw_matrix)
         self.tree = tree
         self.branch_lengths = branch_lengths
-        print("Neighbor joining tree constructed")
         self.final_structures = [s for s in self.structures]
         self.final_consensus_weights = [
             np.full(
@@ -539,7 +545,6 @@ class StructureMultiple:
             ) = self.get_pairwise_alignment(
                 n1_coords,
                 n2_coords,
-                parameters,
                 gap_open_penalty=gap_open_penalty,
                 gap_extend_penalty=gap_extend_penalty,
                 weight=True,
@@ -566,18 +571,21 @@ class StructureMultiple:
                 n1_weights, n2_weights, dtw_aln_1, dtw_aln_2
             )
             self.final_structures.append(
-                protein_utility.Structure(name_int, mean_coords.shape[0], mean_coords)
+                Structure(name_int, mean_coords.shape[0], mean_coords)
             )
             self.final_consensus_weights.append(mean_weights)
 
-        for x in range(0, self.tree.shape[0] - 1, 2):
-            node_1, node_2, node_int = (
-                self.tree[x, 0],
-                self.tree[x + 1, 0],
-                self.tree[x, 1],
-            )
-            assert self.tree[x + 1, 1] == node_int
-            make_intermediate_node(node_1, node_2, node_int)
+        with typer.progressbar(
+            range(0, self.tree.shape[0] - 1, 2), label="Aligning"
+        ) as progress:
+            for x in progress:
+                node_1, node_2, node_int = (
+                    self.tree[x, 0],
+                    self.tree[x + 1, 0],
+                    self.tree[x, 1],
+                )
+                assert self.tree[x + 1, 1] == node_int
+                make_intermediate_node(node_1, node_2, node_int)
 
         node_1, node_2 = self.tree[-1, 0], self.tree[-1, 1]
         make_intermediate_node(node_1, node_2, "final")
@@ -586,37 +594,55 @@ class StructureMultiple:
             **msa_alignments[self.final_structures[node_2].name],
         }
         self.alignment = alignment
-        return alignment
+        return self.make_sequence_alignment(alignment)
+
+    def make_sequence_alignment(self, alignment=None):
+        sequence_alignment = {}
+        if alignment is None:
+            alignment = self.alignment
+        for s in self.structures:
+            sequence_alignment[s.name] = "".join(
+                self.sequences[s.name][i] if i != -1 else "-" for i in alignment[s.name]
+            )
+        return sequence_alignment
 
     def write_files(
-        self,
-        write_fasta,
-        write_pdb,
-        write_features,
-        write_class,
-        num_threads,
-        overwrite_dssp,
+        self, write_fasta, write_pdb, write_features, write_class, num_threads=4,
     ):
         if any((write_fasta, write_pdb, write_pdb, write_class)):
-            print("Writing files...")
+            typer.echo("Writing files...")
         if write_fasta:
-            self.write_alignment(self.output_folder / "result.fasta")
+            fasta_file = self.output_folder / "result.fasta"
+            self.write_alignment(fasta_file)
+            typer.echo(
+                f"FASTA file: {typer.style(str(fasta_file), fg=typer.colors.GREEN)}",
+            )
         if write_pdb:
-            pdb_folder = self.output_folder / "superposed_pdb"
+            pdb_folder = self.output_folder / "superposed_pdbs"
             if not pdb_folder.exists():
                 pdb_folder.mkdir()
             self.write_superposed_pdbs(pdb_folder)
+            typer.echo(
+                f"Superposed PDB files: {typer.style(str(pdb_folder), fg=typer.colors.GREEN)}"
+            )
         if write_features:
             dssp_dir = self.output_folder / ".caretta_tmp"
             if not dssp_dir.exists():
                 dssp_dir.mkdir()
-            with open(str(self.output_folder / "result_features.pkl"), "wb") as f:
-                pickle.dump(
-                    self.get_aligned_features(dssp_dir, num_threads, overwrite_dssp), f
-                )
+            feature_file = self.output_folder / "result_features.pkl"
+            self.features = self.get_aligned_features(str(dssp_dir), num_threads)
+            with open(feature_file, "wb") as f:
+                pickle.dump(self.features, f)
+            typer.echo(
+                f"Aligned features: {typer.style(str(feature_file), fg=typer.colors.GREEN)}"
+            )
         if write_class:
-            with open(str(self.output_folder / "result_class.pkl"), "wb") as f:
+            class_file = self.output_folder / "result_class.pkl"
+            with open(class_file, "wb") as f:
                 pickle.dump(self, f)
+            typer.echo(
+                f"Class file: {typer.style(str(class_file), fg=typer.colors.GREEN)}"
+            )
 
     def write_alignment(self, filename, alignments: dict = None):
         """
@@ -659,11 +685,11 @@ class StructureMultiple:
         )
         aln_ref = alignments[reference_name]
         ref_coords_core = (
-            reference_pdb[protein_utility.get_alpha_indices(reference_pdb)]
+            reference_pdb[helper.get_alpha_indices(reference_pdb)]
             .getCoords()
             .astype(np.float64)[np.array([aln_ref[c] for c in core_indices])]
         )
-        ref_centroid = nb_mean_axis_0(ref_coords_core)
+        ref_centroid = helper.nb_mean_axis_0(ref_coords_core)
         ref_coords_core -= ref_centroid
         transformation = pd.Transformation(np.eye(3), -ref_centroid)
         reference_pdb = pd.applyTransformation(transformation, reference_pdb)
@@ -675,7 +701,7 @@ class StructureMultiple:
             )
             aln_name = alignments[name]
             common_coords_2 = (
-                pdb[protein_utility.get_alpha_indices(pdb)]
+                pdb[helper.get_alpha_indices(pdb)]
                 .getCoords()
                 .astype(np.float64)[np.array([aln_name[c] for c in core_indices])]
             )
@@ -690,22 +716,24 @@ class StructureMultiple:
             pd.writePDB(str(output_pdb_folder / f"{name}.pdb"), pdb)
 
     def get_aligned_features(
-        self, dssp_dir, num_threads, force_overwrite, alignments: dict = None
-    ):
+        self, dssp_dir, num_threads, alignment: dict = None
+    ) -> typing.Dict[str, np.ndarray]:
         """
         Get dict of aligned features
         """
-        if alignments is None:
-            alignments = self.alignment
+        if alignment is None:
+            alignment = self.make_sequence_alignment()
+
+        pdb_files = [
+            self.output_folder / "cleaned_pdb" / f"{s.name}.pdb"
+            for s in self.structures
+        ]
         features = feature_extraction.get_features_multiple(
-            helper.parse_pdb_files(self.output_folder / "cleaned_pdb"),
-            str(dssp_dir),
-            num_threads=num_threads,
-            force_overwrite=force_overwrite,
+            pdb_files, str(dssp_dir), num_threads=num_threads, force_overwrite=True
         )
         feature_names = list(features[0].keys())
         aligned_features = {}
-        alignment_length = len(alignments[self.structures[0].name])
+        alignment_length = len(alignment[self.structures[0].name])
         for feature_name in feature_names:
             if feature_name == "secondary":
                 continue
@@ -720,7 +748,7 @@ class StructureMultiple:
                 indices = [
                     i
                     for i in range(alignment_length)
-                    if alignments[self.structures[p].name][i] != "-"
+                    if alignment[self.structures[p].name][i] != "-"
                 ]
                 aligned_features[feature_name][p, indices] = farray
         return aligned_features
@@ -736,7 +764,7 @@ class StructureMultiple:
         # core_indices = np.array([i for i in range(len(alignments[reference_key])) if '-' not in [alignments[n][i] for n in alignments]])
         aln_ref = alignments[reference_key]
         # ref_coords = self.structures[reference_index].coordinates[np.array([aln_ref[c] for c in core_indices])]
-        # ref_centroid = nb_mean_axis_0(ref_coords)
+        # ref_centroid = helper.nb_mean_axis_0(ref_coords)
         # ref_coords -= ref_centroid
         for i in range(len(self.structures)):
             # if i == reference_index:
@@ -813,3 +841,21 @@ class StructureMultiple:
                     j, i
                 ] = common_coords_1.shape[0] / len(aln_1)
         return pairwise_rmsd_matrix, pairwise_coverage
+
+
+def trigger_numba_compilation():
+    """
+    Run this at the beginning of a Caretta run to compile Numba functions
+    """
+    parameters = {"size": 1, "gap_open_penalty": 0.0, "gap_extend_penalty": 0.0}
+    coords_1 = np.zeros((2, 3))
+    coords_2 = np.zeros((2, 3))
+    superposition_functions.signal_svd_superpose_function(
+        coords_1, coords_2, parameters
+    )
+    distance_matrix = np.random.random((5, 5))
+    nj.neighbor_joining(distance_matrix)
+    aln_1 = np.array([0, -1, 1])
+    aln_2 = np.array([0, 1, -1])
+    get_common_coordinates(coords_1, coords_2, aln_1, aln_2)
+    get_mean_coords(aln_1, coords_1, aln_2, coords_2)
diff --git a/caretta/psa_numba.py b/caretta/psa_numba.py
deleted file mode 100644
index ebf7ca78cc4eed47c59126f8b1dc1c527260c922..0000000000000000000000000000000000000000
--- a/caretta/psa_numba.py
+++ /dev/null
@@ -1,188 +0,0 @@
-import numba as nb
-import numpy as np
-
-from caretta import dynamic_time_warping as dtw
-from caretta import rmsd_calculations, helper
-
-
-@nb.njit
-# @numba_cc.export('make_signal_index', 'f64[:](f64[:], i64)')
-def make_signal_index(coords, index):
-    centroid = coords[index]
-    distances = np.zeros(coords.shape[0])
-    for i in range(coords.shape[0]):
-        distances[i] = np.sqrt(np.sum((coords[i] - centroid) ** 2, axis=-1))
-    return distances
-
-
-@nb.njit
-# @numba_cc.export('dtw_signals_index', '(f64[:], f64[:], i64, i64, i64)')
-def dtw_signals_index(coords_1, coords_2, index, size=30, overlap=1):
-    signals_1 = np.zeros(((coords_1.shape[0] - size) // overlap, size))
-    signals_2 = np.zeros(((coords_2.shape[0] - size) // overlap, size))
-    middles_1 = np.zeros((signals_1.shape[0], coords_1.shape[1]))
-    middles_2 = np.zeros((signals_2.shape[0], coords_2.shape[1]))
-    if index == -1:
-        index = size - 1
-    for x, i in enumerate(range(0, signals_1.shape[0] * overlap, overlap)):
-        signals_1[x] = make_signal_index(coords_1[i : i + size], index)
-        middles_1[x] = coords_1[i + index]
-    for x, i in enumerate(range(0, signals_2.shape[0] * overlap, overlap)):
-        signals_2[x] = make_signal_index(coords_2[i : i + size], index)
-        middles_2[x] = coords_2[i + index]
-    distance_matrix = np.zeros((signals_1.shape[0], signals_2.shape[0]))
-    for i in range(signals_1.shape[0]):
-        for j in range(signals_2.shape[0]):
-            distance_matrix[i, j] = np.median(
-                np.exp(-0.1 * (signals_1[i] - signals_2[j]) ** 2)
-            )
-    dtw_1, dtw_2, _ = dtw.dtw_align(distance_matrix, 0.0, 0.0)
-    pos_1, pos_2 = helper.get_common_positions(dtw_1, dtw_2)
-    aln_coords_1 = np.zeros((len(pos_1), coords_1.shape[1]))
-    aln_coords_2 = np.zeros((len(pos_2), coords_2.shape[1]))
-    for i, (p1, p2) in enumerate(zip(pos_1, pos_2)):
-        aln_coords_1[i] = middles_1[p1]
-        aln_coords_2[i] = middles_2[p2]
-    coords_1, coords_2, _ = rmsd_calculations.superpose_with_pos(
-        coords_1, coords_2, aln_coords_1, aln_coords_2
-    )
-    return coords_1, coords_2
-
-
-@nb.njit
-# @numba_cc.export('get_dtw_signal_score_pos', '(f64[:], f64[:], f64, i64, f64, f64)')
-def get_dtw_signal_score_pos(
-    coords_1, coords_2, gamma, index, gap_open_penalty, gap_extend_penalty
-):
-    coords_1[:, :3], coords_2[:, :3] = dtw_signals_index(
-        coords_1[:, :3], coords_2[:, :3], index
-    )
-    distance_matrix = rmsd_calculations.make_distance_matrix(
-        coords_1[:, :3], coords_2[:, :3], gamma, normalized=False
-    )
-    dtw_aln_array_1, dtw_aln_array_2, _ = dtw.dtw_align(
-        distance_matrix, gap_open_penalty, gap_extend_penalty
-    )
-    pos_1, pos_2 = helper.get_common_positions(dtw_aln_array_1, dtw_aln_array_2)
-    common_coords_1, common_coords_2 = coords_1[pos_1], coords_2[pos_2]
-    rot, tran = rmsd_calculations.svd_superimpose(
-        common_coords_1[:, :3], common_coords_2[:, :3]
-    )
-    common_coords_2[:, :3] = rmsd_calculations.apply_rotran(
-        common_coords_2[:, :3], rot, tran
-    )
-    return (
-        rmsd_calculations.get_caretta_score(
-            common_coords_1, common_coords_2, gamma, False
-        ),
-        pos_1,
-        pos_2,
-    )
-
-
-@nb.njit
-# @numba_cc.export('get_secondary_distance_matrix', '(i8[:], i8[:], i8)')
-def get_secondary_distance_matrix(secondary_1, secondary_2, gap=0):
-    score_matrix = np.zeros((secondary_1.shape[0], secondary_2.shape[0]))
-    for i in range(secondary_1.shape[0]):
-        for j in range(secondary_2.shape[0]):
-            if secondary_1[i] == secondary_2[j]:
-                if secondary_1[i] != gap:
-                    score_matrix[i, j] = 1
-            else:
-                score_matrix[i, j] = -1
-    return score_matrix
-
-
-@nb.njit
-# @numba_cc.export('get_secondary_rmsd_pos', '(i8[:], i8[:], f64[:], f64[:], f64, f64, f64)')
-def get_secondary_rmsd_pos(
-    secondary_1, secondary_2, coords_1, coords_2, gamma, gap_open_sec, gap_extend_sec
-):
-    distance_matrix = get_secondary_distance_matrix(secondary_1, secondary_2)
-    dtw_aln_array_1, dtw_aln_array_2, _ = dtw.dtw_align(
-        distance_matrix, gap_open_sec, gap_extend_sec
-    )
-    pos_1, pos_2 = helper.get_common_positions(dtw_aln_array_1, dtw_aln_array_2)
-    common_coords_1, common_coords_2 = coords_1[pos_1][:, :3], coords_2[pos_2][:, :3]
-    rot, tran = rmsd_calculations.svd_superimpose(common_coords_1, common_coords_2)
-    common_coords_2 = rmsd_calculations.apply_rotran(common_coords_2, rot, tran)
-    return (
-        rmsd_calculations.get_caretta_score(
-            common_coords_1, common_coords_2, gamma, False
-        ),
-        pos_1,
-        pos_2,
-    )
-
-
-@nb.njit
-# @numba_cc.export('get_pairwise_alignment', '(f64[:], f64[:], i8[:], i8[:], f64, f64, f64, f64, f64, i64)')
-def get_pairwise_alignment(
-    coords_1,
-    coords_2,
-    secondary_1,
-    secondary_2,
-    gamma,
-    gap_open_sec,
-    gap_extend_sec,
-    gap_open_penalty,
-    gap_extend_penalty,
-    max_iter=3,
-):
-    rmsd_1, pos_1_1, pos_2_1 = get_dtw_signal_score_pos(
-        coords_1, coords_2, gamma, 0, gap_open_penalty, gap_extend_penalty
-    )
-    rmsd_2, pos_1_2, pos_2_2 = get_secondary_rmsd_pos(
-        secondary_1,
-        secondary_2,
-        coords_1[:, :3],
-        coords_2[:, :3],
-        gamma,
-        gap_open_sec,
-        gap_extend_sec,
-    )
-    rmsd_3, pos_1_3, pos_2_3 = get_dtw_signal_score_pos(
-        coords_1, coords_2, gamma, -1, gap_open_penalty, gap_extend_penalty
-    )
-    if rmsd_1 > rmsd_2:
-        if rmsd_3 > rmsd_1:
-            pos_1, pos_2 = pos_1_3, pos_2_3
-        else:
-            pos_1, pos_2 = pos_1_1, pos_2_1
-    else:
-        if rmsd_3 > rmsd_2:
-            pos_1, pos_2 = pos_1_3, pos_2_3
-        else:
-            pos_1, pos_2 = pos_1_2, pos_2_2
-    common_coords_1, common_coords_2 = coords_1[pos_1][:, :3], coords_2[pos_2][:, :3]
-    coords_1[:, :3], coords_2[:, :3], _ = rmsd_calculations.superpose_with_pos(
-        coords_1[:, :3], coords_2[:, :3], common_coords_1, common_coords_2
-    )
-    distance_matrix = rmsd_calculations.make_distance_matrix(
-        coords_1, coords_2, gamma, normalized=False
-    )
-    dtw_aln_array_1, dtw_aln_array_2, score = dtw.dtw_align(
-        distance_matrix, gap_open_penalty, gap_extend_penalty
-    )
-    for i in range(max_iter):
-        pos_1, pos_2 = helper.get_common_positions(dtw_aln_array_1, dtw_aln_array_2)
-        common_coords_1, common_coords_2 = (
-            coords_1[pos_1][:, :3],
-            coords_2[pos_2][:, :3],
-        )
-        coords_1[:, :3], coords_2[:, :3], _ = rmsd_calculations.superpose_with_pos(
-            coords_1[:, :3], coords_2[:, :3], common_coords_1, common_coords_2
-        )
-
-        distance_matrix = rmsd_calculations.make_distance_matrix(
-            coords_1, coords_2, gamma, normalized=False
-        )
-        dtw_1, dtw_2, new_score = dtw.dtw_align(
-            distance_matrix, gap_open_penalty, gap_extend_penalty
-        )
-        if int(new_score) > int(score):
-            dtw_aln_array_1, dtw_aln_array_2, score = dtw_1, dtw_2, new_score
-        else:
-            break
-    return dtw_aln_array_1, dtw_aln_array_2, score
diff --git a/caretta/rmsd_calculations.py b/caretta/rmsd_calculations.py
deleted file mode 100644
index f4ce3c63eebff90406e9fbbda797ce739d74462a..0000000000000000000000000000000000000000
--- a/caretta/rmsd_calculations.py
+++ /dev/null
@@ -1,165 +0,0 @@
-import numba as nb
-import numpy as np
-
-
-@nb.njit
-# @numba_cc.export('normalize', 'f64[:](f64[:])')
-def normalize(numbers):
-    minv, maxv = np.min(numbers), np.max(numbers)
-    return (numbers - minv) / (maxv - minv)
-
-
-@nb.njit
-# @numba_cc.export('nb_mean_axis_0', 'f64[:](f64[:])')
-def nb_mean_axis_0(array: np.ndarray) -> np.ndarray:
-    """
-    Same as np.mean(array, axis=0) but njitted
-    """
-    mean_array = np.zeros(array.shape[1])
-    for i in range(array.shape[1]):
-        mean_array[i] = np.mean(array[:, i])
-    return mean_array
-
-
-@nb.njit
-# @numba_cc.export('svd_superimpose', '(f64[:], f64[:])')
-def svd_superimpose(coords_1: np.ndarray, coords_2: np.ndarray):
-    """
-    Superimpose paired coordinates on each other using svd
-
-    Parameters
-    ----------
-    coords_1
-        numpy array of coordinate data for the first protein; shape = (n, 3)
-    coords_2
-        numpy array of corresponding coordinate data for the second protein; shape = (n, 3)
-
-    Returns
-    -------
-    rotation matrix, translation matrix for optimal superposition
-    """
-    centroid_1, centroid_2 = nb_mean_axis_0(coords_1), nb_mean_axis_0(coords_2)
-    coords_1_c, coords_2_c = coords_1 - centroid_1, coords_2 - centroid_2
-    correlation_matrix = np.dot(coords_2_c.T, coords_1_c)
-    u, s, v = np.linalg.svd(correlation_matrix)
-    reflect = np.linalg.det(u) * np.linalg.det(v) < 0
-    if reflect:
-        s[-1] = -s[-1]
-        u[:, -1] = -u[:, -1]
-    rotation_matrix = np.dot(u, v)
-    translation_matrix = centroid_1 - np.dot(centroid_2, rotation_matrix)
-    return rotation_matrix.astype(np.float64), translation_matrix.astype(np.float64)
-
-
-@nb.njit
-# @numba_cc.export('apply_rotran', '(f64[:], f64[:], f64[:])')
-def apply_rotran(
-    coords: np.ndarray, rotation_matrix: np.ndarray, translation_matrix: np.ndarray
-) -> np.ndarray:
-    """
-    Applies a rotation and translation matrix onto coordinates
-
-    Parameters
-    ----------
-    coords
-    rotation_matrix
-    translation_matrix
-
-    Returns
-    -------
-    transformed coordinates
-    """
-    return np.dot(coords, rotation_matrix) + translation_matrix
-
-
-# @numba_cc.export('superpose_with_pos', '(f64[:], f64[:], f64[:], f64[:])')
-@nb.njit
-def superpose_with_pos(coords_1, coords_2, common_coords_1, common_coords_2):
-    """
-    Superpose two sets of un-aligned coordinates using smaller subsets of aligned coordinates
-
-    Parameters
-    ----------
-    coords_1
-    coords_2
-    common_coords_1
-    common_coords_2
-
-    Returns
-    -------
-    superposed coord_1, superposed coords_2, superposed common_coords_2
-    """
-    rot, tran = svd_superimpose(common_coords_1, common_coords_2)
-    coords_1 = coords_1 - nb_mean_axis_0(common_coords_1)
-    coords_2 = np.dot(coords_2 - nb_mean_axis_0(common_coords_2), rot)
-    common_coords_2_rot = apply_rotran(common_coords_2, rot, tran)
-    return coords_1, coords_2, common_coords_2_rot
-
-
-@nb.njit
-# @numba_cc.export('make_distance_matrix', '(f64[:], f64[:], f64, b1)')
-def make_distance_matrix(
-    coords_1: np.ndarray, coords_2: np.ndarray, gamma, normalized=False
-) -> np.ndarray:
-    """
-    Makes matrix of euclidean distances of each coordinate in coords_1 to each coordinate in coords_2
-    TODO: probably faster to do upper triangle += transpose
-    Parameters
-    ----------
-    coords_1
-        shape = (n, 3)
-    coords_2
-        shape = (m, 3)
-    gamma
-    normalized
-    Returns
-    -------
-    matrix; shape = (n, m)
-    """
-    distance_matrix = np.zeros((coords_1.shape[0], coords_2.shape[0]))
-    for i in range(coords_1.shape[0]):
-        for j in range(coords_2.shape[0]):
-            distance_matrix[i, j] = np.exp(
-                -gamma * np.sum((coords_1[i] - coords_2[j]) ** 2, axis=-1)
-            )
-    if normalized:
-        return normalize(distance_matrix)
-    else:
-        return distance_matrix
-
-
-@nb.njit
-# @numba_cc.export('get_rmsd', '(f64[:], f64[:])')
-def get_rmsd(coords_1: np.ndarray, coords_2: np.ndarray) -> float:
-    """
-    RMSD of paired coordinates = normalized square-root of sum of squares of euclidean distances
-    """
-    return np.sqrt(np.sum((coords_1 - coords_2) ** 2) / coords_1.shape[0])
-
-
-@nb.njit
-# @numba_cc.export('get_caretta_score', '(f64[:], f64[:], f64, b1)')
-def get_caretta_score(
-    coords_1: np.ndarray, coords_2: np.ndarray, gamma, normalized
-) -> float:
-    """
-    Get caretta score for a a set of paired coordinates
-
-    Parameters
-    ----------
-    coords_1
-    coords_2
-    gamma
-    normalized
-
-    Returns
-    -------
-    Caretta score
-    """
-    score = 0
-    for i in range(coords_1.shape[0]):
-        score += np.exp(-gamma * np.sum((coords_1[i] - coords_2[i]) ** 2, axis=-1))
-    if normalized:
-        return score / coords_1.shape[0]
-    else:
-        return score
diff --git a/caretta/score_functions.py b/caretta/score_functions.py
index 1029055c8d7768dc6fc8c9208e0c4814cf1b7d90..65d4f15865262e86f35015127aad6a818fe77924 100644
--- a/caretta/score_functions.py
+++ b/caretta/score_functions.py
@@ -1,17 +1,6 @@
 import numba as nb
 import numpy as np
-
-
-@nb.njit
-def nan_normalize(numbers):
-    minv, maxv = np.nanmin(numbers), np.nanmax(numbers)
-    return (numbers - minv) / (maxv - minv)
-
-
-@nb.njit
-def normalize(numbers):
-    minv, maxv = np.min(numbers), np.max(numbers)
-    return (numbers - minv) / (maxv - minv)
+from caretta import helper
 
 
 @nb.njit
@@ -38,33 +27,6 @@ def get_rmsd(coords_1: np.ndarray, coords_2: np.ndarray) -> float:
     return np.sqrt(np.sum((coords_1 - coords_2) ** 2) / coords_1.shape[0])
 
 
-def make_score_matrix_python(
-    coords_1: np.ndarray, coords_2: np.ndarray, score_function, normalized=False
-) -> np.ndarray:
-    """
-    Makes matrix of scores of each coordinate in coords_1 to each coordinate in coords_2
-    Parameters
-    ----------
-    coords_1
-        shape = (n, 3)
-    coords_2
-        shape = (m, 3)
-    score_function
-    normalized
-    Returns
-    -------
-    matrix; shape = (n, m)
-    """
-    score_matrix = np.zeros((coords_1.shape[0], coords_2.shape[0]))
-    for i in range(coords_1.shape[0]):
-        for j in range(coords_2.shape[0]):
-            score_matrix[i, j] = score_function(coords_1[i], coords_2[j])
-    if normalized:
-        return normalize(score_matrix)
-    else:
-        return score_matrix
-
-
 @nb.njit
 def make_score_matrix(
     coords_1: np.ndarray, coords_2: np.ndarray, score_function, normalized=False
@@ -88,7 +50,7 @@ def make_score_matrix(
         for j in range(coords_2.shape[0]):
             score_matrix[i, j] = score_function(coords_1[i], coords_2[j])
     if normalized:
-        return normalize(score_matrix)
+        return helper.normalize(score_matrix)
     else:
         return score_matrix
 
@@ -119,38 +81,3 @@ def get_total_score(
         return score / coords_1.shape[0]
     else:
         return score
-
-
-@nb.njit
-# @numba_cc.export('get_common_positions', '(i64[:], i64[:], i64)')
-def get_common_positions(aln_array_1, aln_array_2, gap=-1):
-    """
-    Return positions where neither alignment has a gap
-
-    Parameters
-    ----------
-    aln_array_1
-    aln_array_2
-    gap
-
-    Returns
-    -------
-    common_positions_1, common_positions_2
-    """
-    pos_1 = np.array(
-        [
-            aln_array_1[i]
-            for i in range(len(aln_array_1))
-            if aln_array_1[i] != gap and aln_array_2[i] != gap
-        ],
-        dtype=np.int64,
-    )
-    pos_2 = np.array(
-        [
-            aln_array_2[i]
-            for i in range(len(aln_array_2))
-            if aln_array_1[i] != gap and aln_array_2[i] != gap
-        ],
-        dtype=np.int64,
-    )
-    return pos_1, pos_2
diff --git a/caretta/superposition_functions.py b/caretta/superposition_functions.py
index 08e999a7a32c9753c512165f09bf35fe0c1c4cab..1f31c2050d32852b4f3ea79f21209e9b702332f7 100644
--- a/caretta/superposition_functions.py
+++ b/caretta/superposition_functions.py
@@ -1,16 +1,12 @@
 import numba as nb
 import numpy as np
-from caretta import dynamic_time_warping as dtw, score_functions
-from geometricus import geometricus
-from geometricus.moment_utility import nb_mean_axis_0
-
-GAP_OPEN = 0
-GAP_EXTEND = 0
+from geometricus import MomentInvariants, SplitType
+from caretta import dynamic_time_warping as dtw, score_functions, helper
 
 """
 Provides pairwise superposition functions to use in Caretta
 Each takes coords_1, coords_2, parameters, score_function as input
-    parameters is a dict with gap_open_penalty, gap_extend_penalty, and other specific parameters as keys
+    parameters is a dict with gap_open_penalty, gap_extend_penalty, and other function-specific parameters as keys
 
 returns score, superposed_coords_1, superposed_coords_2
 """
@@ -50,7 +46,7 @@ def signal_superpose_function(
 ):
     """
     Makes initial superposition of coordinates using DTW alignment of overlapping signals
-    A signal is a vector of euclidean distances of first (or last) coordinate to all others in a 30-residue stretch
+    A signal is a vector of euclidean distances of first (or last) coordinate to all others in a "size"-residue stretch
     """
     score_first, c1_first, c2_first = _signal_superpose_index(
         0,
@@ -59,6 +55,7 @@ def signal_superpose_function(
         score_function,
         parameters["gap_open_penalty"],
         parameters["gap_extend_penalty"],
+        size=parameters["size"],
     )
     score_last, c1_last, c2_last = _signal_superpose_index(
         -1,
@@ -67,6 +64,7 @@ def signal_superpose_function(
         score_function,
         parameters["gap_open_penalty"],
         parameters["gap_extend_penalty"],
+        size=parameters["size"],
     )
     if score_first > score_last:
         return score_first, c1_first, c2_first
@@ -88,19 +86,19 @@ def moment_superpose_function(
     coords_1, coords_2, parameters, score_function=score_functions.get_caretta_score
 ):
     """
-    Uses 4 rotation/translation invariant moments for each 5-mer to run DTW
+    Uses 4 rotation/translation invariant moments for each "split_size"-mer to run DTW
     """
-    moments_1 = geometricus.MomentInvariants.from_coordinates(
+    moments_1 = MomentInvariants.from_coordinates(
         "name",
         coords_1,
-        split_type=geometricus.SplitType[parameters["split_type"]],
+        split_type=SplitType[parameters["split_type"]],
         split_size=parameters["split_size"],
         upsample_rate=parameters["upsample_rate"],
     ).moments
-    moments_2 = geometricus.MomentInvariants.from_coordinates(
+    moments_2 = MomentInvariants.from_coordinates(
         "name",
         coords_2,
-        split_type=geometricus.SplitType[parameters["split_type"]],
+        split_type=SplitType[parameters["split_type"]],
         split_size=parameters["split_size"],
         upsample_rate=parameters["upsample_rate"],
     ).moments
@@ -124,24 +122,24 @@ def moment_multiple_superpose_function(
     coords_1, coords_2, parameters, score_function=score_functions.get_caretta_score
 ):
     """
-    Uses 4 rotation/translation invariant moments for each 5-mer to run DTW
+    Uses 4 rotation/translation invariant moments for each "split_size"-mer with different fragmentation approaches to run DTW
     """
     moments_1 = []
     moments_2 = []
     for i in range(parameters["num_split_types"]):
         if "upsample_rate" not in parameters:
             parameters["upsample_rate"] = 10
-        moments_1_1 = geometricus.MomentInvariants.from_coordinates(
+        moments_1_1 = MomentInvariants.from_coordinates(
             "name",
             coords_1,
-            split_type=geometricus.SplitType[parameters[f"split_type_{i}"]],
+            split_type=SplitType[parameters[f"split_type_{i}"]],
             split_size=parameters[f"split_size_{i}"],
             upsample_rate=parameters["upsample_rate"],
         ).moments
-        moments_2_1 = geometricus.MomentInvariants.from_coordinates(
+        moments_2_1 = MomentInvariants.from_coordinates(
             "name",
             coords_2,
-            split_type=geometricus.SplitType[parameters[f"split_type_{i}"]],
+            split_type=SplitType[parameters[f"split_type_{i}"]],
             split_size=parameters[f"split_size_{i}"],
             upsample_rate=parameters["upsample_rate"],
         ).moments
@@ -169,7 +167,7 @@ def moment_multiple_svd_superpose_function(
     coords_1, coords_2, parameters, score_function=score_functions.get_caretta_score
 ):
     """
-    Uses moment_superpose followed by dtw_svd_superpose
+    Uses moment_multiple_superpose followed by dtw_svd_superpose
     """
     _, coords_1, coords_2 = moment_multiple_superpose_function(
         coords_1, coords_2, parameters
@@ -197,9 +195,7 @@ def _align_and_superpose(
     dtw_aln_array_1, dtw_aln_array_2, score = dtw.dtw_align(
         score_matrix, gap_open_penalty, gap_extend_penalty
     )
-    pos_1, pos_2 = score_functions.get_common_positions(
-        dtw_aln_array_1, dtw_aln_array_2
-    )
+    pos_1, pos_2 = helper.get_common_positions(dtw_aln_array_1, dtw_aln_array_2)
     common_coords_1, common_coords_2 = coords_1[pos_1], coords_2[pos_2]
     coords_1, coords_2, common_coords_2 = paired_svd_superpose_with_subset(
         coords_1, coords_2, common_coords_1, common_coords_2
@@ -223,7 +219,10 @@ def paired_svd_superpose(coords_1: np.ndarray, coords_2: np.ndarray):
     -------
     rotation matrix, translation matrix for optimal superposition
     """
-    centroid_1, centroid_2 = nb_mean_axis_0(coords_1), nb_mean_axis_0(coords_2)
+    centroid_1, centroid_2 = (
+        helper.nb_mean_axis_0(coords_1),
+        helper.nb_mean_axis_0(coords_2),
+    )
     coords_1_c, coords_2_c = coords_1 - centroid_1, coords_2 - centroid_2
     correlation_matrix = np.dot(coords_2_c.T, coords_1_c)
     u, s, v = np.linalg.svd(correlation_matrix)
@@ -255,8 +254,8 @@ def paired_svd_superpose_with_subset(
     superposed coord_1, superposed coords_2, superposed common_coords_2
     """
     rot, tran = paired_svd_superpose(common_coords_1, common_coords_2)
-    coords_1 = coords_1 - nb_mean_axis_0(common_coords_1)
-    coords_2 = np.dot(coords_2 - nb_mean_axis_0(common_coords_2), rot)
+    coords_1 = coords_1 - helper.nb_mean_axis_0(common_coords_1)
+    coords_2 = np.dot(coords_2 - helper.nb_mean_axis_0(common_coords_2), rot)
     common_coords_2_rot = apply_rotran(common_coords_2, rot, tran)
     return coords_1, coords_2, common_coords_2_rot
 
@@ -301,7 +300,7 @@ def _signal_superpose_index(
     dtw_1, dtw_2, score = dtw.dtw_align(
         score_matrix, gap_open_penalty, gap_extend_penalty
     )
-    pos_1, pos_2 = score_functions.get_common_positions(dtw_1, dtw_2)
+    pos_1, pos_2 = helper.get_common_positions(dtw_1, dtw_2)
     aln_coords_1 = np.zeros((len(pos_1), coords_1.shape[1]))
     aln_coords_2 = np.zeros((len(pos_2), coords_2.shape[1]))
     for i, (p1, p2) in enumerate(zip(pos_1, pos_2)):
@@ -314,7 +313,6 @@ def _signal_superpose_index(
 
 
 @nb.njit
-# @numba_cc.export('svd_superimpose', '(f64[:], f64[:])')
 def svd_superimpose(coords_1: np.ndarray, coords_2: np.ndarray):
     """
     Superimpose paired coordinates on each other using svd
@@ -330,7 +328,10 @@ def svd_superimpose(coords_1: np.ndarray, coords_2: np.ndarray):
     -------
     rotation matrix, translation matrix for optimal superposition
     """
-    centroid_1, centroid_2 = nb_mean_axis_0(coords_1), nb_mean_axis_0(coords_2)
+    centroid_1, centroid_2 = (
+        helper.nb_mean_axis_0(coords_1),
+        helper.nb_mean_axis_0(coords_2),
+    )
     coords_1_c, coords_2_c = coords_1 - centroid_1, coords_2 - centroid_2
     correlation_matrix = np.dot(coords_2_c.T, coords_1_c)
     u, s, v = np.linalg.svd(correlation_matrix)
@@ -344,7 +345,6 @@ def svd_superimpose(coords_1: np.ndarray, coords_2: np.ndarray):
 
 
 @nb.njit
-# @numba_cc.export('apply_rotran', '(f64[:], f64[:], f64[:])')
 def apply_rotran(
     coords: np.ndarray, rotation_matrix: np.ndarray, translation_matrix: np.ndarray
 ) -> np.ndarray:
@@ -364,7 +364,6 @@ def apply_rotran(
     return np.dot(coords, rotation_matrix) + translation_matrix
 
 
-# @numba_cc.export('superpose_with_pos', '(f64[:], f64[:], f64[:], f64[:])')
 @nb.njit
 def superpose_with_pos(coords_1, coords_2, common_coords_1, common_coords_2):
     """
@@ -382,7 +381,7 @@ def superpose_with_pos(coords_1, coords_2, common_coords_1, common_coords_2):
     superposed coord_1, superposed coords_2, superposed common_coords_2
     """
     rot, tran = svd_superimpose(common_coords_1, common_coords_2)
-    coords_1 = coords_1 - nb_mean_axis_0(common_coords_1)
-    coords_2 = np.dot(coords_2 - nb_mean_axis_0(common_coords_2), rot)
+    coords_1 = coords_1 - helper.nb_mean_axis_0(common_coords_1)
+    coords_2 = np.dot(coords_2 - helper.nb_mean_axis_0(common_coords_2), rot)
     common_coords_2_rot = apply_rotran(common_coords_2, rot, tran)
     return coords_1, coords_2, common_coords_2_rot
diff --git a/setup.py b/setup.py
index 1b9c9d23e1de36ff50834658b44a122088262d7b..942d223e90e319bde1445fc22bc607f0f7ea809d 100644
--- a/setup.py
+++ b/setup.py
@@ -11,7 +11,7 @@ setup(
         "scipy",
         "prody",
         "biopython",
-        "fire",
+        "typer",
         "pyparsing",
     ],
     extras_require={
@@ -22,7 +22,6 @@ setup(
             "dash-core-components==1.2.1",
             "dash-html-components==1.0.1",
             "dash-renderer==1.1.0",
-            "dash-table==4.3.0",
             "plotly==3.7.1",
             "flask",
         ]