# -*- coding: utf-8 -*-

import logging
import numpy as np
from skimage.io import imread
import matplotlib.pyplot as plt

import vrl

from .polsar import read_polsar, gen_rgb_from_polsar

rng = vrl.base.RNG(seed=1234567)


logger = vrl.utils.get_logger('sartb')
vrl.utils.set_loglevel(logging.INFO)


class DatasetNames(object):
    @property
    def SanFrancisco(self):
        return 'SanFrancisco'

    @property
    def Flevoland(self):
        return 'Flevoland'

    @property
    def FlevolandMuellerMatrix(self):
        return 'FlevolandMuellerMatrix'

    @property
    def FoulumS2(self):
        return 'FoulumS2'

    @property
    def FoulumC3(self):
        return 'FoulumC3'


def get_dataset(base_path, dataset_name, mode, seed, nlooks=None):
    rng = vrl.base.RNG(seed=seed)

    assert(dataset_name in ['SanFrancisco', 'Flevoland', 'FlevolandMuellerMatrix', 'FoulumS2', 'FoulumC3'])
    #### San Francisco Bay Datasets ####
    if(dataset_name is 'SanFrancisco'):
        dataset = SanFranciscoBayLandCover('{}/AIRSAR/{}'.format(base_path, dataset_name), mode=mode, rng=rng)
        #dataset = sartb.SanFranciscoBay2Classes(ipath, mode=mode, rng=rng)
        #dataset = sartb.SanFranciscoBayForestOcean(ipath, mode=mode, rng=rng)
        #dataset = sartb.SanFranciscoBayUrbanForestOcean(ipath, mode=mode, rng=rng)
        #dataset = sartb.SanFranciscoBay3Classes(ipath, mode=mode, rng=rng)

    #### Flevoland Datasets ####
    elif(dataset_name is 'Flevoland'):
        dataset = FlevolandLandCover('{}/AIRSAR/{}'.format(base_path, dataset_name), mode=mode, rng=rng)

    elif(dataset_name is 'FlevolandMuellerMatrix'):
        dataset = FlevolandMuellerMatrix('{}/AIRSAR/{}'.format(base_path, 'Flevoland'), mode=mode, rng=rng)

    #### Foulum Datasets ####
    elif(dataset_name is 'FoulumS2'):
        dataset = FoulumS2LandCover('{}/EMISAR/{}'.format(base_path, dataset_name), mode=mode, rng=rng,
                                    nlooks=nlooks)
    elif(dataset_name is 'FoulumC3'):
        dataset = FoulumC3('{}/EMISAR/{}'.format(base_path, dataset_name), mode=mode, rng=rng, nlooks=nlooks)

    return dataset


