package gov.sandia.cognition.statistics.distribution;

import gov.sandia.cognition.algorithm.AnytimeAlgorithmWrapper;
import gov.sandia.cognition.algorithm.MeasurablePerformanceAlgorithm;
import gov.sandia.cognition.annotation.PublicationReference;
import gov.sandia.cognition.annotation.PublicationType;
import gov.sandia.cognition.collection.CollectionUtil;
import gov.sandia.cognition.learning.algorithm.AbstractAnytimeBatchLearner;
import gov.sandia.cognition.learning.algorithm.clustering.KMeansClusterer;
import gov.sandia.cognition.learning.algorithm.clustering.cluster.GaussianCluster;
import gov.sandia.cognition.math.Ring;
import gov.sandia.cognition.math.RingAccumulator;
import gov.sandia.cognition.math.matrix.Matrix;
import gov.sandia.cognition.math.matrix.Vector;
import gov.sandia.cognition.math.matrix.VectorFactory;
import gov.sandia.cognition.statistics.DistributionEstimator;
import gov.sandia.cognition.statistics.distribution.MultivariateGaussian;
import gov.sandia.cognition.statistics.distribution.MultivariateMixtureDensityModel;
import gov.sandia.cognition.util.ArgumentChecker;
import gov.sandia.cognition.util.DefaultNamedValue;
import gov.sandia.cognition.util.DefaultWeightedValue;
import gov.sandia.cognition.util.NamedValue;
import gov.sandia.cognition.util.ObjectUtil;
import gov.sandia.cognition.util.Randomized;
import gov.sandia.cognition.util.WeightedValue;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Iterator;
import java.util.Random;

@PublicationReference(author = {"Wikipedia"}, title = "Mixture Model", type = PublicationType.WebPage, year = 2009, url = "http://en.wikipedia.org/wiki/Mixture_model")
/* loaded from: input_file:gov/sandia/cognition/statistics/distribution/MixtureOfGaussians.class */
public class MixtureOfGaussians {

    @PublicationReference(author = {"Jaakkola"}, title = "Estimating mixtures: the EM-algorithm", type = PublicationType.Misc, year = 2007, url = "http://courses.csail.mit.edu/6.867/lectures/notes-em2.pdf")
    /* loaded from: input_file:gov/sandia/cognition/statistics/distribution/MixtureOfGaussians$EMLearner.class */
    public static class EMLearner extends AbstractAnytimeBatchLearner<Collection<? extends Vector>, PDF> implements Randomized, DistributionEstimator<Vector, PDF>, MeasurablePerformanceAlgorithm {
        public static final String PERFORMANCE_NAME = "Assignment Change";
        public static final int DEFAULT_MAX_ITERATIONS = 100;
        public static final double DEFAULT_TOLERANCE = 1.0E-5d;
        private MultivariateGaussian.WeightedMaximumLikelihoodEstimator learner;
        protected Random random;
        private double tolerance;
        private transient ArrayList<DefaultWeightedValue<Vector>> weightedData;
        private transient ArrayList<double[]> assignments;
        private transient ArrayList<MultivariateGaussian.PDF> distributions;
        private transient double[] distributionPrior;
        private transient double assignmentChanged;

        public EMLearner(Random random) {
            this(2, random);
        }

        public EMLearner(int i, Random random) {
            this(i, new MultivariateGaussian.WeightedMaximumLikelihoodEstimator(), random);
        }

        public EMLearner(int i, MultivariateGaussian.WeightedMaximumLikelihoodEstimator weightedMaximumLikelihoodEstimator, Random random) {
            super(100);
            setRandom(random);
            setTolerance(1.0E-5d);
            this.learner = weightedMaximumLikelihoodEstimator;
            this.distributionPrior = new double[i];
            Arrays.fill(this.distributionPrior, 1.0d);
        }

