from clearml.automation.controller import PipelineDecorator
from clearml import TaskTypes

@PipelineDecorator.component(
        return_values=['dataset_name_training, dataset_name_test, dataset_project'],
        cache=True,
        repo="https://git.wur.nl/mdt-research-it-solutions/clearml-demo.git",
        task_type=TaskTypes.data_processing,
        packages="./requirements.txt"
        )
def step_one(training_path: str = 'data/mnist_png/training',
            test_path: str = 'data/mnist_png/testing',
            dataset_project: str = "pipeline",
            dataset_name_training: str = "training_dataset",
            dataset_name_test: str = "testing_dataset"):

    print('step_one')
    # make sure we have scikit-learn for this step, we need it to use to unpickle the object
    from clearml import Dataset

    dataset_train = Dataset.create(
        dataset_name=dataset_name_training, dataset_project=dataset_project
    )
    dataset_test = Dataset.create(
        dataset_name=dataset_name_test, dataset_project=dataset_project
    )

    dataset_train.add_files(path=training_path)
    dataset_test.add_files(path=test_path)
    dataset_train.upload()
    dataset_test.upload()
    dataset_train.finalize()
    dataset_test.finalize()

    return dataset_name_training, dataset_name_test, dataset_project

@PipelineDecorator.component(
        return_values=['model'],
        cache=True,
        repo="https://git.wur.nl/mdt-research-it-solutions/clearml-demo.git",
        task_type=TaskTypes.training,
        packages="./requirements.txt"
    )
def step_two(dataset_name_training,
            dataset_name_test,
            dataset_project,
            epochs: int = 10,
            train_batch_size: int = 256,
            validation_batch_size: int = 256,
            train_num_workers: int = 0,
            validation_num_workers: int = 0,
            resize: int = 28,
            lr: float = 1e-3
        ):

    print('step_two')
    # make sure we have pandas for this step, we need it to use the data_frame
    import pickle
    import torchvision.transforms as transforms
    import torchvision.datasets as datasets
    from torch.utils.data import DataLoader
    import torch
    import torch.nn as nn
    import torch.optim as optim
    import time
    from tqdm.auto import tqdm
    from model import CNNModel
    from model_utils import train, validate
    from clearml import Logger
    from clearml import StorageManager
    from clearml import Dataset

    mnist_train = Dataset.get(
        dataset_name=dataset_name_training, dataset_project=dataset_project
    ).get_local_copy()
    mnist_test = Dataset.get(
        dataset_name=dataset_name_test, dataset_project=dataset_project
    ).get_local_copy()

    # get logger
    logger = Logger.current_logger()

    # the training transforms
    train_transform = transforms.Compose([
        transforms.Resize(resize),
        #transforms.RandomHorizontalFlip(p=0.5),
        #transforms.RandomVerticalFlip(p=0.5),
        #transforms.GaussianBlur(kernel_size=(5, 9), sigma=(0.1, 5)),
        #transforms.RandomRotation(degrees=(30, 70)),
        transforms.ToTensor(),
        transforms.Normalize(
            mean=[0.5, 0.5, 0.5],
            std=[0.5, 0.5, 0.5]
        )
    ])
    # the validation transforms
    valid_transform = transforms.Compose([
        transforms.Resize(resize),
        transforms.ToTensor(),
        transforms.Normalize(
            mean=[0.5, 0.5, 0.5],
            std=[0.5, 0.5, 0.5]
        )
    ])

    # training dataset
    train_dataset = datasets.ImageFolder(
        root=mnist_train,
        transform=train_transform
    )
    # validation dataset
    valid_dataset = datasets.ImageFolder(
        root=mnist_test,
        transform=valid_transform
    )
    # training data loaders
    train_loader = DataLoader(
        train_dataset, batch_size=train_batch_size, shuffle=True,
        num_workers=train_num_workers, pin_memory=True
    )
    # validation data loaders
    valid_loader = DataLoader(
        valid_dataset, batch_size=validation_batch_size, shuffle=False,
        num_workers=validation_num_workers, pin_memory=True
    )

    device = ('cuda' if torch.cuda.is_available() else 'cpu')

    print(f"Computation device: {device}\n")

    model = CNNModel().to(device)
    print(model)

    # total parameters and trainable parameters
    total_params = sum(p.numel() for p in model.parameters())
    print(f"{total_params:,} total parameters.")

    total_trainable_params = sum(
        p.numel() for p in model.parameters() if p.requires_grad)
    print(f"{total_trainable_params:,} training parameters.")

    # optimizer
    optimizer = optim.Adam(model.parameters(), lr=lr)

    # loss function
    criterion = nn.CrossEntropyLoss()

    # lists to keep track of losses and accuracies
    train_loss, valid_loss = [], []
    train_acc, valid_acc = [], []

    # start the training
    for epoch in range(epochs):
        print(f"[INFO]: Epoch {epoch+1} of {epochs}")
        train_epoch_loss, train_epoch_acc = train(model, train_loader,
                                                  optimizer, criterion, device)
        valid_epoch_loss, valid_epoch_acc = validate(model, valid_loader,
                                                     criterion, device)
        train_loss.append(train_epoch_loss)
        valid_loss.append(valid_epoch_loss)
        train_acc.append(train_epoch_acc)
        valid_acc.append(valid_epoch_acc)
        print(f"Training loss: {train_epoch_loss:.3f}, training acc: {train_epoch_acc:.3f}")
        logger.report_scalar(
                "loss", "train", iteration=epoch, value=train_epoch_loss
            )
        logger.report_scalar(
                "accuracy", "train", iteration=epoch, value=train_epoch_acc
            )
        print(f"Validation loss: {valid_epoch_loss:.3f}, validation acc: {valid_epoch_acc:.3f}")
        logger.report_scalar(
                "loss", "validation", iteration=epoch, value=valid_epoch_loss
            )
        logger.report_scalar(
                "accuracy", "validation", iteration=epoch, value=valid_epoch_acc
            )
    return model

