/*
 * Decompiled with CFR 0.152.
 */
package org.openimaj.math.statistics.distribution;

import Jama.CholeskyDecomposition;
import Jama.Matrix;
import java.util.Arrays;
import java.util.Random;
import org.openimaj.math.matrix.MatrixUtils;
import org.openimaj.math.statistics.distribution.AbstractMultivariateDistribution;
import org.openimaj.math.statistics.distribution.MultivariateGaussian;
import org.openimaj.util.array.ArrayUtils;
import org.openimaj.util.pair.IndependentPair;

public class MixtureOfGaussians
extends AbstractMultivariateDistribution {
    public static final double MIN_COVAR_RECONDITION = 1.0E-7;
    public MultivariateGaussian[] gaussians;
    public double[] weights;

    public MixtureOfGaussians(MultivariateGaussian[] gaussians, double[] weights) {
        this.gaussians = gaussians;
        this.weights = weights;
    }

    @Override
    public double[] sample(Random rng) {
        return this.sample(1, rng)[0];
    }

    @Override
    public double[][] sample(int n_samples, Random rng) {
        int i;
        double[] weight_cdf = ArrayUtils.cumulativeSum(this.weights);
        double[][] X = new double[n_samples][this.gaussians[0].getMean().getColumnDimension()];
        int[] comps = new int[n_samples];
        for (i = 0; i < n_samples; ++i) {
            comps[i] = Arrays.binarySearch(weight_cdf, rng.nextDouble());
            if (comps[i] < 0) {
                comps[i] = 0;
            }
            if (comps[i] < this.gaussians.length) continue;
            comps[i] = this.gaussians.length - 1;
        }
        for (i = 0; i < this.gaussians.length; ++i) {
            int[] idxs = ArrayUtils.search(comps, i);
            if (idxs.length == 0) continue;
            double[][] samples = this.gaussians[i].sample(idxs.length, rng);
            for (int j = 0; j < samples.length; ++j) {
                X[idxs[j]] = samples[j];
            }
        }
        return X;
    }

    @Override
    public double estimateLogProbability(double[] sample) {
        return this.estimateLogProbability(new double[][]{sample})[0];
    }

    @Override
    public double[] estimateLogProbability(double[][] samples) {
        if (samples[0].length != this.gaussians[0].getMean().getColumnDimension()) {
            throw new IllegalArgumentException("The number of dimensions of the given data is not compatible with the model");
        }
        double[][] lpr = this.computeWeightedLogProb(samples);
        double[] logprob = new double[samples.length];
        for (int i = 0; i < samples.length; ++i) {
            for (int j = 0; j < lpr[0].length; ++j) {
                int n = i;
                logprob[n] = logprob[n] + Math.exp(lpr[i][j]);
            }
            logprob[i] = Math.log(logprob[i]);
        }
        return logprob;
    }

    public static double[][] logProbability(double[][] x, MultivariateGaussian[] gaussians) {
        int ndims = x[0].length;
        int nmix = gaussians.length;
        int nsamples = x.length;
        Matrix X = new Matrix(x);
        double[][] log_prob = new double[nsamples][nmix];
        for (int i = 0; i < nmix; ++i) {
            Matrix cv_chol;
            Matrix mu = gaussians[i].getMean();
            Matrix cv = gaussians[i].getCovariance();
            CholeskyDecomposition chol = cv.chol();
            if (chol.isSPD()) {
                cv_chol = chol.getL();
            } else {
                Matrix m = cv.plus(Matrix.identity(ndims, ndims).timesEquals(1.0E-7));
                cv_chol = m.chol().getL();
            }
            double cv_log_det = 0.0;
            double[][] cv_chol_d = cv_chol.getArray();
            for (int j = 0; j < ndims; ++j) {
                cv_log_det += Math.log(cv_chol_d[j][j]);
            }
            cv_log_det *= 2.0;
            Matrix cv_sol = cv_chol.solve(MatrixUtils.minusRow(X, mu.getArray()[0]).transpose()).transpose();
            for (int k = 0; k < nsamples; ++k) {
                double sum = 0.0;
                for (int j = 0; j < ndims; ++j) {
                    sum += cv_sol.get(k, j) * cv_sol.get(k, j);
                }
                log_prob[k][i] = -0.5 * (sum + cv_log_det + (double)ndims * Math.log(Math.PI * 2));
            }
        }
        return log_prob;
    }

    protected double[][] computeWeightedLogProb(double[][] samples) {
        double[][] lpr = this.logProbability(samples);
        for (int j = 0; j < lpr[0].length; ++j) {
            double logw = Math.log(this.weights[j]);
            for (int i = 0; i < lpr.length; ++i) {
                double[] dArray = lpr[i];
                int n = j;
                dArray[n] = dArray[n] + logw;
            }
        }
        return lpr;
    }

    public double[][] logProbability(double[][] x) {
        int nmix = this.gaussians.length;
        int nsamples = x.length;
        double[][] log_prob = new double[nsamples][nmix];
        for (int i = 0; i < nmix; ++i) {
            double[] lp = this.gaussians[i].estimateLogProbability(x);
            for (int j = 0; j < nsamples; ++j) {
                log_prob[j][i] = lp[j];
            }
        }
        return log_prob;
    }

    public double[] predictLogPosterior(double[] sample) {
        return this.predictLogPosterior(new double[][]{sample})[0];
    }

    public double[][] predictLogPosterior(double[][] samples) {
        if (samples[0].length != this.gaussians[0].getMean().getColumnDimension()) {
            throw new IllegalArgumentException("The number of dimensions of the given data is not compatible with the model");
        }
        double[][] lpr = this.computeWeightedLogProb(samples);
        double[] logprob = this.logsumexp(lpr);
        double[][] responsibilities = new double[samples.length][this.gaussians.length];
        for (int i = 0; i < samples.length; ++i) {
            for (int j = 0; j < this.gaussians.length; ++j) {
                responsibilities[i][j] = lpr[i][j] - logprob[i];
            }
        }
        return responsibilities;
    }

    public IndependentPair<double[], double[][]> scoreSamples(double[][] samples) {
        if (samples[0].length != this.gaussians[0].getMean().getColumnDimension()) {
            throw new IllegalArgumentException("The number of dimensions of the given data is not compatible with the model");
        }
        double[][] lpr = this.computeWeightedLogProb(samples);
        double[] logprob = this.logsumexp(lpr);
        double[][] responsibilities = new double[samples.length][this.gaussians.length];
        for (int i = 0; i < samples.length; ++i) {
            for (int j = 0; j < this.gaussians.length; ++j) {
                responsibilities[i][j] = Math.exp(lpr[i][j] - logprob[i]);
            }
        }
        return IndependentPair.pair(logprob, responsibilities);
    }

    private double[] logsumexp(double[][] data) {
        double[] lse = new double[data.length];
        for (int i = 0; i < data.length; ++i) {
            double max = ArrayUtils.maxValue(data[i]);
            for (int j = 0; j < data[0].length; ++j) {
                int n = i;
                lse[n] = lse[n] + Math.exp(data[i][j] - max);
            }
            lse[i] = max + Math.log(lse[i]);
        }
        return lse;
    }

    public MultivariateGaussian[] getGaussians() {
        return this.gaussians;
    }

    public double[] getWeights() {
        return this.weights;
    }

    @Override
    public double estimateProbability(double[] sample) {
        return Math.exp(this.estimateLogProbability(sample));
    }

    public int predict(double[] data) {
        double[] posterior = this.predictLogPosterior(data);
        return ArrayUtils.maxIndex(posterior);
    }
}