        @Override // gov.sandia.cognition.learning.algorithm.AbstractAnytimeBatchLearner
        protected boolean initializeAlgorithm() {
            int size = ((Collection) this.data).size();
            int length = this.distributionPrior.length;
            int dimensionality = ((Vector) CollectionUtil.getFirst((Iterable) this.data)).getDimensionality();
            Vector[] vectorArr = new Vector[length];
            for (int i = 0; i < length; i++) {
                int nextInt = this.random.nextInt(size);
                vectorArr[i] = VectorFactory.getDefault().createUniformRandom(dimensionality, -1.0d, 1.0d, getRandom());
                vectorArr[i].plusEquals((Ring) CollectionUtil.getElement((Iterable) this.data, nextInt));
            }
            this.weightedData = new ArrayList<>(size);
            this.assignments = new ArrayList<>(size);
            this.distributionPrior = new double[length];
            this.assignmentChanged = size;
            for (Vector vector : (Collection) this.data) {
                this.weightedData.add(new DefaultWeightedValue<>(vector, 0.0d));
                double[] dArr = new double[length];
                double d = 0.0d;
                for (int i2 = 0; i2 < length; i2++) {
                    double exp = Math.exp(-vector.minus(vectorArr[i2]).norm1());
                    dArr[i2] = exp;
                    d += exp;
                }
                if (d <= 0.0d) {
                    d = 1.0d;
                }
                for (int i3 = 0; i3 < length; i3++) {
                    int i4 = i3;
                    dArr[i4] = dArr[i4] / d;
                    double[] dArr2 = this.distributionPrior;
                    int i5 = i3;
                    dArr2[i5] = dArr2[i5] + dArr[i3];
                }
                this.assignments.add(dArr);
            }
            this.distributions = new ArrayList<>(length);
            for (int i6 = 0; i6 < length; i6++) {
                for (int i7 = 0; i7 < size; i7++) {
                    this.weightedData.get(i7).setWeight(this.assignments.get(i7)[i6]);
                }
                this.distributions.add(this.learner.learn((Collection<? extends WeightedValue<? extends Vector>>) this.weightedData));
            }
            return true;
        }

        @Override // gov.sandia.cognition.learning.algorithm.AbstractAnytimeBatchLearner
        protected boolean step() {
            int size = ((Collection) this.data).size();
            int length = this.distributionPrior.length;
            this.assignmentChanged = 0.0d;
            Arrays.fill(this.distributionPrior, 0.0d);
            double[] dArr = new double[length];
            for (int i = 0; i < size; i++) {
                Vector value = this.weightedData.get(i).getValue();
                double[] dArr2 = this.assignments.get(i);
                System.arraycopy(dArr2, 0, dArr, 0, length);
                int i2 = 0;
                double d = 0.0d;
                Iterator<MultivariateGaussian.PDF> it = this.distributions.iterator();
                while (it.hasNext()) {
                    double doubleValue = it.next().evaluate(value).doubleValue();
                    dArr2[i2] = doubleValue;
                    d += doubleValue;
                    i2++;
                }
                if (d <= 0.0d) {
                    d = 1.0d;
                }
                for (int i3 = 0; i3 < length; i3++) {
                    double d2 = dArr2[i3] / d;
                    dArr2[i3] = d2;
                    double abs = Math.abs(d2 - dArr[i3]);
                    double[] dArr3 = this.distributionPrior;
                    int i4 = i3;
                    dArr3[i4] = dArr3[i4] + d2;
                    this.assignmentChanged += abs;
                }
            }
            System.out.println(getIteration() + ": " + this.assignmentChanged);
            if (this.assignmentChanged <= getTolerance()) {
                return false;
            }
            for (int i5 = 0; i5 < length; i5++) {
                for (int i6 = 0; i6 < size; i6++) {
                    this.weightedData.get(i6).setWeight(this.assignments.get(i6)[i5]);
                }
                this.distributions.set(i5, this.learner.learn((Collection<? extends WeightedValue<? extends Vector>>) this.weightedData));
                System.out.println("\t" + i5 + ": Prior = " + this.distributionPrior[i5] + " Mean: " + this.distributions.get(i5).getMean());
            }
            return true;
        }

        @Override // gov.sandia.cognition.learning.algorithm.AbstractAnytimeBatchLearner
        protected void cleanupAlgorithm() {
            this.weightedData = null;
            this.assignments = null;
            this.data = null;
        }

        @Override // gov.sandia.cognition.algorithm.AnytimeAlgorithm
        /* renamed from: getResult */
        public PDF getResult2() {
            return new PDF(this.distributions, this.distributionPrior);
        }

        @Override // gov.sandia.cognition.algorithm.MeasurablePerformanceAlgorithm
        public NamedValue<Double> getPerformance() {
            return new DefaultNamedValue("Assignment Change", Double.valueOf(this.assignmentChanged));
        }

        public double getTolerance() {
            return this.tolerance;
        }

        public void setTolerance(double d) {
            ArgumentChecker.assertIsNonNegative("tolerance", d);
            this.tolerance = d;
        }

        @Override // gov.sandia.cognition.util.Randomized
        public Random getRandom() {
            return this.random;
        }

        @Override // gov.sandia.cognition.util.Randomized
        public void setRandom(Random random) {
            this.random = random;
        }
    }

