package gov.sandia.cognition.statistics.distribution;

import gov.sandia.cognition.annotation.PublicationReference;
import gov.sandia.cognition.annotation.PublicationType;
import gov.sandia.cognition.math.MathUtil;
import gov.sandia.cognition.math.matrix.Vector;
import gov.sandia.cognition.math.matrix.VectorFactory;
import gov.sandia.cognition.math.matrix.VectorInputEvaluator;
import gov.sandia.cognition.math.matrix.Vectorizable;
import gov.sandia.cognition.statistics.AbstractDistribution;
import gov.sandia.cognition.statistics.ClosedFormComputableDiscreteDistribution;
import gov.sandia.cognition.statistics.ProbabilityMassFunction;
import gov.sandia.cognition.statistics.ProbabilityMassFunctionUtil;
import gov.sandia.cognition.util.AbstractCloneableSerializable;
import gov.sandia.cognition.util.CloneableSerializable;
import gov.sandia.cognition.util.ObjectUtil;
import java.util.AbstractCollection;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Iterator;
import java.util.Random;

@PublicationReference(author = {"Wikipedia"}, title = "Multinomial distribution", type = PublicationType.WebPage, year = 2009, url = "http://en.wikipedia.org/wiki/Multinomial_distribution")
/* loaded from: input_file:gov/sandia/cognition/statistics/distribution/MultinomialDistribution.class */
public class MultinomialDistribution extends AbstractDistribution<Vector> implements ClosedFormComputableDiscreteDistribution<Vector> {
    public static final int DEFAULT_NUM_CLASSES = 2;
    public static final int DEFAULT_NUM_TRIALS = 1;
    private int numTrials;
    private Vector parameters;

    /* loaded from: input_file:gov/sandia/cognition/statistics/distribution/MultinomialDistribution$Domain.class */
    public static class Domain extends AbstractCollection<Vector> {
        private int numClasses;
        private int numTrials;

        /* loaded from: input_file:gov/sandia/cognition/statistics/distribution/MultinomialDistribution$Domain$MultinomialIterator.class */
        protected static class MultinomialIterator extends AbstractCloneableSerializable implements Iterator<Vector> {
            private int value;
            private int numClasses;
            private int numTrials;
            private MultinomialIterator child;

            public MultinomialIterator(int i, int i2) {
                if (i <= 0) {
                    throw new IllegalArgumentException("NumClasses <= 0");
                }
                this.numClasses = i;
                if (i2 < 0) {
                    throw new IllegalArgumentException("numTrials < 0");
                }
                this.numTrials = i2;
                if (this.numClasses <= 1) {
                    this.value = this.numTrials;
                } else {
                    this.value = 0;
                }
                if (this.numClasses > 1) {
                    this.child = new MultinomialIterator(this.numClasses - 1, this.numTrials - this.value);
                }
            }

            @Override // java.util.Iterator
            public boolean hasNext() {
                if (this.value < this.numTrials) {
                    return true;
                }
                return this.child == null ? this.value <= this.numTrials : this.child.hasNext();
            }

            /* JADX WARN: Can't rename method to resolve collision */
            @Override // java.util.Iterator
            public Vector next() {
                if (this.child == null) {
                    Vector createVector = VectorFactory.getDefault().createVector(1, this.value);
                    this.value++;
                    return createVector;
                }
                if (!this.child.hasNext()) {
                    this.value++;
                    this.child = new MultinomialIterator(this.numClasses - 1, this.numTrials - this.value);
                }
                return VectorFactory.getDefault().createVector(1, this.value).stack(this.child.next());
            }

            @Override // java.util.Iterator
            public void remove() {
                throw new UnsupportedOperationException("Cannot remove from MultinomialDomain");
            }
        }

        public Domain(int i, int i2) {
            this.numClasses = i;
            this.numTrials = i2;
        }

        @Override // java.util.AbstractCollection, java.util.Collection, java.lang.Iterable
        public Iterator<Vector> iterator() {
            return new MultinomialIterator(this.numClasses, this.numTrials);
        }