class PolSARDataset(object):
    def __init__(self, path, mode, rng, nlooks=None):
        self._path = path
        self._mode = mode
        self._overlap = True
        self._nlooks = nlooks
        self._rng = rng
        self._read_polsar()
        if(self._nlooks is not None):
            self._multilook()
        self._rgb = None

        self._read_mask()
        self._samples = None

    @property
    def width(self):
        return self._polsar.shape[1]

    @property
    def height(self):
        return self._polsar.shape[0]

    @property
    def nsamples(self):
        return self.width * self.height

    @property
    def nclasses(self):
        return len(self._classes)

    @property
    def polsar(self):
        return self._polsar

    @property
    def rgb(self):
        if(self._rgb is None):
            self._rgb = gen_rgb_from_polsar(self._polsar)
        return self._rgb

    @property
    def mode(self):
        return self._mode

    @property
    def overlap(self):
        return self._overlap

    @property
    def nlooks(self):
        return self._nlooks

    @overlap.setter
    def overlap(self, value):
        if(value is False):
            self._overlap = False
        else:
            self._overlap = True

    def eval(self, prediction):
        assert(prediction.ndim == 2)
        mask = np.ones(prediction.shape, dtype=np.bool)
        #remove training samples from evaluation
        for sample in self._samples.reshape(-1, 2):
            mask[sample[0], sample[1]] = False

        ntrain = np.sum(mask == False)
        true = (prediction[mask] == self._label[mask])
        true_predictions = np.sum(true)
        accuracy = float(true_predictions) / float(self.nsamples - ntrain)
        logging.debug("accuracy = {}".format(accuracy))
        return accuracy

    def show_error_mat(self, prediction):
        error = (prediction != self._label)
        fig, ax = plt.subplots(1, 1)
        ax.matshow(error)
        ax.set_title("Error plot")
        ax.set_xticks(())
        ax.set_yticks(())
        plt.show()

        return error

    def get_samples(self, n=100, size=1):
        """Return a list of random box samples for every class.

        Keyword arguments:
        n -- return nboxes samples per class
        size -- width and height of the boxes
        """
        samples = []
        for lbl in self._classes:
            class_samples = []
            if(self.overlap is True):
                class_samples = self._get_samples_with_overlap(lbl, n, size)
                samples.append(class_samples)
            else:
                class_samples = self._get_samples_without_overlap(lbl, n, size)
                samples.append(class_samples)

        self._samples = np.array(samples)

        return samples

    def _get_nsamples(self, lbl, n):
        samples = []
        while(len(samples) < n):
            sample = self._get_sample(lbl=lbl)
            if(self._sample_not_in_samples(sample, samples)):
                samples.append(sample)
        return samples

    def _sample_not_in_samples(self, sample, samples):
        for s in samples:
            if((s[0] == sample[0]) and (s[1] == sample[1])):
                return False
        return True

    def _get_samples_with_overlap(self, lbl, n, size):
        samples = []
        for i in range(n):
            boxes_samples = self._get_box_samples(lbl=lbl, size=size)
            for bs in boxes_samples:
                samples.append(bs)
        return samples

    def _get_samples_without_overlap(self, lbl, n, size):
        samples = []
        boxes_samples = []
        #we get 100 more boxes, then we filter by overlap, and return n boxes
        for i in range(n * 100):
            box = self._get_box_samples(lbl=lbl, size=size)
            if(self._no_overlap(size, box, boxes_samples) is True):
                boxes_samples.append(box)

        if(len(boxes_samples) >= n):
            boxes_samples = boxes_samples[:n]
            for box in boxes_samples:
                for sample in box:
                    samples.append(sample)

        else:
            logging.debug("Samples could not be taken without overlap for label {}. Getting samples with \
                  overlap.".format(lbl))
            samples = self._get_samples_with_overlap(lbl, n, size)

        return samples

    def _no_overlap(self, size, box, box_samples):
        for bs in box_samples:
            bb1 = [box[0][0], box[0][1], box[0][0] + size, box[0][1] + size]
            bb2 = [bs[0][0], bs[0][1], box[0][0] + size, bs[0][1] + size]
            if(self._have_overlap(bb1, bb2) is True):
                return False

        return True

    def _have_overlap(self, bb1, bb2):
        '''bounding box overlap
        bb1 -- reference bounding box
        bb2 -- query bounding boxes
        '''
        bb1_x1, bb1_y1, bb1_x2, bb1_y2 = bb1
        bb2_x1, bb2_y1, bb2_x2, bb2_y2 = bb2

        # intersection bbox
        int_x1 = np.maximum(bb1_x1, bb2_x1)
        int_y1 = np.maximum(bb1_y1, bb2_y1)
        int_x2 = np.minimum(bb1_x2, bb2_x2)
        int_y2 = np.minimum(bb1_y2, bb2_y2)
        int_area = np.maximum(0, int_x2 - int_x1 + 1) * np.maximum(0, int_y2 - int_y1 + 1)
        if(int_area > 0.):
            return True
        else:
            return False

    def _get_box_samples(self, lbl, size):
        assert(lbl in self._classes)
        assert(size >= 1)
        box_samples = None
        while(box_samples is None):
            sample = self._get_sample(lbl)
            labels = set()
            samples = []
            y0 = sample[0]
            x0 = sample[1]
            #check that box is inside image
            if(((y0 + size) < self.height) and ((x0 + size) < self.width)):
                for x in range(x0, x0 + size):
                    for y in range(y0, y0 + size):
                        samples.append([y, x])
                        labels.add(self._label[y, x])
                #check that all labels are the same
                if(len(labels) == 1):
                    box_samples = samples

        return box_samples

    def _get_sample(self, lbl):
        assert(lbl in self._classes)
        sample = None
        while(sample is None):
            w = self._rng.randint(self.width)
            h = self._rng.randint(self.height)
            if(self._label[h, w] == lbl):
                sample = [h, w]
        return sample

    def _multilook(self):
        height, width, ncha = self._polsar.shape
        nlooks = self._nlooks
        n = int(np.sqrt(nlooks))
        for i in range(height - n + 1):
            for j in range(width - n + 1):
                self._polsar[i, j] = np.mean(self._polsar[i:(i + n), j:(j + n)].reshape(nlooks, ncha), axis=0)

    def show_mask(self):
        plt.imshow(self._mask)
        plt.show()

    def get_training_mask(self, samples):
        mask = self._mask.copy()
        for boxes_samples in samples:
            for sample in boxes_samples:
                mask[sample[0], sample[1], :3] = 0

        return mask

    def show_training_mask(self, samples):
        mask = self._mask.copy()
        for boxes_samples in samples:
            for sample in boxes_samples:
                mask[sample[0], sample[1], :3] = 0

        plt.imshow(mask)
        plt.show()