# The actual pipeline execution context
# notice that all pipeline component function calls are actually executed remotely
# Only when a return value is used, the pipeline logic will wait for the component execution to complete
@PipelineDecorator.pipeline(
        name='pipeline test',
        pipeline_execution_queue="test",
        repo="https://git.wur.nl/mdt-research-it-solutions/clearml-demo.git",
        project='pipeline_deco',
        version='0.0.5',
        add_pipeline_tags=True
    )
def executing_pipeline(
        training_path='data/mnist_png/training',
        test_path='data/mnist_png/testing'
    ):
    from utils import save_model
    from clearml import OutputModel
    import torch

    # Use the pipeline argument to start the pipeline and pass it ot the first step
    print('launch step one')

    dataset_name_training, dataset_name_test, dataset_project = step_one(
                                            training_path=training_path,
                                            test_path=test_path,
                                            dataset_project="pipeline",
                                            dataset_name_training="training_dataset",
                                            dataset_name_test="testing_dataset"
                                        )
    # Use the returned data from the first step (`step_one`), and pass it to the next step (`step_two`)
    # Notice! unless we actually access the `data_frame` object,
    # the pipeline logic does not actually load the artifact itself.
    # When actually passing the `data_frame` object into a new step,
    # It waits for the creating step/function (`step_one`) to complete the execution
    print('launch step two')

    model = step_two(
                    dataset_name_training=dataset_name_training,
                    dataset_name_test=dataset_name_test,
                    dataset_project=dataset_project,
                    epochs=10,
                    train_batch_size=256,
                    validation_batch_size=256,
                    train_num_workers=0,
                    validation_num_workers=0,
                    resize=28,
                    lr=1e-3
                )

    # store in a way we can easily load into triton without having to have the model class
    torch.jit.script(model).save('serving_model.pt')
    OutputModel().update_weights('serving_model.pt')

if __name__ == '__main__':
    # set the pipeline steps default execution queue (per specific step we can override it with the decorator)
    PipelineDecorator.set_default_execution_queue('test')
    # Run the pipeline steps as subprocesses on the current machine, great for local executions
    # (for easy development / debugging, use `PipelineDecorator.debug_pipeline()` to execute steps as regular functions)
    #PipelineDecorator.run_locally()

    # Start the pipeline execution logic.
    executing_pipeline(
        training_path='data/mnist_png/training',
        test_path='data/mnist_png/testing'
    )

    print('process completed')