        @Override // java.util.AbstractCollection, java.util.Collection
        public int size() {
            return MathUtil.binomialCoefficient((this.numClasses + this.numTrials) - 1, this.numClasses - 1);
        }

        public double logSize() {
            return MathUtil.logBinomialCoefficient((this.numClasses + this.numTrials) - 1, this.numClasses - 1);
        }
    }

    /* loaded from: input_file:gov/sandia/cognition/statistics/distribution/MultinomialDistribution$PMF.class */
    public static class PMF extends MultinomialDistribution implements ProbabilityMassFunction<Vector>, VectorInputEvaluator<Vector, Double> {
        public PMF() {
        }

        public PMF(int i, int i2) {
            super(i, i2);
        }

        public PMF(Vector vector, int i) {
            super(vector, i);
        }

        public PMF(MultinomialDistribution multinomialDistribution) {
            super(multinomialDistribution);
        }

        @Override // gov.sandia.cognition.math.matrix.VectorInputEvaluator
        public int getInputDimensionality() {
            return getParameters().getDimensionality();
        }

        @Override // gov.sandia.cognition.evaluator.Evaluator
        public Double evaluate(Vector vector) {
            return Double.valueOf(Math.exp(logEvaluate(vector)));
        }

        @Override // gov.sandia.cognition.statistics.ProbabilityFunction
        public double logEvaluate(Vector vector) {
            int inputDimensionality = getInputDimensionality();
            vector.assertDimensionalityEquals(inputDimensionality);
            Vector parameters = getParameters();
            double norm1 = parameters.norm1();
            double logFactorial = MathUtil.logFactorial(getNumTrials());
            double d = 0.0d;
            int i = 0;
            for (int i2 = 0; i2 < inputDimensionality; i2++) {
                int element = (int) vector.getElement(i2);
                i += element;
                double element2 = parameters.getElement(i2) / norm1;
                if (element2 < 0.0d) {
                    throw new IllegalArgumentException("pi < 0.0" + parameters);
                }
                if (element2 == 0.0d) {
                    if (element != 0) {
                        return Math.log(0.0d);
                    }
                } else if (element != 0) {
                    logFactorial -= MathUtil.logFactorial(element);
                    d += element * Math.log(element2);
                }
            }
            if (i != getNumTrials()) {
                throw new IllegalArgumentException("Integer input sum != num trials");
            }
            return logFactorial + d;
        }

        @Override // gov.sandia.cognition.statistics.ProbabilityMassFunction
        public double getEntropy() {
            return ProbabilityMassFunctionUtil.getEntropy(this);
        }

        @Override // gov.sandia.cognition.statistics.distribution.MultinomialDistribution, gov.sandia.cognition.statistics.ComputableDistribution
        public PMF getProbabilityFunction() {
            return this;
        }

        @Override // gov.sandia.cognition.statistics.distribution.MultinomialDistribution, gov.sandia.cognition.statistics.DistributionWithMean
        public /* bridge */ /* synthetic */ Object getMean() {
            return super.getMean();
        }

        @Override // gov.sandia.cognition.statistics.distribution.MultinomialDistribution, gov.sandia.cognition.util.AbstractCloneableSerializable, gov.sandia.cognition.util.CloneableSerializable
        /* renamed from: clone */
        public /* bridge */ /* synthetic */ CloneableSerializable mo539clone() {
            return super.mo539clone();
        }

        @Override // gov.sandia.cognition.statistics.distribution.MultinomialDistribution, gov.sandia.cognition.util.AbstractCloneableSerializable, gov.sandia.cognition.util.CloneableSerializable
        /* renamed from: clone */
        public /* bridge */ /* synthetic */ Vectorizable mo539clone() {
            return super.mo539clone();
        }

        @Override // gov.sandia.cognition.statistics.distribution.MultinomialDistribution, gov.sandia.cognition.statistics.DiscreteDistribution
        public /* bridge */ /* synthetic */ Collection getDomain() {
            return super.getDomain();
        }

        @Override // gov.sandia.cognition.statistics.distribution.MultinomialDistribution, gov.sandia.cognition.util.AbstractCloneableSerializable
        /* renamed from: clone */
        public /* bridge */ /* synthetic */ Object mo539clone() throws CloneNotSupportedException {
            return super.mo539clone();
        }
    }

