# -*- coding: utf-8 -*-
from __future__ import absolute_import

import numpy as np


import logging
import time


import vrl


from .unary import compute_unary
from .smoothing import smoothing
from .dataset import get_dataset


logger = vrl.utils.get_logger('sartb')
vrl.utils.set_loglevel(logging.INFO)


def segmentation_with_smoothing(dataset, unary_method='fisher_vector', wishart_init='kmeans',
                                smoothing_algo='graph_cut', ncomp=8, dof=3, step=1, size=2, nblocks=100,
                                block_size=1, seed=1234567):

    unary = compute_unary(dataset, unary_method, wishart_init, ncomp, dof, step, size, nblocks, block_size,
                          seed)

    start = time.time()
    y = smoothing(unary, smoothing_algo)
    stop = time.time()
    print('Inference time = {}'.format(stop - start))

    return y


def run_ntimes(ipath, dataset_name, unary_method, wishart_init, polsar_mode, ncomp, dof, ntimes, overlap,
               nblocks=4, block_size=5, seed=1234567):
    rng = vrl.base.RNG(seed=seed)
    acc = []
    for i in range(ntimes):
        seed = rng.randint(0, 12345678)
        dataset = get_dataset(ipath, dataset_name, mode=polsar_mode, seed=seed)
        if(overlap is True):
            dataset.overlap = True
        else:
            dataset.overlap = False

        y = segmentation_with_smoothing(dataset, unary_method=unary_method, wishart_init=wishart_init,
                                        ncomp=ncomp, dof=dof, nblocks=nblocks, block_size=block_size,
                                        seed=seed)
        acc.append(dataset.eval(y))

    mean = np.mean(acc, axis=0)
    std = np.std(acc, axis=0)

    logging.info('mean = {}'.format(mean))
    logging.info('std = {}'.format(std))

    return mean, std
