# -*- coding: utf-8 -*-
from __future__ import absolute_import


import numpy as np
import logging
import time


import vrl


from .complex_wishart import ComplexWishart
from .polsar import force_positive_definite


logger = vrl.utils.get_logger('sartb')
vrl.utils.set_loglevel(logging.INFO)


def likelihood(mixture, sample):
    likelihoods = [np.exp(mixture._component(k).log_likelihood(sample)) for k in range(mixture.ncomp)]
    return np.dot(likelihoods, mixture.weights_)


def get_indexs(size, row, col, nrow, ncol):
    indexs = []
    assert(row < nrow)
    assert(col < ncol)
    sot = size / 2
    for y in range(max(0, min(row - sot, nrow)), min(row + sot + 1, nrow)):
        for x in range(max(0, min(col - sot, ncol)), min(col + sot + 1, ncol)):
            indexs.append(x + y * ncol)
    return indexs


def compute_unary(dataset, unary_method='fisher_vector', wishart_init='kmeans', ncomp=8, dof=3, step=1,
                  size=2, nblocks=100, block_size=1, seed=1234567):
    rng = vrl.base.RNG(seed=seed)
    assert(unary_method in ['fisher_vector', 'wishart_classifier', 'wishart_mixture'])
    if(unary_method is 'fisher_vector'):
        return unary_fv(dataset, ncomp, dof, wishart_init, nblocks, block_size, step, size, rng)
    elif(unary_method is 'wishart_mixture'):
        return unary_wishart_mixture(dataset, ncomp, dof, wishart_init, nblocks, block_size, step, size, rng)
    else:
        return unary_complex_wishart(unary_method, dataset, nblocks, block_size, step, size, rng)


def unary_fv(dataset, ncomp, dof, wishart_init, nblocks, block_size, step, size, rng):
    logger.info("Reading polsar image")
    polsar = np.array(dataset.polsar, dtype=np.float64)
    force_positive_definite(polsar)

    nstates = dataset.nclasses

    #reshape and get some samples from train
    nrow, ncol, ncha = polsar.shape
    X = polsar.reshape((nrow * ncol, ncha))

    detX = np.array([np.linalg.det(x.reshape(3,3)) for x in X])
    lb = np.percentile(detX, 0)
    ub = np.percentile(detX, 95)
    idx = np.where(np.bitwise_and(detX>lb, detX<ub))[0]
    X_ = X[idx]

    #X_diag = [x.reshape(3,3).diagonal() for x in X]
    #X_norm = [np.dot(d.reshape(1, -1), d.reshape(-1, 1)).ravel() for d in X_diag]
    #X = np.array([x / np.sqrt(X_norm[i]) for i, x in enumerate(X)])

    logger.info("Training wishart mixture")
    wmm = vrl.ml.WishartMixture(ncomp=ncomp, nvar=ncha, dof=dof, random_state=rng)
    wmm._em_params['init'] = wishart_init
    wmm._em_params['alpha'] = 1e5
    wmm._em_params['n_init'] = 1e3

    #wmm.fit(X)
    start = time.time()
    wmm.fit(X_)
    stop = time.time()
    print('WMM time = {}'.format(stop - start))

    logger.info("Fitting Fisher vector model")
    efv = vrl.ml.FisherVector(model=wmm, normalizer='var', random_state=rng)
    #efv.fit(X)
    efv.fit(X_)

    logger.info("Computing reference Fisher vector")
    index_to_samples = dataset.get_samples(n=nblocks, size=block_size)
    ref_fv = []
    assert(len(index_to_samples) == nstates)
    for i in range(nstates):
        indexs = []
        for j in index_to_samples[i]:
            y, x = j
            indexs.append(x + y * ncol)

        fv = efv.transform(X[indexs])
        ref_fv.append(fv)

    logger.info("Computing Fisher vector for nodes")
    fv_size = len(ref_fv[0])
    fvs = np.zeros((nrow, ncol, fv_size), np.float64)
    diff = []
    for r in range(0, nrow, step):
        for c in range(0, ncol, step):
            start = time.time()
            indexs = get_indexs(size, r, c, nrow, ncol)
            fv = efv.transform(X[indexs])
            fvs[r, c, :] = fv.copy()
            stop = time.time()
            diff.append(stop - start)

    print('Mean time FV = {}'.format(np.mean(diff)))

    nstates = len(ref_fv)
    unary = np.zeros((nrow, ncol, nstates), np.float64)
    for i in range(nstates):
        logger.debug("Computing unary for state Nº {}".format(i))
        for r in range(0, nrow, step):
            for c in range(0, ncol, step):
                unary[r, c, i] = np.dot(ref_fv[i], fvs[r, c])

    return unary


def unary_wishart_mixture(dataset, ncomp, dof, wishart_init, nblocks, block_size, step, size, rng):
    logger.info("Reading polsar image")
    polsar = np.array(dataset.polsar, dtype=np.float64)
    force_positive_definite(polsar)

    nstates = dataset.nclasses

    #reshape and get some samples from train
    nrow, ncol, ncha = polsar.shape
    X = polsar.reshape((nrow * ncol, ncha))

    logger.info("Training Wishart Mixture")
    wishart_mixtures = []
    #4 blocks of 5x5, 100 training samples in total
    samples = dataset.get_samples(n=nblocks, size=block_size)
    assert(len(samples) == nstates)
    for i in range(nstates):
        indexs = []
        for j in samples[i]:
            y, x = j
            indexs.append(x + y * ncol)

        wmm = vrl.ml.WishartMixture(ncomp=ncomp, nvar=ncha, dof=dof, random_state=rng)
        wmm._em_params['init'] = wishart_init
        wmm._em_params['alpha'] = 1e5
        X_ = X[indexs]
        wmm.fit(X_)
        wishart_mixtures.append(wmm)

    logger.info("Fitting potts model")
    unary = np.zeros((nrow, ncol, nstates), np.float64)
    X = X.reshape((nrow, ncol, 9))
    for i in range(nstates):
        logger.debug("Potts: Computing unary for state Nº {}".format(i))
        for r in range(0, nrow, step):
            for c in range(0, ncol, step):
                unary[r, c, i] = max(-50, min(50, wishart_mixtures[i].log_likelihood(X[r, c])))

    return unary


def unary_complex_wishart(method, dataset, nblocks, block_size, step, size, rng):
    logger.info("Reading polsar image")
    polsar = dataset.polsar
    nstates = dataset.nclasses

    #reshape and get some samples from train
    nrow, ncol, ncha = polsar.shape
    X = polsar.reshape((nrow * ncol, ncha))

    logger.info("Training complex wishart")
    complex_wisharts = []
    #4 blocks of 5x5, 100 training samples in total
    samples = dataset.get_samples(n=nblocks, size=block_size)
    assert(len(samples) == nstates)
    for i in range(nstates):
        indexs = []
        for j in samples[i]:
            y, x = j
            indexs.append(x + y * ncol)

        cw = ComplexWishart()
        X_ = X[indexs]
        X_ = X_.reshape(X_.shape[0], 3, 3)
        rng.shuffle(X_)
        cw.fit(X_)
        complex_wisharts.append(cw)

    logger.info("Fitting potts model")
    unary = np.zeros((nrow, ncol, nstates), np.float64)
    X = X.reshape((nrow, ncol, 3, 3))
    for i in range(nstates):
        logger.debug("Potts: Computing unary for state Nº {}".format(i))
        for r in range(0, nrow, step):
            for c in range(0, ncol, step):
                unary[r, c, i] = complex_wisharts[i].score(X[r, c])

    return unary