class SanFranciscoBay2Classes(PolSARDataset):
    def __init__(self, path, mode='real', rng=rng):
        self._classes = (0, 1)  # (Urban, Forest) = (Blue, Green)
        PolSARDataset.__init__(self, path, mode, rng)

    def _read_polsar(self):
        polsar = read_polsar(self._path, self._mode)
        self._polsar = polsar[537:737, 218:418, :]

    def _read_mask(self):
        self._mask = imread(self._path + '/mask_537_737_218_418.png')
        self._masks = [self._mask[:, :, 2] == 255, self._mask[:, :, 1] == 255]
        self._label = np.zeros((self.height, self.width), dtype=np.uint32)
        self._label[self._masks[1]] = 1


class SanFranciscoBayForestOcean(PolSARDataset):
    def __init__(self, path, mode='real', rng=rng):
        self._classes = (0, 1)  # (Ocean, Forest) = (Blue, Green)
        PolSARDataset.__init__(self, path, mode, rng)

    def _read_polsar(self):
        polsar = read_polsar(self._path, self._mode)
        self._polsar = polsar[:300, :300, :]

    def _read_mask(self):
        mask = imread(self._path + '/mask_0_300_0_300.png')
        self._mask = mask
        self._masks = [self._mask[:, :, 2] == 255, self._mask[:, :, 1] == 255]
        self._label = np.zeros((self.height, self.width), dtype=np.uint32)
        self._label[self._masks[1]] = 1


class SanFranciscoBay3Classes(PolSARDataset):
    def __init__(self, path, mode='real', rng=rng):
        self._classes = (0, 1, 2)  # (Urban, Forest, Sea) = (Red, Green, Blue)
        PolSARDataset.__init__(self, path, mode, rng)

    def _read_polsar(self):
        polsar = read_polsar(self._path, self._mode)
        self._polsar = polsar[238:564, 200:594, :]

    def _read_mask(self):
        mask = imread(self._path + '/mask_238_564_0_594.png')[:, 200:, :]
        self._mask = mask
        self._masks = [self._mask[:, :, 0] == 255, self._mask[:, :, 1] == 255, self._mask[:, :, 2] == 255]
        self._label = np.zeros((self.height, self.width), dtype=np.uint32)
        self._label[self._masks[1]] = 1
        self._label[self._masks[2]] = 2


#dataset for comparison with the paper land cover classification (2014)
class SanFranciscoBayLandCover(PolSARDataset):
    def __init__(self, path, mode='real', rng=rng):
        self._classes = (0, 1, 2)  # (Urban, Forest, Sea) = (Red, Green, Blue)
        PolSARDataset.__init__(self, path, mode, rng)

    def _read_polsar(self):
        polsar = read_polsar(self._path, self._mode)
        self._polsar = polsar[250:750, 50:550, :]

    def _read_mask(self):
        mask = imread(self._path + '/mask_250_750_50_550.png')[:, :, :]
        self._mask = mask
        self._masks = [self._mask[:, :, 0] == 255, self._mask[:, :, 1] == 255, self._mask[:, :, 2] == 255]
        self._label = np.zeros((self.height, self.width), dtype=np.uint32)
        self._label[self._masks[1]] = 1
        self._label[self._masks[2]] = 2

    def show_segmentation(self, prediction):
        segmentation = 255 * np.ones(self._rgb.shape, dtype=np.uint8)
        segmentation[prediction == 0] = [255, 0, 0]
        segmentation[prediction == 1] = [0, 255, 0]
        segmentation[prediction == 2] = [0, 0, 255]
        plt.imshow(segmentation)
        plt.show()


