# -*- coding: utf-8 -*-

import logging
import numpy as np

from pystruct.inference import inference_dispatch
from pystruct.utils import make_grid_edges


import vrl


logger = vrl.utils.get_logger('sartb')
vrl.utils.set_loglevel(logging.INFO)


def potts(unary, inference_method='ad3'):
    nrow, ncol, nstates = unary.shape

    logger.info("Potts: Computing pairwise")
    pairwise = np.eye(nstates)

    edges = make_grid_edges(unary)
    unaries = unary.reshape(-1, nstates)

    logger.info("Potts: Inference")
    y = inference_dispatch(unaries, pairwise, edges, inference_method=inference_method)

    return y.reshape(nrow, ncol)
