/*
 * Decompiled with CFR 0.152.
 */
package gov.sandia.cognition.statistics.distribution;

import gov.sandia.cognition.annotation.PublicationReference;
import gov.sandia.cognition.annotation.PublicationType;
import gov.sandia.cognition.math.matrix.Matrix;
import gov.sandia.cognition.math.matrix.MatrixFactory;
import gov.sandia.cognition.math.matrix.Vector;
import gov.sandia.cognition.math.matrix.VectorFactory;
import gov.sandia.cognition.statistics.AbstractDistribution;
import gov.sandia.cognition.statistics.ClosedFormComputableDistribution;
import gov.sandia.cognition.statistics.ProbabilityDensityFunction;
import gov.sandia.cognition.statistics.distribution.InverseWishartDistribution;
import gov.sandia.cognition.statistics.distribution.MultivariateGaussian;
import gov.sandia.cognition.util.ObjectUtil;
import java.util.ArrayList;
import java.util.Random;

@PublicationReference(author={"Stanley Sawyer"}, title="Wishart Distributions and Inverse-Wishart Sampling", type=PublicationType.Misc, year=2007, url="http://www.math.wustl.edu/~sawyer/hmhandouts/Wishart.pdf")
public class NormalInverseWishartDistribution
extends AbstractDistribution<Matrix>
implements ClosedFormComputableDistribution<Matrix> {
    public static final int DEFAULT_DIMENSIONALITY = 2;
    public static final double DEFAULT_COVARIANCE_DIVISOR = 1.0;
    protected double covarianceDivisor;
    protected MultivariateGaussian gaussian;
    protected InverseWishartDistribution inverseWishart;

    public NormalInverseWishartDistribution() {
        this(2);
    }

    public NormalInverseWishartDistribution(int dimensionality) {
        this(dimensionality, 1.0);
    }

    public NormalInverseWishartDistribution(int dimensionality, double covarianceDivisor) {
        this(new MultivariateGaussian(dimensionality), new InverseWishartDistribution(dimensionality), covarianceDivisor);
    }

    public NormalInverseWishartDistribution(MultivariateGaussian gaussian, InverseWishartDistribution inverseWishart, double covarianceDivisor) {
        this.setGaussian(gaussian);
        this.setInverseWishart(inverseWishart);
        this.setCovarianceDivisor(covarianceDivisor);
    }

    public NormalInverseWishartDistribution(NormalInverseWishartDistribution other) {
        this(ObjectUtil.cloneSafe(other.getGaussian()), ObjectUtil.cloneSafe(other.getInverseWishart()), other.getCovarianceDivisor());
    }

    @Override
    public NormalInverseWishartDistribution clone() {
        NormalInverseWishartDistribution clone = (NormalInverseWishartDistribution)super.clone();
        clone.setGaussian(ObjectUtil.cloneSafe(this.getGaussian()));
        clone.setInverseWishart(ObjectUtil.cloneSafe(this.getInverseWishart()));
        return clone;
    }

    public MultivariateGaussian getGaussian() {
        return this.gaussian;
    }

    public void setGaussian(MultivariateGaussian gaussian) {
        this.gaussian = gaussian;
    }

    public InverseWishartDistribution getInverseWishart() {
        return this.inverseWishart;
    }

    public void setInverseWishart(InverseWishartDistribution inverseWishart) {
        this.inverseWishart = inverseWishart;
    }

    public double getCovarianceDivisor() {
        return this.covarianceDivisor;
    }

    public void setCovarianceDivisor(double covarianceDivisor) {
        if (covarianceDivisor <= 0.0) {
            throw new IllegalArgumentException("covarianceDivisor must be > 0.0");
        }
        this.covarianceDivisor = covarianceDivisor;
    }

    @Override
    public Matrix getMean() {
        Matrix C = this.inverseWishart.getMean();
        Vector mean = this.gaussian.getMean();
        int d = this.getInputDimensionality();
        Matrix R = MatrixFactory.getDefault().createMatrix(d, d + 1);
        R.setColumn(0, mean);
        R.setSubMatrix(0, 1, C);
        return R;
    }

    @Override
    public ArrayList<Matrix> sample(Random random, int numSamples) {
        int d = this.gaussian.getInputDimensionality();
        ArrayList<Matrix> samples = new ArrayList<Matrix>(numSamples);
        ArrayList<Matrix> covariances = this.inverseWishart.sample(random, numSamples);
        for (Matrix covariance : covariances) {
            Matrix meanAndCovariance = MatrixFactory.getDefault().createMatrix(d, d + 1);
            meanAndCovariance.setSubMatrix(0, 1, covariance);
            covariance.scaleEquals(1.0 / this.covarianceDivisor);
            this.gaussian.setCovariance(covariance);
            Vector mean = (Vector)this.gaussian.sample(random);
            meanAndCovariance.setColumn(0, mean);
            samples.add(meanAndCovariance);
        }
        return samples;
    }

    @Override
    public Vector convertToVector() {
        Vector c = VectorFactory.getDefault().copyValues(this.covarianceDivisor);
        c = c.stack(this.gaussian.getMean());
        return c.stack(this.inverseWishart.convertToVector());
    }

    @Override
    public void convertFromVector(Vector parameters) {
        int d = this.getInputDimensionality();
        parameters.assertDimensionalityEquals(1 + d + 1 + d * d);
        this.setCovarianceDivisor(parameters.getElement(0));
        Vector mean = parameters.subVector(1, d);
        this.gaussian.setMean(mean);
        Vector iwp = parameters.subVector(d + 1, parameters.getDimensionality() - 1);
        this.inverseWishart.convertFromVector(iwp);
    }

    public int getInputDimensionality() {
        return this.gaussian != null ? this.gaussian.getInputDimensionality() : 0;
    }

    public PDF getProbabilityFunction() {
        return new PDF(this);
    }

    public static class PDF
    extends NormalInverseWishartDistribution
    implements ProbabilityDensityFunction<Matrix> {
        public PDF() {
        }

        public PDF(int dimensionality, double covarianceDivisor) {
            super(dimensionality, covarianceDivisor);
        }

        public PDF(MultivariateGaussian gaussian, InverseWishartDistribution inverseWishart, double covarianceDivisor) {
            super(gaussian, inverseWishart, covarianceDivisor);
        }

        public PDF(NormalInverseWishartDistribution other) {
            super(other);
        }

        @Override
        public PDF getProbabilityFunction() {
            return this;
        }

        @Override
        public double logEvaluate(Matrix input) {
            int d = input.getNumRows();
            Vector mean = input.getColumn(0);
            Matrix C = input.getSubMatrix(0, d - 1, 1, d);
            C.scaleEquals(1.0 / this.covarianceDivisor);
            double lpg = this.gaussian.getProbabilityFunction().logEvaluate(mean);
            double lpiw = this.inverseWishart.getProbabilityFunction().logEvaluate(C);
            return lpg + lpiw;
        }

        @Override
        public Double evaluate(Matrix input) {
            return Math.exp(this.logEvaluate(input));
        }
    }
}