class SanFranciscoBayUrbanForestOcean(PolSARDataset):
    def __init__(self, path, mode='real', rng=rng):
        self._classes = (0, 1, 2)  # (Urban, Forest, Ocean) = (Red, Green, Blue)
        PolSARDataset.__init__(self, path, mode, rng)

    def _read_polsar(self):
        polsar = read_polsar(self._path, self._mode)
        self._polsar = polsar[566:766, 86:286, :]

    def _read_mask(self):
        mask = imread(self._path + '/mask_566_766_86_286.png')[:, :, :]
        self._mask = mask
        self._masks = [self._mask[:, :, 0] == 255, self._mask[:, :, 1] == 255, self._mask[:, :, 2] == 255]
        self._label = np.zeros((self.height, self.width), dtype=np.uint32)
        self._label[self._masks[1]] = 1
        self._label[self._masks[2]] = 2


#dataset for comparison with the paper land cover classification (2014)
class FlevolandLandCover(PolSARDataset):
    def __init__(self, path, mode='real', rng=rng):
        self._classes_names = ('BareSoil', 'Barley', 'Beet', 'Buildings', 'Forest', 'Grasses', 'Lucerne',
                               'Peas', 'Potatoes', 'Rapeseed', 'Beans', 'Water', 'Wheat', 'Wheat2', 'Wheat3')
        self._classes = (0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14)
        self._colors = ([0xab, 0x8a, 0x50], [0x94, 0x00, 0x00], [0xb7, 0x00, 0xff], [0xff, 0xd9, 0x9d],
                        [0x00, 0x83, 0x4a], [0x00, 0xff, 0x00], [0x00, 0xfc, 0xff], [0x5a, 0x0b, 0xe1],
                        [0xff, 0xff, 0x00], [0xff, 0x80, 0x00], [0xff, 0x00, 0x00], [0x00, 0x00, 0xff],
                        [0xff, 0xb6, 0xe5], [0xbf, 0xbf, 0xff], [0xbf, 0xff, 0xbf], [0xff, 0xff, 0xff])
        PolSARDataset.__init__(self, path, mode, rng)
        self._undefined_samples = np.sum(self._label == 255)

    @property
    def nsamples(self):
        return self.width * self.height - self._undefined_samples

    def _read_polsar(self):
        self._polsar = read_polsar(self._path, self._mode)

    def _read_mask(self):
        mask = imread(self._path + '/mask_stian.png')
        self._mask = mask
        self._masks = [(mask[:, :, 0] == 0xab) & (mask[:, :, 1] == 0x8a) & (mask[:, :, 2] == 0x50),
                       (mask[:, :, 0] == 0x94) & (mask[:, :, 1] == 0x00) & (mask[:, :, 2] == 0x00),
                       (mask[:, :, 0] == 0xb7) & (mask[:, :, 1] == 0x00) & (mask[:, :, 2] == 0xff),
                       (mask[:, :, 0] == 0xff) & (mask[:, :, 1] == 0xd9) & (mask[:, :, 2] == 0x9d),
                       (mask[:, :, 0] == 0x00) & (mask[:, :, 1] == 0x83) & (mask[:, :, 2] == 0x4a),
                       (mask[:, :, 0] == 0x00) & (mask[:, :, 1] == 0xff) & (mask[:, :, 2] == 0x00),
                       (mask[:, :, 0] == 0x00) & (mask[:, :, 1] == 0xfc) & (mask[:, :, 2] == 0xff),
                       (mask[:, :, 0] == 0x5a) & (mask[:, :, 1] == 0x0b) & (mask[:, :, 2] == 0xe1),
                       (mask[:, :, 0] == 0xff) & (mask[:, :, 1] == 0xff) & (mask[:, :, 2] == 0x00),
                       (mask[:, :, 0] == 0xff) & (mask[:, :, 1] == 0x80) & (mask[:, :, 2] == 0x00),
                       (mask[:, :, 0] == 0xff) & (mask[:, :, 1] == 0x00) & (mask[:, :, 2] == 0x00),
                       (mask[:, :, 0] == 0x00) & (mask[:, :, 1] == 0x00) & (mask[:, :, 2] == 0xff),
                       (mask[:, :, 0] == 0xff) & (mask[:, :, 1] == 0xb6) & (mask[:, :, 2] == 0xe5),
                       (mask[:, :, 0] == 0xbf) & (mask[:, :, 1] == 0xbf) & (mask[:, :, 2] == 0xff),
                       (mask[:, :, 0] == 0xbf) & (mask[:, :, 1] == 0xff) & (mask[:, :, 2] == 0xbf),
                       (mask[:, :, 0] == 0xff) & (mask[:, :, 1] == 0xff) & (mask[:, :, 2] == 0xff)]

        self._label = np.ones((self.height, self.width), dtype=np.uint32) * 255  # Undefined is 255
        for i in range(len(self._classes)):
            self._label[self._masks[i]] = i

    def show_segmentation(self, prediction):
        segmentation = 255 * np.ones(self._rgb.shape, dtype=np.uint8)
        for i in range(len(self._colors) - 1):
            segmentation[prediction == i] = self._colors[i]

        plt.imshow(segmentation)
        plt.show()


