/*
 * Decompiled with CFR 0.152.
 */
package gov.sandia.cognition.statistics.distribution;

import gov.sandia.cognition.algorithm.AnytimeAlgorithmWrapper;
import gov.sandia.cognition.algorithm.MeasurablePerformanceAlgorithm;
import gov.sandia.cognition.annotation.PublicationReference;
import gov.sandia.cognition.annotation.PublicationType;
import gov.sandia.cognition.collection.CollectionUtil;
import gov.sandia.cognition.learning.algorithm.AbstractAnytimeBatchLearner;
import gov.sandia.cognition.learning.algorithm.clustering.KMeansClusterer;
import gov.sandia.cognition.learning.algorithm.clustering.cluster.GaussianCluster;
import gov.sandia.cognition.math.Ring;
import gov.sandia.cognition.math.RingAccumulator;
import gov.sandia.cognition.math.matrix.Matrix;
import gov.sandia.cognition.math.matrix.Vector;
import gov.sandia.cognition.math.matrix.VectorFactory;
import gov.sandia.cognition.statistics.DistributionEstimator;
import gov.sandia.cognition.statistics.distribution.MultivariateGaussian;
import gov.sandia.cognition.statistics.distribution.MultivariateMixtureDensityModel;
import gov.sandia.cognition.util.ArgumentChecker;
import gov.sandia.cognition.util.DefaultNamedValue;
import gov.sandia.cognition.util.DefaultWeightedValue;
import gov.sandia.cognition.util.NamedValue;
import gov.sandia.cognition.util.ObjectUtil;
import gov.sandia.cognition.util.Randomized;
import gov.sandia.cognition.util.WeightedValue;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Iterator;
import java.util.Random;

@PublicationReference(author={"Wikipedia"}, title="Mixture Model", type=PublicationType.WebPage, year=2009, url="http://en.wikipedia.org/wiki/Mixture_model")
public class MixtureOfGaussians {