    /* loaded from: input_file:gov/sandia/cognition/statistics/distribution/MixtureOfGaussians$Learner.class */
    public static class Learner extends AnytimeAlgorithmWrapper<PDF, KMeansClusterer<Vector, GaussianCluster>> implements DistributionEstimator<Vector, PDF>, MeasurablePerformanceAlgorithm {
        public Learner(KMeansClusterer<Vector, GaussianCluster> kMeansClusterer) {
            super(kMeansClusterer);
        }

        @Override // gov.sandia.cognition.algorithm.AnytimeAlgorithm
        /* renamed from: getResult */
        public PDF getResult2() {
            ArrayList<GaussianCluster> result2 = getAlgorithm().getResult2();
            if (result2 == null || result2.size() <= 0) {
                return null;
            }
            int size = result2.size();
            ArrayList arrayList = new ArrayList(size);
            double[] dArr = new double[size];
            int i = 0;
            Iterator<GaussianCluster> it = result2.iterator();
            while (it.hasNext()) {
                arrayList.add(it.next().getGaussian());
                dArr[i] = r0.getMembers().size();
                i++;
            }
            return new PDF(arrayList, dArr);
        }

        @Override // gov.sandia.cognition.learning.algorithm.BatchLearner
        public PDF learn(Collection<? extends Vector> collection) {
            getAlgorithm().learn(collection);
            return getResult2();
        }

        @Override // gov.sandia.cognition.algorithm.MeasurablePerformanceAlgorithm
        public NamedValue<? extends Number> getPerformance() {
            return getAlgorithm().getPerformance();
        }
    }

    /* loaded from: input_file:gov/sandia/cognition/statistics/distribution/MixtureOfGaussians$PDF.class */
    public static class PDF extends MultivariateMixtureDensityModel.PDF<MultivariateGaussian> {
        public PDF(MultivariateGaussian... multivariateGaussianArr) {
            this(Arrays.asList(multivariateGaussianArr));
        }

        public PDF(Collection<? extends MultivariateGaussian> collection) {
            this(collection, null);
        }

        public PDF(Collection<? extends MultivariateGaussian> collection, double[] dArr) {
            super(collection, dArr);
        }

        /* JADX WARN: Type inference failed for: r2v1, types: [double[], java.io.Serializable] */
        public PDF(PDF pdf) {
            this(ObjectUtil.cloneSmartElementsAsArrayList(pdf.getDistributions()), (double[]) ObjectUtil.deepCopy(pdf.getPriorWeights()));
        }

        @Override // gov.sandia.cognition.statistics.distribution.MultivariateMixtureDensityModel.PDF, gov.sandia.cognition.statistics.distribution.MultivariateMixtureDensityModel, gov.sandia.cognition.statistics.distribution.LinearMixtureModel, gov.sandia.cognition.util.AbstractCloneableSerializable
        /* renamed from: clone */
        public PDF mo539clone() {
            return (PDF) super.mo539clone();
        }

        @Override // gov.sandia.cognition.statistics.distribution.MultivariateMixtureDensityModel.PDF, gov.sandia.cognition.statistics.distribution.MultivariateMixtureDensityModel, gov.sandia.cognition.statistics.ComputableDistribution
        public PDF getProbabilityFunction() {
            return this;
        }

        public int getDimensionality() {
            return ((MultivariateGaussian) CollectionUtil.getFirst(getDistributions())).getInputDimensionality();
        }

        public MultivariateGaussian.PDF fitSingleGaussian() {
            Vector mean = getMean();
            RingAccumulator ringAccumulator = new RingAccumulator();
            double priorWeightSum = getPriorWeightSum();
            for (int i = 0; i < getDistributionCount(); i++) {
                MultivariateGaussian multivariateGaussian = (MultivariateGaussian) getDistributions().get(i);
                Vector minus = multivariateGaussian.getMean().minus(mean);
                ringAccumulator.accumulate((RingAccumulator) multivariateGaussian.getCovariance().plus(minus.outerProduct(minus)).scale(this.priorWeights[i] / priorWeightSum));
            }
            return new MultivariateGaussian.PDF(mean, (Matrix) ringAccumulator.getSum());
        }

        public double computeWeightedZSquared(Vector vector) {
            double[] computeRandomVariableProbabilities = computeRandomVariableProbabilities(vector);
            double d = 0.0d;
            int i = 0;
            Iterator it = getDistributions().iterator();
            while (it.hasNext()) {
                d += ((MultivariateGaussian) it.next()).computeZSquared(vector) * computeRandomVariableProbabilities[i];
                i++;
            }
            return d;
        }
    }
}
