# -*- coding: utf-8 -*-

import numpy as np
from scipy.special import gamma


class ComplexWishart(object):
    def __init__(self, nlooks=4, dimension=3):
        self._n = nlooks
        self._d = dimension

    @property
    def C(self):
        return self._C

    @C.setter
    def C(self, value):
        self._C = value
        self.recompute_constants()

    def recompute_constants(self):
        self._constant_num = self._n ** (self._n * self._d)
        self._den = self._R() * (np.linalg.det(self._C) ** self._n)
        self._one_over_den = 1. / self._den
        self._inv_C = np.linalg.inv(self._C)

    def likelihood(self, Z):
        exp = np.exp(-self._n * np.trace(np.dot(self._inv_C, Z)))
        return self._constant_num * (np.linalg.det(Z) ** (self._n - self._d)) * exp * self._one_over_den

    def log_likelihood(self, Z):
        return np.log(self.likelihood(Z))

    def _R(self):
        R = np.pi ** (self._d * (self._d - 1) / 2.)
        for i in range(self._n - self._d + 1, self._n + 1):
            R *= gamma(i)

        return R


class ComplexWishartMixture(object):
    def __init__(self, ncomps=2, nlooks=4, dimension=3):
        self._ncomps = ncomps
        self._n = nlooks
        self._d = dimension
        self._components = [ComplexWishart(nlooks, dimension) for i in range(self._ncomps)]
        self._weights = [1./self._ncomps for i in range(self._ncomps)]

    def log_likelihood(self, Z):
        ll = 0
        for i in range(self._ncomps):
            ll += self._components[i].log_likelihood(Z)
        return ll

    def fit(self, X):
        nsamples = len(X)
        ncomps = self._ncomps
        self._initial_estimate(X)
        lh_by_comps = [0. for i in range(self._ncomps)]  # q(Z_i|C_k)
        pst_by_comps = [0. for i in range(self._ncomps)]  # gamma_i[k]
        ll_curr = 0.
        ss0 = [0. for i in range(self._ncomps)]  # ss0[k] = sum(gamma_i[k] * Zi)
        ss1 = [0. for i in range(self._ncomps)]  # ss1[k] = sum(gamma_i[k])
        for ns in nsamples:
            for nc in range(ncomps):
                lh_by_comps[nc] = self._components[nc].likelihood(X[ns])
            sample_lh = np.dot(self._weights, lh_by_comps)
            pst_by_comps = np.multiply(self._weights, lh_by_comps) / sample_lh
            ll_curr += np.log(sample_lh)

