Skip to content
Snippets Groups Projects
Commit dc2808c8 authored by Freek Daniëls's avatar Freek Daniëls
Browse files

Merge branch 'develop-freek' into 'develop'

Update code/4. Data loader.ipynb

See merge request !6
parents 947f57c0 438e6978
No related branches found
No related tags found
2 merge requests!7Develop,!6Update code/4. Data loader.ipynb
%% Cell type:code id: tags:
``` python
from __future__ import print_function
from __future__ import division
import pandas as pd
import numpy as np
import glob
import os
import random
import time
import copy
import json
from dataclasses import dataclass
from PIL import Image
from skimage import io
import matplotlib.pyplot as plt
import torch.utils.data
import torch
import torchvision.transforms.functional as TF
from IPython.display import clear_output
```
%% Cell type:code id: tags:
``` python
%run "0. Config.ipynb"
%run "1. Utils.ipynb"
```
%% Cell type:code id: tags:
``` python
class SeagrassDataset(torch.utils.data.Dataset):
def __init__(self, path_to_images, path_to_masks, path_to_configurations, path_to_files, transforms=False):
self.path_to_images = path_to_images
self.path_to_masks = path_to_masks
self.annotations = []
files = pd.read_csv(path_to_files,usecols=["file"])
for row in files['file']:
configurations = path_to_configurations + row + ".json"
with open(configurations) as json_file:
data = json.load(json_file)
name = data['name']
for configuration in data['configurations']:
annotation = {}
annotation['name'] = name
annotation['configuration'] = configuration
self.annotations.append(annotation)
self.annotations.append(annotation)
self.transforms = transforms
def __getitem__(self, idx):
annotation = self.annotations[idx]
name = annotation['name']
configuration = annotation["configuration"]
configuration_mask = configuration["mask"]
mask_left = configuration_mask['left']
mask_right = configuration_mask['right']
mask_top = configuration_mask['top']
mask_bottom = configuration_mask['bottom']
image = io.imread(self.path_to_images + name)
mask = io.imread(self.path_to_masks + name)
if self.transforms:
image = pad_image(image,image_padding_for_augmentation)
mask = pad_image(mask,mask_padding_for_augmentation)
image_offset = image_padding_for_augmentation
mask_offset = mask_padding_for_augmentation
else:
image = pad_image(image,image_padding)
image_offset = image_padding
mask_offset = 0
image_left = mask_left-image_offset+image_offset
image_top = mask_top-image_offset+image_offset
image_right = mask_right+image_offset+image_offset
image_bottom = mask_bottom+image_offset+image_offset
mask_left = mask_left-mask_offset+mask_offset
mask_top = mask_top-mask_offset+mask_offset
mask_right = mask_right+mask_offset+mask_offset
mask_bottom = mask_bottom+mask_offset+mask_offset
image = Image.fromarray(image).convert("RGB")
mask = Image.fromarray(mask)
if self.transforms:
image = image.crop((image_left,image_top,image_right,image_bottom))
mask = mask.crop((mask_left,mask_top,mask_right,mask_bottom))
if random.random() > 0.5:
image = TF.hflip(image)
mask = TF.hflip(mask)
if random.random() > 0.5:
image = TF.vflip(image)
mask = TF.vflip(mask)
degrees = random.random()*180
image = TF.affine(image,degrees,translate=[0,0], scale=1, shear=0.0)
mask = TF.affine(mask,degrees,translate=[0,0], scale=1, shear=0.0)
image = TF.center_crop(image,image_size)
mask = TF.center_crop(mask,mask_size)
image = TF.to_tensor(image)
image = TF.normalize(image,mean,std)
mask = torch.tensor(np.array(mask)).long()
return image, mask, name
def __len__(self):
return len(self.annotations)
```
%% Cell type:code id: tags:
``` python
datasets = {"train":SeagrassDataset(path_to_images,path_to_masks,path_to_configurations,training_file,True),
"val":SeagrassDataset(path_to_images,path_to_masks,path_to_configurations,validation_file),
"test":SeagrassDataset(path_to_images,path_to_masks,path_to_configurations,test_file)}
```
%% Cell type:code id: tags:
``` python
image,mask,name = next(iter(datasets["train"]))
image = np.transpose(np.asarray(image),axes=[1,2,0])*std+mean
mask = np.asarray(mask)
plt.figure(figsize=(10,5))
plt.subplot(1, 2, 1)
plt.axis('off')
plt.tight_layout()
plt.imshow(image)
plt.subplot(1, 2, 2)
plt.axis('off')
plt.tight_layout()
plt.imshow(get_coloured_mask(mask,labels))
plt.show()
```
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment