/*
 * Decompiled with CFR 0.152.
 */
package gov.sandia.cognition.statistics.bayesian;

import gov.sandia.cognition.annotation.PublicationReference;
import gov.sandia.cognition.annotation.PublicationReferences;
import gov.sandia.cognition.annotation.PublicationType;
import gov.sandia.cognition.collection.CollectionUtil;
import gov.sandia.cognition.learning.algorithm.clustering.cluster.DefaultCluster;
import gov.sandia.cognition.math.matrix.Matrix;
import gov.sandia.cognition.math.matrix.Vector;
import gov.sandia.cognition.math.matrix.VectorFactory;
import gov.sandia.cognition.statistics.ProbabilityFunction;
import gov.sandia.cognition.statistics.bayesian.AbstractMarkovChainMonteCarlo;
import gov.sandia.cognition.statistics.bayesian.conjugate.MultivariateGaussianMeanBayesianEstimator;
import gov.sandia.cognition.statistics.bayesian.conjugate.MultivariateGaussianMeanCovarianceBayesianEstimator;
import gov.sandia.cognition.statistics.distribution.BetaDistribution;
import gov.sandia.cognition.statistics.distribution.ChineseRestaurantProcess;
import gov.sandia.cognition.statistics.distribution.GammaDistribution;
import gov.sandia.cognition.statistics.distribution.MultivariateGaussian;
import gov.sandia.cognition.statistics.distribution.MultivariateStudentTDistribution;
import gov.sandia.cognition.statistics.distribution.NormalInverseWishartDistribution;
import gov.sandia.cognition.util.AbstractCloneableSerializable;
import gov.sandia.cognition.util.CloneableSerializable;
import gov.sandia.cognition.util.ObjectUtil;
import java.util.ArrayList;
import java.util.Collection;
import java.util.LinkedList;
import java.util.Random;