#dateset for comparison with paper "Pol-SAR Classification Based on Generalized Polar Decomposition of Mueller Matrix"
class FlevolandMuellerMatrix(PolSARDataset):
    def __init__(self, path, mode='real', rng=rng):
        self._classes_names = ('BareSoil', 'Beet', 'Forest', 'Grasses', 'Lucerne', 'Peas', 'Potatoes',
                               'Rapeseed', 'Beans', 'Water', 'Wheat')
        self._classes = (0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10)
        self._colors = ([0xab, 0x8a, 0x50], [0x94, 0x00, 0x00], [0xb7, 0x00, 0xff], [0xff, 0xd9, 0x9d],
                        [0x00, 0x83, 0x4a], [0x00, 0xff, 0x00], [0x00, 0xfc, 0xff], [0x5a, 0x0b, 0xe1],
                        [0xff, 0xff, 0x00], [0xff, 0x80, 0x00], [0xff, 0x00, 0x00], [0x00, 0x00, 0xff],
                        [0xff, 0xb6, 0xe5], [0xbf, 0xbf, 0xff], [0xbf, 0xff, 0xbf], [0xff, 0xff, 0xff])
        PolSARDataset.__init__(self, path, mode, rng)
        self._undefined_samples = np.sum(self._label == 255)

    @property
    def nsamples(self):
        return self.width * self.height - self._undefined_samples

    def _read_polsar(self):
        self._polsar = read_polsar(self._path, self._mode)

    def _read_mask(self):
        mask = imread(self._path + '/mask.png')
        self._mask = mask
        self._masks = [(mask[:, :, 0] == 0xab) & (mask[:, :, 1] == 0x8a) & (mask[:, :, 2] == 0x50),
                       (mask[:, :, 0] == 0xb7) & (mask[:, :, 1] == 0x00) & (mask[:, :, 2] == 0xff),
                       (mask[:, :, 0] == 0x00) & (mask[:, :, 1] == 0x83) & (mask[:, :, 2] == 0x4a),
                       (mask[:, :, 0] == 0x00) & (mask[:, :, 1] == 0xff) & (mask[:, :, 2] == 0x00),
                       (mask[:, :, 0] == 0x00) & (mask[:, :, 1] == 0xfc) & (mask[:, :, 2] == 0xff),
                       (mask[:, :, 0] == 0x5a) & (mask[:, :, 1] == 0x0b) & (mask[:, :, 2] == 0xe1),
                       (mask[:, :, 0] == 0xff) & (mask[:, :, 1] == 0xff) & (mask[:, :, 2] == 0x00),
                       (mask[:, :, 0] == 0xff) & (mask[:, :, 1] == 0x80) & (mask[:, :, 2] == 0x00),
                       (mask[:, :, 0] == 0xff) & (mask[:, :, 1] == 0x00) & (mask[:, :, 2] == 0x00),
                       (mask[:, :, 0] == 0x00) & (mask[:, :, 1] == 0x00) & (mask[:, :, 2] == 0xff),
                       (mask[:, :, 0] == 0xbf) & (mask[:, :, 1] == 0xff) & (mask[:, :, 2] == 0xbf),
                       (mask[:, :, 0] == 0xff) & (mask[:, :, 1] == 0xff) & (mask[:, :, 2] == 0xff)]

        self._label = np.ones((self.height, self.width), dtype=np.uint32) * 255  # Undefined is 255
        for i in range(len(self._classes)):
            self._label[self._masks[i]] = i

    def show_segmentation(self, prediction):
        segmentation = 255 * np.ones(self._rgb.shape, dtype=np.uint8)
        for i in range(len(self._colors) - 1):
            segmentation[prediction == i] = self._colors[i]

        plt.imshow(segmentation)
        plt.show()


