# -*- coding: utf-8 -*-

import numpy as np

from pygco import cut_simple
import logging


import vrl


logger = vrl.utils.get_logger('sartb')
vrl.utils.set_loglevel(logging.INFO)


def graph_cut(unary):
    nrow, ncol, nstates = unary.shape

    logger.info("Graph Cut: Computing pairwise")
    factor = -100
    pairwise = (factor * np.eye(nstates)).astype(np.int32)
    unary = (factor * unary).astype(np.int32)

    logger.info("Graph Cut: Inference")
    y = cut_simple(unary, pairwise)
    logger.info("Graph Cut: Inference Done")

    return y.reshape(nrow, ncol)
