/*
 * Decompiled with CFR 0.152.
 */
package gov.sandia.cognition.statistics.distribution;

import gov.sandia.cognition.annotation.PublicationReference;
import gov.sandia.cognition.annotation.PublicationType;
import gov.sandia.cognition.math.MathUtil;
import gov.sandia.cognition.math.matrix.Vector;
import gov.sandia.cognition.math.matrix.VectorFactory;
import gov.sandia.cognition.math.matrix.VectorInputEvaluator;
import gov.sandia.cognition.statistics.AbstractDistribution;
import gov.sandia.cognition.statistics.ClosedFormComputableDiscreteDistribution;
import gov.sandia.cognition.statistics.ProbabilityMassFunction;
import gov.sandia.cognition.statistics.ProbabilityMassFunctionUtil;
import gov.sandia.cognition.statistics.distribution.DirichletDistribution;
import gov.sandia.cognition.statistics.distribution.MultinomialDistribution;
import gov.sandia.cognition.util.ObjectUtil;
import java.util.ArrayList;
import java.util.Random;

@PublicationReference(author={"Wikipedia"}, title="Multivariate Polya Distribution", type=PublicationType.WebPage, year=2010, url="http://en.wikipedia.org/wiki/Multivariate_Polya_distribution")
public class MultivariatePolyaDistribution
extends AbstractDistribution<Vector>
implements ClosedFormComputableDiscreteDistribution<Vector> {
    public static final int DEFAULT_NUM_TRIALS = 1;
    public static final int DEFAULT_DIMENSIONALITY = 2;
    protected Vector parameters;
    private int numTrials;

    public MultivariatePolyaDistribution() {
        this(2, 1);
    }

    public MultivariatePolyaDistribution(int dimensionality, int numTrials) {
        this(VectorFactory.getDefault().createVector(dimensionality, 1.0), numTrials);
    }

    public MultivariatePolyaDistribution(Vector parameters, int numTrials) {
        this.setParameters(parameters);
        this.setNumTrials(numTrials);
    }

    public MultivariatePolyaDistribution(MultivariatePolyaDistribution other) {
        this(ObjectUtil.cloneSafe(other.getParameters()), other.getNumTrials());
    }

    @Override
    public MultivariatePolyaDistribution clone() {
        MultivariatePolyaDistribution clone = (MultivariatePolyaDistribution)super.clone();
        clone.setParameters(ObjectUtil.cloneSafe(this.getParameters()));
        return clone;
    }

    @Override
    public Vector getMean() {
        return (Vector)this.parameters.scale((double)this.numTrials / this.parameters.norm1());
    }

    @Override
    public ArrayList<Vector> sample(Random random, int numSamples) {
        DirichletDistribution prior = new DirichletDistribution(this.parameters);
        ArrayList<Vector> dirichletSamples = prior.sample(random, numSamples);
        int dim = this.getInputDimensionality();
        int N = this.getNumTrials();
        MultinomialDistribution conditional = new MultinomialDistribution(dim, N);
        conditional.setNumTrials(N);
        ArrayList<Vector> samples = new ArrayList<Vector>(numSamples);
        for (int i = 0; i < numSamples; ++i) {
            conditional.setParameters(dirichletSamples.get(i));
            samples.add((Vector)conditional.sample(random));
        }
        return samples;
    }

    public int getNumTrials() {
        return this.numTrials;
    }

    public void setNumTrials(int numTrials) {
        if (numTrials <= 0) {
            throw new IllegalArgumentException("numTrials must be > 0");
        }
        this.numTrials = numTrials;
    }

    public PMF getProbabilityFunction() {
        return new PMF(this);
    }

    @Override
    public Vector convertToVector() {
        return ObjectUtil.cloneSafe(this.getParameters());
    }

    @Override
    public void convertFromVector(Vector parameters) {
        parameters.assertSameDimensionality(this.getParameters());
        this.setParameters(ObjectUtil.cloneSafe(parameters));
    }

    public int getInputDimensionality() {
        return this.parameters != null ? this.parameters.getDimensionality() : 0;
    }

    public Vector getParameters() {
        return this.parameters;
    }

    public void setParameters(Vector parameters) {
        int N = parameters.getDimensionality();
        if (N < 2) {
            throw new IllegalArgumentException("Dimensionality must be >= 2");
        }
        for (int i = 0; i < N; ++i) {
            if (!(parameters.getElement(i) <= 0.0)) continue;
            throw new IllegalArgumentException("All parameter elements must be > 0.0");
        }
        this.parameters = parameters;
    }

    public MultinomialDistribution.Domain getDomain() {
        return new MultinomialDistribution.Domain(this.getInputDimensionality(), this.getNumTrials());
    }

    @Override
    public int getDomainSize() {
        return this.getDomain().size();
    }

    public String toString() {
        return "N = " + this.getNumTrials() + ", Parameters = " + this.getParameters();
    }

    public static class PMF
    extends MultivariatePolyaDistribution
    implements ProbabilityMassFunction<Vector>,
    VectorInputEvaluator<Vector, Double> {
        public PMF() {
        }

        public PMF(int dimensionality, int numTrials) {
            super(dimensionality, numTrials);
        }

        public PMF(Vector parameters, int numTrials) {
            super(parameters, numTrials);
        }

        public PMF(MultivariatePolyaDistribution other) {
            super(other);
        }

        @Override
        public PMF getProbabilityFunction() {
            return this;
        }

        @Override
        public double logEvaluate(Vector input) {
            int dim = this.getInputDimensionality();
            input.assertDimensionalityEquals(dim);
            int ni = (int)Math.round(input.norm1());
            int N = this.getNumTrials();
            double A = this.parameters.norm1();
            if (ni != N) {
                return Math.log(0.0);
            }
            double logSum = 0.0;
            logSum += Math.log(ni);
            logSum += MathUtil.logBetaFunction(A, ni);
            for (int i = 0; i < dim; ++i) {
                double pi = this.parameters.getElement(i);
                double xi = input.getElement(i);
                if (!(pi > 0.0) || !(xi > 0.0)) continue;
                logSum -= Math.log(xi);
                logSum -= MathUtil.logBetaFunction(pi, xi);
            }
            return logSum;
        }

        @Override
        public Double evaluate(Vector input) {
            return Math.exp(this.logEvaluate(input));
        }

        @Override
        public double getEntropy() {
            return ProbabilityMassFunctionUtil.getEntropy(this);
        }
    }
}

