Skip to content
Snippets Groups Projects
tracker2d.py 9.49 KiB
Newer Older
Cribellier, Antoine's avatar
Cribellier, Antoine committed
import os
Cribellier, Antoine's avatar
Cribellier, Antoine committed
import numpy as np
from PIL import Image
from collections import defaultdict

from dataclasses import dataclass


@dataclass
    type: str = 'KNN'  # 'KNN' or 'MOG2'
    varThreshold: int = 12
    dist2Threshold: int = 50
    shadow_threshold: int = 1


@dataclass
class BlobDetectorSettings:
    threshold_min: int = 1
    threshold_max: int = 255
    
    area_min: int = 200
    area_max: int = 10000
    
    inertia_ratio_min: float = 0.0
    inertia_ratio_max: float = 1.0
    
    aspect_ratio_min: float = 0.8
    eucli_dist_max: float = 5.0

Cribellier, Antoine's avatar
Cribellier, Antoine committed
class Tracker2D:
    def __init__(self):
        """
        Class to track 2d coordinates of blobs in images
        """

        self.show_plot = False

        self.images, self.frames = [], []

        self.points = defaultdict(list)

        self.back_subtractor_settings = BackgroundSubtractorSettings()
        self.blob_detector_settings = BlobDetectorSettings()
Cribellier, Antoine's avatar
Cribellier, Antoine committed

Cribellier, Antoine's avatar
Cribellier, Antoine committed

Cribellier, Antoine's avatar
Cribellier, Antoine committed

        self.get_blob_detector()
Cribellier, Antoine's avatar
Cribellier, Antoine committed

    def load_images(self, image_names: List[str], path: str, frames: List[int] = None) -> None:
Cribellier, Antoine's avatar
Cribellier, Antoine committed
        """

        Args:
            image_names: a List of image names
            path: The path where the images are
            frames: a list of int with the frame number corresponding to the images
Cribellier, Antoine's avatar
Cribellier, Antoine committed
        """

        self.images = [cv2.imread(os.path.join(path, image_name), cv2.IMREAD_GRAYSCALE) for image_name in image_names]
Cribellier, Antoine's avatar
Cribellier, Antoine committed

        if frames is None:
            self.frames = list(range(1, len(self.images) +1))
        else:
            self.frames = frames
Cribellier, Antoine's avatar
Cribellier, Antoine committed

        in_sort = np.argsort(self.frames)
        self.images = [self.images[in_s] for in_s in in_sort]
        self.frames = [self.frames[in_s] for in_s in in_sort]

    def do_tracking(self):
        """

        """

        assert len(self.images) != 0, "Images need to be loaded first"
Cribellier, Antoine's avatar
Cribellier, Antoine committed

        [self.back_subtractor.apply(image, learningRate=-1) for image in self.images[:-6]]
Cribellier, Antoine's avatar
Cribellier, Antoine committed

        self.background = self.back_subtractor.getBackgroundImage()
        self.background = cv2.medianBlur(self.background, 5)  # Add median filter to images
Cribellier, Antoine's avatar
Cribellier, Antoine committed

        self.points = defaultdict(list)

        self.foreground_masks = self.images.copy()
        for i, image in enumerate(self.images):
            self.foreground_masks[i] = self.back_subtractor.apply(image, learningRate=-1)
            self.foreground_masks[i] = cv2.medianBlur(self.foreground_masks[i], 3)  # Add median filter to images
Cribellier, Antoine's avatar
Cribellier, Antoine committed

            blobs = self.blob_detector.detect((255 - self.foreground_masks[i]))  # Detect blobs

            x_blob, y_blob = [], []
            if len(blobs) is not 0:
                for blob in blobs:
                    x, y = blob.pt  # (width, height) with zero on top-left corner
                    x_blob.append(x), y_blob.append(y)

                    self.points['x'].append(x)
                    self.points['y'].append(y)
                    self.points['frame'].append(self.frames[i])
                    self.points['area'].append(blob.size ** 2)

            # contours, _ = cv2.findContours(self.foreground_masks[i], cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)[-2:]
            # for contour in contours:
            #     approx = cv2.approxPolyDP(contour, 0.005 * cv2.arcLength(contour, True), True)
            #     contour_area = cv2.contourArea(contour)
Cribellier, Antoine's avatar
Cribellier, Antoine committed

    def gen_stroboscopic_image(self, radius, shape='square', step=1, flip_time=False, background_color=-1):
        """

        """

        if len(self.images) is 0:
            print("WARN: you need to load images first")
            return

        if 'frame' not in self.points.keys() or len(self.points['frame']) is 0:
            print("WARN: you need to track or load 2d points first")
            return

        [self.back_subtractor.apply(image, learningRate=-1) for image in self.images[:-6]]
Cribellier, Antoine's avatar
Cribellier, Antoine committed

        if flip_time:  # Will start analysing last frame
            images, frames = np.flip(self.images, axis=0), np.flip(self.frames)
        else:
            images, frames = self.images, self.frames

        self.background = self.back_subtractor.getBackgroundImage()
        self.background = cv2.medianBlur(self.background, 5)  # Add median filter to images
Cribellier, Antoine's avatar
Cribellier, Antoine committed

        if 'strip' in shape: height, width = images[0].shape

        if 0 <= background_color <= 1:
            self.strobe_img = np.ones(images[0].shape, np.uint8) * 255 * background_color
        else:
            self.strobe_img = self.background.copy()

        self.foreground_masks = images.copy()
        for i, image in enumerate(images):
            if not i % step == 0: continue

            self.foreground_masks[i] = self.back_subtractor.apply(image, learningRate=0)
            self.foreground_masks[i] = cv2.medianBlur(self.foreground_masks[i], 3)  # Add median filter to images