class FoulumBase(PolSARDataset):
    def __init__(self, path, mode='real', rng=rng, nlooks=None):
        self._classes_names = ('Conifer', 'Oat', 'Pea', 'Rape', 'Rye', 'Urban', 'Wheat')
        self._classes = (0, 1, 2, 3, 4, 5, 6)
        self._colors = ([0x00, 0x00, 0xff], [0xff, 0xff, 0x00], [0x00, 0xff, 0xff], [0x00, 0xff, 0x00],
                        [0xff, 0x00, 0xff], [0x7f, 0x00, 0x00], [0xff, 0x00, 0x00], [0xff, 0xff, 0xff])
        PolSARDataset.__init__(self, path, mode, rng, nlooks)
        self._undefined_samples = np.sum(self._label == 255)

    @property
    def nsamples(self):
        return self.width * self.height - self._undefined_samples

    def _read_polsar(self):
        self._polsar = read_polsar(self._path, self._mode)

    def _read_mask(self):
        mask = imread(self._path + '/mask.png')
        self._mask = mask
        self._masks = [(mask[:, :, 0] == 0x00) & (mask[:, :, 1] == 0x00) & (mask[:, :, 2] == 0xff),
                       (mask[:, :, 0] == 0xff) & (mask[:, :, 1] == 0xff) & (mask[:, :, 2] == 0x00),
                       (mask[:, :, 0] == 0x00) & (mask[:, :, 1] == 0xff) & (mask[:, :, 2] == 0xff),
                       (mask[:, :, 0] == 0x00) & (mask[:, :, 1] == 0xff) & (mask[:, :, 2] == 0x00),
                       (mask[:, :, 0] == 0xff) & (mask[:, :, 1] == 0x00) & (mask[:, :, 2] == 0xff),
                       (mask[:, :, 0] == 0x7f) & (mask[:, :, 1] == 0x00) & (mask[:, :, 2] == 0x00),
                       (mask[:, :, 0] == 0xff) & (mask[:, :, 1] == 0x00) & (mask[:, :, 2] == 0x00),
                       (mask[:, :, 0] == 0xff) & (mask[:, :, 1] == 0xff) & (mask[:, :, 2] == 0xff)]

        self._label = np.ones((self.height, self.width), dtype=np.uint32) * 255  # Undefined is 255
        for i in range(len(self._classes)):
            self._label[self._masks[i]] = i

    def show_segmentation(self, prediction):
        segmentation = 255 * np.ones(self._rgb.shape, dtype=np.uint8)
        for i in range(len(self._colors) - 1):
            segmentation[prediction == i] = self._colors[i]

        plt.imshow(segmentation)
        plt.show()


#dataset for comparison with the paper land cover classification (2014)
class FoulumS2LandCover(FoulumBase):
    def __init__(self, path, mode='real', rng=rng, nlooks=None):
        FoulumBase.__init__(self, path, mode, rng, nlooks)


class FoulumC3(FoulumBase):
    def __init__(self, path, mode='real', rng=rng, nlooks=None):
        FoulumBase.__init__(self, path, mode, rng, nlooks)