    public MultinomialDistribution() {
        this(2, 1);
    }

    public MultinomialDistribution(int i, int i2) {
        this(VectorFactory.getDefault().createVector(i, 1.0d), i2);
    }

    public MultinomialDistribution(Vector vector, int i) {
        setNumTrials(i);
        setParameters(vector);
    }

    public MultinomialDistribution(MultinomialDistribution multinomialDistribution) {
        this((Vector) ObjectUtil.cloneSafe(multinomialDistribution.getParameters()), multinomialDistribution.getNumTrials());
    }

    @Override // gov.sandia.cognition.util.AbstractCloneableSerializable, gov.sandia.cognition.util.CloneableSerializable
    /* renamed from: clone */
    public MultinomialDistribution mo539clone() {
        MultinomialDistribution multinomialDistribution = (MultinomialDistribution) super.mo539clone();
        multinomialDistribution.setParameters((Vector) ObjectUtil.cloneSafe(getParameters()));
        return multinomialDistribution;
    }

    public Vector getParameters() {
        return this.parameters;
    }

    public void setParameters(Vector vector) {
        int dimensionality = vector.getDimensionality();
        if (dimensionality < 2) {
            throw new IllegalArgumentException("Dimensionality must be >= 2");
        }
        for (int i = 0; i < dimensionality; i++) {
            if (vector.getElement(i) < 0.0d) {
                throw new IllegalArgumentException("All parameter elements must be >= 0.0");
            }
        }
        this.parameters = vector;
    }

    @Override // gov.sandia.cognition.math.matrix.Vectorizable
    public Vector convertToVector() {
        return (Vector) ObjectUtil.cloneSafe(getParameters());
    }

    @Override // gov.sandia.cognition.math.matrix.Vectorizable
    public void convertFromVector(Vector vector) {
        vector.assertSameDimensionality(getParameters());
        setParameters((Vector) ObjectUtil.cloneSafe(vector));
    }

    public int getNumTrials() {
        return this.numTrials;
    }

    public void setNumTrials(int i) {
        if (i <= 0) {
            throw new IllegalArgumentException("numTrials must be > 0");
        }
        this.numTrials = i;
    }

    @Override // gov.sandia.cognition.statistics.DistributionWithMean
    public Vector getMean() {
        return this.parameters.scale(this.numTrials / this.parameters.norm1());
    }

    @Override // gov.sandia.cognition.statistics.Distribution
    public ArrayList<Vector> sample(Random random, int i) {
        int dimensionality = this.parameters.getDimensionality();
        double[] dArr = new double[dimensionality];
        double norm1 = this.parameters.norm1();
        for (int i2 = 0; i2 < dimensionality; i2++) {
            dArr[i2] = this.parameters.getElement(i2) / norm1;
        }
        ArrayList<Vector> arrayList = new ArrayList<>(i);
        for (int i3 = 0; i3 < i; i3++) {
            double[] dArr2 = new double[dimensionality];
            for (int i4 = 0; i4 < this.numTrials; i4++) {
                double nextDouble = random.nextDouble();
                int i5 = 0;
                while (true) {
                    if (i5 >= dimensionality) {
                        break;
                    }
                    if (nextDouble <= dArr[i5]) {
                        int i6 = i5;
                        dArr2[i6] = dArr2[i6] + 1.0d;
                        break;
                    }
                    nextDouble -= dArr[i5];
                    i5++;
                }
            }
            arrayList.add(VectorFactory.getDefault().copyArray(dArr2));
        }
        return arrayList;
    }

    @Override // gov.sandia.cognition.statistics.DiscreteDistribution
    public Domain getDomain() {
        return new Domain(getParameters().getDimensionality(), getNumTrials());
    }

    @Override // gov.sandia.cognition.statistics.DiscreteDistribution
    public int getDomainSize() {
        return getDomain().size();
    }

    @Override // gov.sandia.cognition.statistics.ComputableDistribution
    public PMF getProbabilityFunction() {
        return new PMF(this);
    }
}