Cribellier, Antoine's avatar
Cribellier, Antoine committed

            in_frame = np.where(frames[i] == self.points['frame'])[0]

            if len(in_frame) == 0: continue
            elif len(in_frame) > 1: in_frame = in_frame[0]

            if np.isnan(self.points['x'][in_frame]): continue
            x, y = int(self.points['x'][in_frame]), int(self.points['y'][in_frame])

            new_mask = np.zeros(image.shape, np.uint8)
            if 'rectangle' in shape:
                new_mask = cv2.rectangle(new_mask, (x - radius, y - radius), (x + radius, y + radius), (255), -1)

            elif 'circle' in shape:
                new_mask = cv2.circle(new_mask, (x, y), radius, (255), -1)
                
            elif 'strip' in shape:
                if 'height' in shape:
                    new_mask = cv2.rectangle(new_mask, (x - radius, 0), (x + radius, height), (255), -1)
                if 'width' in shape:
                    new_mask = cv2.rectangle(new_mask, (0, y - radius), (width, y + radius), (255), -1)
            else:
                print('WARN: cannot recognize shape (', shape, ')!!')

            self.foreground_masks[i] = np.multiply(self.foreground_masks[i] / 255, new_mask /255)
            self.strobe_img = image * self.foreground_masks[i] + self.strobe_img * (1 - self.foreground_masks[i])

        self.strobe_img = self.strobe_img.astype(np.uint8)

    def get_back_subtractor(self):
        if self.back_subtractor_settings.type is 'MOG2':
            self.back_subtractor = cv2.createBackgroundSubtractorMOG2(history=-1, varThreshold=self.back_subtractor_settings.varThreshold, detectShadows=True)
        elif self.back_subtractor_settings.type is 'KNN':
            self.back_subtractor = cv2.createBackgroundSubtractorKNN(history=-1, dist2Threshold=self.back_subtractor_settings.dist2Threshold, detectShadows=True)
            self.back_subtractor.setShadowThreshold(self.back_subtractor_settings.shadow_threshold)  # 0.5 by default
Cribellier, Antoine's avatar
Cribellier, Antoine committed

    def get_blob_detector(self):
        # Setup SimpleBlobDetector parameters.
        params = cv2.SimpleBlobDetector_Params()

        params.minThreshold = self.blob_detector_settings.threshold_min
        params.maxThreshold = self.blob_detector_settings.threshold_max
Cribellier, Antoine's avatar
Cribellier, Antoine committed

        params.filterByArea = True
        params.minArea = self.blob_detector_settings.area_min
        params.maxArea = self.blob_detector_settings.area_max
Cribellier, Antoine's avatar
Cribellier, Antoine committed

        params.filterByInertia = True
        params.minInertiaRatio = self.blob_detector_settings.inertia_ratio_min
        params.maxInertiaRatio = self.blob_detector_settings.inertia_ratio_max
Cribellier, Antoine's avatar
Cribellier, Antoine committed

        params.filterByCircularity = False
        params.filterByConvexity = False

        # Create a self.blob_detector with the parameters
        ver = cv2.__version__.split('.')
        if int(ver[0]) < 3:
            self.blob_detector = cv2.SimpleBlobDetector(params)
        else:
            self.blob_detector = cv2.SimpleBlobDetector_create(params)

    def save_csv(self, save_name, save_path):
        """
        Save 2d coordinates in .csv file
        """

        # if os.path.exists(os.path.join(save_path, save_name + '-2d_points.csv')):
        #     os.remove(os.path.join(save_path, save_name + '-2d_points.csv'))
        np.savetxt(os.path.join(save_path, save_name + '-2d_points.csv'),
                   np.c_[self.points['frame'], self.points['x'], self.points['y']], delimiter=',', header='frame,x_px,y_px')

    def save_stroboscopic_image(self, save_name, save_path):
        Image.fromarray(self.strobe_img).save(os.path.join(save_path, save_name + '-strobe.tif'))

Cribellier, Antoine's avatar
Cribellier, Antoine committed

if __name__ == '__main__':
Cribellier, Antoine's avatar
Cribellier, Antoine committed

    image_format = 'tif'
    # rec_name = 'cam1_20200303_030117'
    # image_path = os.path.join('/media/user/MosquitoEscape_Photron1/Photron1/_Process/_DownSized/', rec_name)
    nb_cam = 3
    rec_names = {1: 'cam1_20200303_030117', 2: 'cam2_20200303_030120', 3: 'cam3_20200303_030120'}
    image_paths = {1: '/media/user/MosquitoEscape_Photron1/Photron1/_Process/_DownSized/',
                   2: '/media/user/MosquitoEscape_Photron2/Photron2/_Process/_DownSized/',
                   3: '/media/user/MosquitoEscape_Photron3/Photron3/_Process/_DownSized/'}

    tracker = Tracker2D()

    for camn in range(1, nb_cam + 1):
        print(os.path.join(image_paths[camn], rec_names[camn]))

        all_image_paths = glob.glob(os.path.join(image_paths[camn], '*.{0}'.format(image_format)))
        image_names = [os.path.basename(image_path) for image_path in all_image_paths]
        frames = [int(image_name[20:-len('.{0}'.format(image_format))]) for image_name in image_names]

        tracker.load_images(image_names, image_paths[camn], frames)
Cribellier, Antoine's avatar
Cribellier, Antoine committed
        tracker.do_tracking()

        tracker.save_csv(rec_names[camn], os.path.join(image_paths[camn], rec_names[camn]))