    @PublicationReference(author={"Jaakkola"}, title="Estimating mixtures: the EM-algorithm", type=PublicationType.Misc, year=2007, url="http://courses.csail.mit.edu/6.867/lectures/notes-em2.pdf")
    public static class EMLearner
    extends AbstractAnytimeBatchLearner<Collection<? extends Vector>, PDF>
    implements Randomized,
    DistributionEstimator<Vector, PDF>,
    MeasurablePerformanceAlgorithm {
        public static final String PERFORMANCE_NAME = "Assignment Change";
        public static final int DEFAULT_MAX_ITERATIONS = 100;
        public static final double DEFAULT_TOLERANCE = 1.0E-5;
        private MultivariateGaussian.WeightedMaximumLikelihoodEstimator learner;
        protected Random random;
        private double tolerance;
        private transient ArrayList<DefaultWeightedValue<Vector>> weightedData;
        private transient ArrayList<double[]> assignments;
        private transient ArrayList<MultivariateGaussian.PDF> distributions;
        private transient double[] distributionPrior;
        private transient double assignmentChanged;

        public EMLearner(Random random) {
            this(2, random);
        }

        public EMLearner(int distributionCount, Random random) {
            this(distributionCount, new MultivariateGaussian.WeightedMaximumLikelihoodEstimator(), random);
        }

        public EMLearner(int distributionCount, MultivariateGaussian.WeightedMaximumLikelihoodEstimator learner, Random random) {
            super(100);
            this.setRandom(random);
            this.setTolerance(1.0E-5);
            this.learner = learner;
            this.distributionPrior = new double[distributionCount];
            Arrays.fill(this.distributionPrior, 1.0);
        }

        @Override
        protected boolean initializeAlgorithm() {
            int N = ((Collection)this.data).size();
            int K = this.distributionPrior.length;
            int dim = ((Vector)CollectionUtil.getFirst((Iterable)this.data)).getDimensionality();
            Vector[] x = new Vector[K];
            for (int k = 0; k < K; ++k) {
                int index = this.random.nextInt(N);
                x[k] = VectorFactory.getDefault().createUniformRandom(dim, -1.0, 1.0, this.getRandom());
                x[k].plusEquals((Ring)CollectionUtil.getElement((Iterable)this.data, index));
            }
            this.weightedData = new ArrayList(N);
            this.assignments = new ArrayList(N);
            this.distributionPrior = new double[K];
            this.assignmentChanged = N;
            for (Vector value : (Collection)this.data) {
                int k;
                this.weightedData.add(new DefaultWeightedValue<Vector>(value, 0.0));
                double[] assignment = new double[K];
                double sum = 0.0;
                for (k = 0; k < K; ++k) {
                    double ak;
                    Vector delta = value.minus(x[k]);
                    assignment[k] = ak = Math.exp(-delta.norm1());
                    sum += ak;
                }
                if (sum <= 0.0) {
                    sum = 1.0;
                }
                for (k = 0; k < K; ++k) {
                    int n = k;
                    assignment[n] = assignment[n] / sum;
                    int n2 = k;
                    this.distributionPrior[n2] = this.distributionPrior[n2] + assignment[k];
                }
                this.assignments.add(assignment);
            }
            this.distributions = new ArrayList(K);
            for (int k = 0; k < K; ++k) {
                for (int n = 0; n < N; ++n) {
                    this.weightedData.get(n).setWeight(this.assignments.get(n)[k]);
                }
                this.distributions.add(this.learner.learn((Collection<? extends WeightedValue<? extends Vector>>)this.weightedData));
            }
            return true;
        }

        @Override
        protected boolean step() {
            int N = ((Collection)this.data).size();
            int K = this.distributionPrior.length;
            this.assignmentChanged = 0.0;
            Arrays.fill(this.distributionPrior, 0.0);
            double[] anold = new double[K];
            for (int n = 0; n < N; ++n) {
                Vector xn = this.weightedData.get(n).getValue();
                double[] an = this.assignments.get(n);
                System.arraycopy(an, 0, anold, 0, K);
                int k = 0;
                double sum = 0.0;
                for (MultivariateGaussian.PDF pdf : this.distributions) {
                    double ank;
                    an[k] = ank = pdf.evaluate(xn).doubleValue();
                    sum += ank;
                    ++k;
                }
                if (sum <= 0.0) {
                    sum = 1.0;
                }
                k = 0;
                while (k < K) {
                    double ank;
                    an[k] = ank = an[k] / sum;
                    double delta = Math.abs(ank - anold[k]);
                    int n2 = k++;
                    this.distributionPrior[n2] = this.distributionPrior[n2] + ank;
                    this.assignmentChanged += delta;
                }
            }
            System.out.println(this.getIteration() + ": " + this.assignmentChanged);
            if (this.assignmentChanged <= this.getTolerance()) {
                return false;
            }
            for (int k = 0; k < K; ++k) {
                for (int n = 0; n < N; ++n) {
                    this.weightedData.get(n).setWeight(this.assignments.get(n)[k]);
                }
                this.distributions.set(k, this.learner.learn((Collection<? extends WeightedValue<? extends Vector>>)this.weightedData));
                System.out.println("\t" + k + ": Prior = " + this.distributionPrior[k] + " Mean: " + this.distributions.get(k).getMean());
            }
            return true;
        }

        @Override
        protected void cleanupAlgorithm() {
            this.weightedData = null;
            this.assignments = null;
            this.data = null;
        }

        @Override
        public PDF getResult() {
            return new PDF((Collection<? extends MultivariateGaussian>)this.distributions, this.distributionPrior);
        }

        public NamedValue<Double> getPerformance() {
            return new DefaultNamedValue<Double>(PERFORMANCE_NAME, this.assignmentChanged);
        }

        public double getTolerance() {
            return this.tolerance;
        }

        public void setTolerance(double tolerance) {
            ArgumentChecker.assertIsNonNegative("tolerance", tolerance);
            this.tolerance = tolerance;
        }

        @Override
        public Random getRandom() {
            return this.random;
        }

        @Override
        public void setRandom(Random random) {
            this.random = random;
        }
    }

    public static class Learner
    extends AnytimeAlgorithmWrapper<PDF, KMeansClusterer<Vector, GaussianCluster>>
    implements DistributionEstimator<Vector, PDF>,
    MeasurablePerformanceAlgorithm {
        public Learner(KMeansClusterer<Vector, GaussianCluster> algorithm) {
            super(algorithm);
        }

        @Override
        public PDF getResult() {
            Object clusters = ((KMeansClusterer)this.getAlgorithm()).getResult();
            if (clusters != null && clusters.size() > 0) {
                int K = clusters.size();
                ArrayList<MultivariateGaussian.PDF> gaussians = new ArrayList<MultivariateGaussian.PDF>(K);
                double[] priorProbabilities = new double[K];
                int index = 0;
                Iterator i$ = clusters.iterator();
                while (i$.hasNext()) {
                    GaussianCluster cluster = (GaussianCluster)i$.next();
                    gaussians.add(cluster.getGaussian());
                    int num = ((ArrayList)cluster.getMembers()).size();
                    priorProbabilities[index] = num;
                    ++index;
                }
                return new PDF((Collection<? extends MultivariateGaussian>)gaussians, priorProbabilities);
            }
            return null;
        }

        @Override
        public PDF learn(Collection<? extends Vector> data) {
            ((KMeansClusterer)this.getAlgorithm()).learn(data);
            return this.getResult();
        }

        @Override
        public NamedValue<? extends Number> getPerformance() {
            return ((KMeansClusterer)this.getAlgorithm()).getPerformance();
        }
    }

    public static class PDF
    extends MultivariateMixtureDensityModel.PDF<MultivariateGaussian> {
        public PDF(MultivariateGaussian ... distributions) {
            this((Collection<? extends MultivariateGaussian>)Arrays.asList(distributions));
        }

        public PDF(Collection<? extends MultivariateGaussian> distributions) {
            this(distributions, (double[])null);
        }

        public PDF(Collection<? extends MultivariateGaussian> distributions, double[] priorWeights) {
            super(distributions, priorWeights);
        }

        public PDF(PDF other) {
            this(ObjectUtil.cloneSmartElementsAsArrayList(other.getDistributions()), ObjectUtil.deepCopy(other.getPriorWeights()));
        }

        @Override
        public PDF clone() {
            return (PDF)super.clone();
        }

        @Override
        public PDF getProbabilityFunction() {
            return this;
        }

        public int getDimensionality() {
            return ((MultivariateGaussian)CollectionUtil.getFirst(this.getDistributions())).getInputDimensionality();
        }

        public MultivariateGaussian.PDF fitSingleGaussian() {
            Object mean = this.getMean();
            RingAccumulator covarianceAccumulator = new RingAccumulator();
            double denom = this.getPriorWeightSum();
            for (int i = 0; i < this.getDistributionCount(); ++i) {
                MultivariateGaussian Gaussian2 = (MultivariateGaussian)this.getDistributions().get(i);
                Vector meanDiff = (Vector)Gaussian2.getMean().minus(mean);
                covarianceAccumulator.accumulate(Gaussian2.getCovariance().plus(meanDiff.outerProduct(meanDiff)).scale(this.priorWeights[i] / denom));
            }
            return new MultivariateGaussian.PDF((Vector)mean, (Matrix)covarianceAccumulator.getSum());
        }

        public double computeWeightedZSquared(Vector input) {
            double[] p = this.computeRandomVariableProbabilities(input);
            double weightedZSquared = 0.0;
            int index = 0;
            for (MultivariateGaussian g : this.getDistributions()) {
                weightedZSquared += g.computeZSquared(input) * p[index];
                ++index;
            }
            return weightedZSquared;
        }
    }
}

