/*
 * Decompiled with CFR 0.152.
 */
package gov.sandia.cognition.statistics.bayesian.conjugate;

import gov.sandia.cognition.annotation.PublicationReference;
import gov.sandia.cognition.annotation.PublicationReferences;
import gov.sandia.cognition.annotation.PublicationType;
import gov.sandia.cognition.collection.CollectionUtil;
import gov.sandia.cognition.math.MultivariateStatisticsUtil;
import gov.sandia.cognition.math.matrix.Matrix;
import gov.sandia.cognition.math.matrix.MatrixFactory;
import gov.sandia.cognition.math.matrix.Vector;
import gov.sandia.cognition.statistics.bayesian.AbstractBayesianParameter;
import gov.sandia.cognition.statistics.bayesian.BayesianParameter;
import gov.sandia.cognition.statistics.bayesian.conjugate.AbstractConjugatePriorBayesianEstimator;
import gov.sandia.cognition.statistics.bayesian.conjugate.ConjugatePriorBayesianEstimatorPredictor;
import gov.sandia.cognition.statistics.distribution.MultivariateGaussian;
import gov.sandia.cognition.statistics.distribution.MultivariateStudentTDistribution;
import gov.sandia.cognition.statistics.distribution.NormalInverseWishartDistribution;
import gov.sandia.cognition.util.Pair;
import java.util.Arrays;

@PublicationReferences(references={@PublicationReference(author={"Andrew Gelman", "John B. Carlin", "Hal S. Stern", "Donald B. Rubin"}, title="Bayesian Data Analysis, Second Edition", type=PublicationType.Book, year=2004, pages={87, 88}), @PublicationReference(author={"Wikipedia"}, title="Conjugate Prior", type=PublicationType.WebPage, year=2009, url="http://en.wikipedia.org/wiki/Conjugate_prior")})
public class MultivariateGaussianMeanCovarianceBayesianEstimator
extends AbstractConjugatePriorBayesianEstimator<Vector, Matrix, MultivariateGaussian, NormalInverseWishartDistribution>
implements ConjugatePriorBayesianEstimatorPredictor<Vector, Matrix, MultivariateGaussian, NormalInverseWishartDistribution> {
    public MultivariateGaussianMeanCovarianceBayesianEstimator() {
        this(new NormalInverseWishartDistribution());
    }

    public MultivariateGaussianMeanCovarianceBayesianEstimator(int dimensionality) {
        this(new NormalInverseWishartDistribution(dimensionality));
    }

    public MultivariateGaussianMeanCovarianceBayesianEstimator(NormalInverseWishartDistribution belief) {
        this(new MultivariateGaussian(belief.getInputDimensionality()), belief);
    }

    public MultivariateGaussianMeanCovarianceBayesianEstimator(MultivariateGaussian conditional, NormalInverseWishartDistribution prior) {
        this(new Parameter(conditional, prior));
    }

    protected MultivariateGaussianMeanCovarianceBayesianEstimator(BayesianParameter<Matrix, MultivariateGaussian, NormalInverseWishartDistribution> parameter) {
        super(parameter);
    }

    public Parameter createParameter(MultivariateGaussian conditional, NormalInverseWishartDistribution prior) {
        return new Parameter(conditional, prior);
    }

    @Override
    public void update(NormalInverseWishartDistribution target, Vector data) {
        this.update(target, (Iterable<? extends Vector>)Arrays.asList(data));
    }

    @Override
    public void update(NormalInverseWishartDistribution prior, Iterable<? extends Vector> data) {
        int n = CollectionUtil.size(data);
        Pair<Vector, Matrix> pair = MultivariateStatisticsUtil.computeMeanAndCovariance(data);
        Vector sampleMean = pair.getFirst();
        Matrix sampleCovariance = pair.getSecond();
        Vector lambda = prior.getGaussian().getMean();
        double nu = prior.getCovarianceDivisor();
        int alpha = prior.getInverseWishart().getDegreesOfFreedom();
        Matrix beta = prior.getInverseWishart().getInverseScale();
        int alphahat = alpha + n;
        double nuhat = nu + (double)n;
        Vector lambdahat = (Vector)lambda.scale(nu / (double)n);
        lambdahat.plusEquals(sampleMean);
        lambdahat.scaleEquals((double)n / nuhat);
        Vector delta = sampleMean;
        delta.minusEquals(lambda);
        Matrix betahat = sampleCovariance;
        if (n > 1) {
            betahat.scaleEquals(n);
        }
        betahat.plusEquals(beta);
        betahat.plusEquals(delta.outerProduct((Vector)delta.scale((double)n * nu / nuhat)));
        prior.getGaussian().setMean(lambdahat);
        prior.setCovarianceDivisor(nuhat);
        prior.getInverseWishart().setDegreesOfFreedom(alphahat);
        prior.getInverseWishart().setInverseScale(betahat);
    }

    @Override
    public double computeEquivalentSampleSize(NormalInverseWishartDistribution belief) {
        return belief.getCovarianceDivisor();
    }

    public MultivariateStudentTDistribution createPredictiveDistribution(NormalInverseWishartDistribution posterior) {
        Vector mean = posterior.getGaussian().getMean();
        double dofs = (double)(posterior.getInverseWishart().getDegreesOfFreedom() - posterior.getInverseWishart().getInputDimensionality()) + 1.0;
        Matrix covariance = (Matrix)posterior.getInverseWishart().getInverseScale().scale((posterior.getCovarianceDivisor() + 1.0) / (posterior.getCovarianceDivisor() * dofs));
        Matrix precision = covariance.inverse();
        return new MultivariateStudentTDistribution(dofs, mean, precision);
    }

    public static class Parameter
    extends AbstractBayesianParameter<Matrix, MultivariateGaussian, NormalInverseWishartDistribution> {
        public static final String NAME = "meanAndCovariance";

        public Parameter(MultivariateGaussian conditional, NormalInverseWishartDistribution prior) {
            super(conditional, NAME, prior);
        }

        @Override
        public void setValue(Matrix value) {
            int dim = ((MultivariateGaussian)this.conditionalDistribution).getInputDimensionality();
            if (value.getNumRows() != dim || value.getNumColumns() != dim + 1) {
                throw new IllegalArgumentException("Expected (dim x dim+1) Matrix");
            }
            Vector mean = value.getColumn(0);
            Matrix covariance = value.getSubMatrix(0, dim - 1, 1, dim);
            ((MultivariateGaussian)this.conditionalDistribution).setMean(mean);
            ((MultivariateGaussian)this.conditionalDistribution).setCovariance(covariance);
        }

        @Override
        public Matrix getValue() {
            int dim = ((MultivariateGaussian)this.conditionalDistribution).getInputDimensionality();
            Matrix parameter = MatrixFactory.getDefault().createMatrix(dim, dim + 1);
            parameter.setColumn(0, ((MultivariateGaussian)this.conditionalDistribution).getMean());
            parameter.setSubMatrix(0, 1, ((MultivariateGaussian)this.conditionalDistribution).getCovariance());
            return parameter;
        }
    }
}