@PublicationReferences(references={@PublicationReference(author={"Radform M. Neal"}, title="Markov Chain Sampling Methods for Dirichlet Process Mixture Models", type=PublicationType.Journal, year=2000, publication="Journal of Computational and Graphical Statistics, Vol. 9, No. 2", pages={249, 265}, notes={"Based in part on Algorithm 2 from Neal"}), @PublicationReference(author={"Michael D. Escobar", "Mike West"}, title="Bayesian Density Estimation and Inference Using Mixtures", type=PublicationType.Journal, publication="Journal of the American Statistical Association", year=1995)})
public class DirichletProcessMixtureModel<ObservationType>
extends AbstractMarkovChainMonteCarlo<ObservationType, Sample<ObservationType>> {
    public static final double DEFAULT_ALPHA = 1.0;
    public static final int DEFAULT_NUM_INITIAL_CLUSTERS = 2;
    public static final boolean DEFAULT_REESTIMATE_ALPHA = true;
    protected Updater<ObservationType> updater;
    private int numInitialClusters;
    protected boolean reestimateAlpha;
    protected double initialAlpha;
    protected transient ProbabilityFunction<ObservationType> conditionalPriorPredictive;
    protected transient double[] clusterWeights;
    protected transient BetaDistribution etaSampler;
    protected transient GammaDistribution alphaInverseSampler;

    public DirichletProcessMixtureModel() {
        this.setReestimateAlpha(true);
        this.setInitialAlpha(1.0);
        this.setNumInitialClusters(2);
    }

    @Override
    public DirichletProcessMixtureModel<ObservationType> clone() {
        DirichletProcessMixtureModel clone = (DirichletProcessMixtureModel)super.clone();
        clone.setUpdater(ObjectUtil.cloneSafe(this.getUpdater()));
        return clone;
    }

    @Override
    protected void mcmcUpdate() {
        if (this.conditionalPriorPredictive == null) {
            this.conditionalPriorPredictive = this.updater.createPriorPredictive((Iterable)this.data);
        }
        int K = ((Sample)this.currentParameter).getNumClusters();
        DPMMLogConditional logConditional = new DPMMLogConditional();
        ArrayList<Collection<ObservationType>> clusterAssignments = this.assignObservationsToClusters(K, logConditional);
        int numObservations = CollectionUtil.size((Collection)this.data);
        if (this.previousParameter != null && ((Sample)this.previousParameter).posteriorLogLikelihood == null) {
            ((Sample)this.previousParameter).posteriorLogLikelihood = ((Sample)this.previousParameter).computePosteriorLogLikelihood(numObservations, logConditional.logConditional);
        }
        ((Sample)this.currentParameter).clusters = this.updateClusters(clusterAssignments);
        if (this.getReestimateAlpha()) {
            ((Sample)this.currentParameter).alpha = this.updateAlpha(((Sample)this.currentParameter).alpha, numObservations);
        }
    }

    protected ArrayList<DPMMCluster<ObservationType>> updateClusters(ArrayList<Collection<ObservationType>> clusterAssignments) {
        int Kp1 = clusterAssignments.size();
        ArrayList<DPMMCluster<ObservationType>> clusters = new ArrayList<DPMMCluster<ObservationType>>(Kp1);
        for (int k = 0; k < Kp1; ++k) {
            DPMMCluster<ObservationType> cluster;
            Collection<ObservationType> assignments = clusterAssignments.get(k);
            if (assignments.size() <= 1 || (cluster = this.createCluster(assignments, this.updater)) == null) continue;
            clusters.add(cluster);
        }
        return clusters;
    }

    protected ArrayList<Collection<ObservationType>> assignObservationsToClusters(int K, DPMMLogConditional logConditional) {
        if (this.clusterWeights == null || this.clusterWeights.length != K + 1) {
            this.clusterWeights = new double[K + 1];
        }
        ArrayList<Collection<ObservationType>> clusterAssignments = new ArrayList<Collection<ObservationType>>(K + 1);
        for (int k = 0; k < K + 1; ++k) {
            clusterAssignments.add(new LinkedList());
        }
        for (Object observation : (Collection)this.data) {
            int clusterAssignment = this.assignObservationToCluster(observation, this.clusterWeights, logConditional);
            clusterAssignments.get(clusterAssignment).add(observation);
        }
        return clusterAssignments;
    }

    protected int assignObservationToCluster(ObservationType observation, double[] weights, DPMMLogConditional logConditional) {
        double newClusterWeight;
        double alpha = ((Sample)this.currentParameter).alpha;
        int K = ((Sample)this.currentParameter).getNumClusters();
        weights[K] = newClusterWeight = alpha * (Double)this.conditionalPriorPredictive.evaluate(observation);
        double weightSum = newClusterWeight;
        double conditional = 1.0E-100;
        for (int k = 0; k < K; ++k) {
            DPMMCluster cluster = ((Sample)this.currentParameter).clusters.get(k);
            int num = ((ArrayList)cluster.getMembers()).size();
            if (num > 0) {
                double weight;
                double c = (Double)cluster.getProbabilityFunction().evaluate(observation);
                weights[k] = weight = (double)(num - 1) * c;
                weightSum += weight;
                conditional += (double)num * c;
                continue;
            }
            weights[k] = 0.0;
        }
        logConditional.logConditional += Math.log(conditional);
        double p = weightSum * this.random.nextDouble();
        for (int k = 0; k < K + 1; ++k) {
            if (!((p -= weights[k]) <= 0.0)) continue;
            return k;
        }
        throw new IllegalArgumentException("Did not select cluster: " + weightSum);
    }

    protected DPMMCluster<ObservationType> createCluster(Collection<ObservationType> clusterAssignment, Updater<ObservationType> localUpdater) {
        if (clusterAssignment == null) {
            return null;
        }
        double weight = clusterAssignment.size();
        if (weight <= 0.0) {
            return null;
        }
        ProbabilityFunction<ObservationType> probabilityFunction = localUpdater.createClusterPosterior(clusterAssignment, this.random);
        return new DPMMCluster<ObservationType>(clusterAssignment, probabilityFunction);
    }

    protected double updateAlpha(double alpha, int numObservations) {
        if (this.etaSampler == null) {
            this.etaSampler = new BetaDistribution();
        }
        this.etaSampler.setAlpha(alpha + 1.0);
        this.etaSampler.setBeta(numObservations);
        double eta = (Double)this.etaSampler.sample(this.random);
        double logEta = Math.log(eta);
        double a = 1.0;
        double b = 1.0;
        int updatedK = ((Sample)this.currentParameter).getNumClusters();
        double etaWeight = (1.0 + (double)updatedK - 1.0) / ((double)numObservations * (1.0 - logEta));
        double pEta = this.random.nextDouble();
        if (this.alphaInverseSampler == null) {
            this.alphaInverseSampler = new GammaDistribution();
        }
        if (pEta < etaWeight) {
            this.alphaInverseSampler.setShape(1.0 + (double)updatedK);
        } else {
            this.alphaInverseSampler.setShape(1.0 + (double)updatedK - 1.0);
        }
        this.alphaInverseSampler.setScale(1.0 - logEta);
        return 1.0 / (Double)this.alphaInverseSampler.sample(this.random);
    }

    @Override
    public Sample<ObservationType> createInitialLearnedObject() {
        ArrayList clusters = new ArrayList(this.getNumInitialClusters());
        ProbabilityFunction<ObservationType> probabilityFunction = this.updater.createClusterPosterior((Iterable)this.data, this.random);
        ArrayList dataArray = CollectionUtil.asArrayList((Iterable)this.data);
        for (int k = 0; k < this.getNumInitialClusters(); ++k) {
            clusters.add(new DPMMCluster(dataArray, probabilityFunction));
        }
        return new Sample(this.getInitialAlpha(), clusters);
    }

    public Updater<ObservationType> getUpdater() {
        return this.updater;
    }

    public void setUpdater(Updater<ObservationType> updater) {
        this.updater = updater;
    }

    public int getNumInitialClusters() {
        return this.numInitialClusters;
    }

    public void setNumInitialClusters(int numInitialClusters) {
        this.numInitialClusters = numInitialClusters;
    }

    public boolean getReestimateAlpha() {
        return this.reestimateAlpha;
    }

    public void setReestimateAlpha(boolean reestimateAlpha) {
        this.reestimateAlpha = reestimateAlpha;
    }

    public double getInitialAlpha() {
        return this.initialAlpha;
    }

    public void setInitialAlpha(double initialAlpha) {
        this.initialAlpha = initialAlpha;
    }

    public static class MultivariateMeanUpdater
    extends AbstractCloneableSerializable
    implements Updater<Vector> {
        protected MultivariateGaussianMeanBayesianEstimator estimator;

        public MultivariateMeanUpdater() {
            this(null);
        }

        public MultivariateMeanUpdater(int dimensionality) {
            this(new MultivariateGaussianMeanBayesianEstimator(dimensionality));
        }

        public MultivariateMeanUpdater(MultivariateGaussianMeanBayesianEstimator estimator) {
            this.estimator = estimator;
        }

        @Override
        public MultivariateMeanUpdater clone() {
            MultivariateMeanUpdater clone = (MultivariateMeanUpdater)super.clone();
            clone.estimator = ObjectUtil.cloneSafe(this.estimator);
            return clone;
        }

        public MultivariateGaussian.PDF createPriorPredictive(Iterable<? extends Vector> data) {
            MultivariateGaussian posterior = (MultivariateGaussian)this.estimator.learn(data);
            return this.estimator.createPredictiveDistribution(posterior).getProbabilityFunction();
        }

        public MultivariateGaussian.PDF createClusterPosterior(Iterable<? extends Vector> values, Random random) {
            MultivariateGaussian posterior = (MultivariateGaussian)this.estimator.learn(values);
            Vector parameters = (Vector)posterior.sample(random);
            return this.estimator.createConditionalDistribution(parameters).getProbabilityFunction();
        }
    }

    public static class MultivariateMeanCovarianceUpdater
    extends AbstractCloneableSerializable
    implements Updater<Vector> {
        private MultivariateGaussianMeanCovarianceBayesianEstimator estimator;

        public MultivariateMeanCovarianceUpdater() {
            this(null);
        }

        public MultivariateMeanCovarianceUpdater(int dimensionality) {
            this(new MultivariateGaussianMeanCovarianceBayesianEstimator(dimensionality));
        }

        public MultivariateMeanCovarianceUpdater(MultivariateGaussianMeanCovarianceBayesianEstimator estimator) {
            this.estimator = estimator;
        }

        @Override
        public MultivariateMeanCovarianceUpdater clone() {
            MultivariateMeanCovarianceUpdater clone = (MultivariateMeanCovarianceUpdater)super.clone();
            clone.estimator = ObjectUtil.cloneSafe(this.estimator);
            return clone;
        }

        public MultivariateStudentTDistribution.PDF createPriorPredictive(Iterable<? extends Vector> data) {
            NormalInverseWishartDistribution posterior = (NormalInverseWishartDistribution)this.estimator.learn(data);
            return this.estimator.createPredictiveDistribution(posterior).getProbabilityFunction();
        }

        public MultivariateGaussian.PDF createClusterPosterior(Iterable<? extends Vector> values, Random random) {
            NormalInverseWishartDistribution posterior = (NormalInverseWishartDistribution)this.estimator.learn(values);
            Matrix parameters = (Matrix)posterior.sample(random);
            return ((MultivariateGaussian)this.estimator.createConditionalDistribution(parameters)).getProbabilityFunction();
        }
    }

    public static interface Updater<ObservationType>
    extends CloneableSerializable {
        public ProbabilityFunction<ObservationType> createPriorPredictive(Iterable<? extends ObservationType> var1);

        public ProbabilityFunction<ObservationType> createClusterPosterior(Iterable<? extends ObservationType> var1, Random var2);
    }

    public static class Sample<ObservationType>
    extends AbstractCloneableSerializable {
        protected double alpha;
        protected ArrayList<DPMMCluster<ObservationType>> clusters;
        private Double posteriorLogLikelihood;

        public Sample(double alpha, ArrayList<DPMMCluster<ObservationType>> clusters) {
            this.setAlpha(alpha);
            this.setClusters(clusters);
            this.setPosteriorLogLikelihood(null);
        }

        @Override
        public Sample<ObservationType> clone() {
            Sample clone = (Sample)super.clone();
            clone.setClusters(ObjectUtil.cloneSmartElementsAsArrayList(this.getClusters()));
            clone.setPosteriorLogLikelihood(null);
            return clone;
        }

        public double computePosteriorLogLikelihood(Iterable<? extends ObservationType> data) {
            int K = this.getNumClusters();
            int numObservations = CollectionUtil.size(data);
            double logSum = 0.0;
            for (ObservationType value : data) {
                double p = 1.0E-100;
                for (int k = 0; k < K; ++k) {
                    DPMMCluster<ObservationType> cluster = this.clusters.get(k);
                    int weight = ((ArrayList)cluster.getMembers()).size();
                    double likelihood = (Double)cluster.getProbabilityFunction().evaluate(value);
                    p += (double)weight * likelihood;
                }
                logSum += Math.log(p);
            }
            ChineseRestaurantProcess.PMF pmf = new ChineseRestaurantProcess.PMF(this.getAlpha(), numObservations);
            Vector counts = VectorFactory.getDefault().createVector(K);
            for (int k = 0; k < K; ++k) {
                counts.setElement(k, ((ArrayList)this.clusters.get(k).getMembers()).size());
            }
            return logSum += pmf.logEvaluate(counts);
        }

        public double computePosteriorLogLikelihood(int numObservations, double logConditional) {
            int K = this.getNumClusters();
            ChineseRestaurantProcess.PMF pmf = new ChineseRestaurantProcess.PMF(this.getAlpha(), numObservations);
            Vector counts = VectorFactory.getDefault().createVector(K);
            for (int k = 0; k < K; ++k) {
                counts.setElement(k, ((ArrayList)this.clusters.get(k).getMembers()).size());
            }
            double logPrior = pmf.logEvaluate(counts);
            double logPosterior = logPrior + logConditional;
            return logPosterior;
        }

        public void removeUnusedClusters() {
            for (int j = 0; j < this.getNumClusters(); ++j) {
                DPMMCluster<ObservationType> cluster = this.clusters.get(j);
                if (((ArrayList)cluster.getMembers()).size() > 0) continue;
                this.clusters.remove(j);
                --j;
            }
        }

        public double getAlpha() {
            return this.alpha;
        }

        protected void setAlpha(double alpha) {
            if (alpha <= 0.0) {
                throw new IllegalArgumentException("Alpha must be > 0.0 ");
            }
            this.alpha = alpha;
        }

        public int getNumClusters() {
            return this.clusters.size();
        }

        public ArrayList<DPMMCluster<ObservationType>> getClusters() {
            return this.clusters;
        }

        protected void setClusters(ArrayList<DPMMCluster<ObservationType>> clusters) {
            this.clusters = clusters;
        }

        public Double getPosteriorLogLikelihood() {
            return this.posteriorLogLikelihood;
        }

        public void setPosteriorLogLikelihood(Double posteriorLogLikelihood) {
            this.posteriorLogLikelihood = posteriorLogLikelihood;
        }
    }

    public static class DPMMCluster<ObservationType>
    extends DefaultCluster<ObservationType> {
        private ProbabilityFunction<? super ObservationType> probabilityFunction;

        public DPMMCluster(Collection<? extends ObservationType> assignedData, ProbabilityFunction<? super ObservationType> probabilityFunction) {
            super(assignedData);
            this.setProbabilityFunction(probabilityFunction);
        }

        @Override
        public DPMMCluster<ObservationType> clone() {
            DPMMCluster clone = (DPMMCluster)super.clone();
            clone.setProbabilityFunction(ObjectUtil.cloneSafe(this.getProbabilityFunction()));
            return clone;
        }

        public ProbabilityFunction<? super ObservationType> getProbabilityFunction() {
            return this.probabilityFunction;
        }

        public void setProbabilityFunction(ProbabilityFunction<? super ObservationType> probabilityFunction) {
            this.probabilityFunction = probabilityFunction;
        }
    }

    protected static class DPMMLogConditional
    extends AbstractCloneableSerializable {
        double logConditional = 0.0;
    }
}

