/*
 * Decompiled with CFR 0.152.
 */
package org.openimaj.math.statistics.distribution;

import Jama.Matrix;
import cern.jet.random.Normal;
import cern.jet.random.engine.MersenneTwister;
import java.util.Arrays;
import java.util.Random;
import org.openimaj.math.matrix.MatrixUtils;
import org.openimaj.math.statistics.distribution.AbstractMultivariateGaussian;

public class DiagonalMultivariateGaussian
extends AbstractMultivariateGaussian {
    public double[] variance;

    public DiagonalMultivariateGaussian(Matrix mean, double[] variance) {
        this.mean = mean;
        this.variance = variance;
    }

    public DiagonalMultivariateGaussian(int ndims) {
        this.mean = new Matrix(1, ndims);
        this.variance = new double[ndims];
        Arrays.fill(this.variance, 1.0);
    }

    @Override
    public Matrix getCovariance() {
        return MatrixUtils.diag(this.variance);
    }

    @Override
    public double getCovariance(int row, int col) {
        if (row < 0 || row >= this.variance.length || col < 0 || col > this.variance.length) {
            throw new IndexOutOfBoundsException();
        }
        if (row == col) {
            return this.variance[row];
        }
        return 0.0;
    }

    @Override
    public double estimateProbability(double[] sample) {
        int N = this.variance.length;
        double[] meanvector = this.mean.getArray()[0];
        double det = this.variance[0];
        for (int i = 1; i < N; ++i) {
            det *= this.variance[i];
        }
        double pdf_const_factor = 1.0 / Math.sqrt(Math.pow(Math.PI * 2, N) * det);
        double v = 0.0;
        for (int i = 0; i < N; ++i) {
            double diff = sample[i] - meanvector[i];
            v += diff * diff / this.variance[i];
        }
        return pdf_const_factor * Math.exp(-0.5 * v);
    }

    @Override
    public double estimateLogProbability(double[] sample) {
        int N = this.variance.length;
        double[] meanvector = this.mean.getArray()[0];
        double log_sqrt_det = Math.log(Math.sqrt(this.variance[0]));
        for (int i = 1; i < N; ++i) {
            log_sqrt_det += Math.log(Math.sqrt(this.variance[i]));
        }
        double log_pdf_const_factor = -Math.log(Math.sqrt(Math.pow(Math.PI * 2, N))) - log_sqrt_det;
        double v = 0.0;
        for (int i = 0; i < N; ++i) {
            double diff = sample[i] - meanvector[i];
            v += diff * diff / this.variance[i];
        }
        return log_pdf_const_factor + -0.5 * v;
    }

    @Override
    public double[] estimateLogProbability(double[][] samples) {
        int N = this.variance.length;
        double[] meanvector = this.mean.getArray()[0];
        double log_sqrt_det = Math.log(Math.sqrt(this.variance[0]));
        for (int i = 1; i < N; ++i) {
            log_sqrt_det += Math.log(Math.sqrt(this.variance[i]));
        }
        double log_pdf_const_factor = -Math.log(Math.sqrt(Math.pow(Math.PI * 2, N))) - log_sqrt_det;
        double[] lp = new double[samples.length];
        for (int j = 0; j < samples.length; ++j) {
            double v = 0.0;
            for (int i = 0; i < N; ++i) {
                double diff = samples[j][i] - meanvector[i];
                v += diff * diff / this.variance[i];
            }
            lp[j] = log_pdf_const_factor + -0.5 * v;
        }
        return lp;
    }

    @Override
    public double[][] sample(int nsamples, Random rng) {
        if (nsamples == 0) {
            return new double[0][0];
        }
        Normal rng2 = new Normal(0.0, 1.0, new MersenneTwister());
        int N = this.mean.getColumnDimension();
        double[][] out = new double[nsamples][N];
        double[] meanv = this.mean.getArray()[0];
        for (int i = 0; i < N; ++i) {
            double choli = Math.sqrt(this.variance[i]);
            for (int j = 0; j < nsamples; ++j) {
                out[j][i] = choli * rng2.nextDouble() + meanv[i];
            }
        }
        return out;
    }
}

