/*
 * Decompiled with CFR 0.152.
 */
package gov.sandia.cognition.statistics.distribution;

import gov.sandia.cognition.annotation.PublicationReference;
import gov.sandia.cognition.annotation.PublicationType;
import gov.sandia.cognition.math.matrix.Vector;
import gov.sandia.cognition.math.matrix.VectorFactory;
import gov.sandia.cognition.math.matrix.VectorInputEvaluator;
import gov.sandia.cognition.statistics.AbstractDistribution;
import gov.sandia.cognition.statistics.ClosedFormComputableDiscreteDistribution;
import gov.sandia.cognition.statistics.ProbabilityMassFunction;
import gov.sandia.cognition.statistics.ProbabilityMassFunctionUtil;
import gov.sandia.cognition.util.ObjectUtil;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Random;

@PublicationReference(author={"Wikipedia"}, title="Categoical Distribution", type=PublicationType.WebPage, year=2011, url="http://en.wikipedia.org/wiki/Categorical_distribution")
public class CategoricalDistribution
extends AbstractDistribution<Vector>
implements ClosedFormComputableDiscreteDistribution<Vector> {
    public static final int DEFAULT_NUM_CLASSES = 2;
    protected Vector parameters;

    public CategoricalDistribution() {
        this(2);
    }

    public CategoricalDistribution(int numClasses) {
        this(VectorFactory.getDefault().createVector(numClasses, 1.0));
    }

    public CategoricalDistribution(Vector parameters) {
        this.setParameters(parameters);
    }

    public CategoricalDistribution(CategoricalDistribution other) {
        this(ObjectUtil.cloneSafe(other.getParameters()));
    }

    @Override
    public CategoricalDistribution clone() {
        CategoricalDistribution clone = (CategoricalDistribution)super.clone();
        clone.setParameters(ObjectUtil.cloneSafe(this.getParameters()));
        return clone;
    }

    public Vector getParameters() {
        return this.parameters;
    }

    public void setParameters(Vector parameters) {
        int N = parameters.getDimensionality();
        if (N < 2) {
            throw new IllegalArgumentException("Dimensionality must be >= 2");
        }
        for (int i = 0; i < N; ++i) {
            if (!(parameters.getElement(i) < 0.0)) continue;
            throw new IllegalArgumentException("All parameter elements must be >= 0.0");
        }
        this.parameters = parameters;
    }

    @Override
    public ArrayList<Vector> sample(Random random, int numSamples) {
        Collection domain = this.getDomain();
        int N = ((ArrayList)domain).size();
        double[] cumulativeWeights = new double[N];
        double sum = 0.0;
        for (int n = 0; n < N; ++n) {
            double weight = this.parameters.getElement(n);
            cumulativeWeights[n] = sum += weight;
        }
        return ProbabilityMassFunctionUtil.sampleMultiple(cumulativeWeights, sum, domain, random, numSamples);
    }

    @Override
    public Vector getMean() {
        return (Vector)this.parameters.scale(this.parameters.norm1());
    }

    @Override
    public Vector convertToVector() {
        return this.parameters.clone();
    }

    @Override
    public void convertFromVector(Vector parameters) {
        this.parameters.assertSameDimensionality(parameters);
        this.setParameters(parameters);
    }

    public int getInputDimensionality() {
        return this.getParameters().getDimensionality();
    }

    @Override
    public ArrayList<Vector> getDomain() {
        int N = this.getInputDimensionality();
        ArrayList<Vector> domain = new ArrayList<Vector>(N);
        for (int n = 0; n < N; ++n) {
            Vector x = VectorFactory.getDefault().createVector(N);
            x.setElement(n, 1.0);
            domain.add(x);
        }
        return domain;
    }

    @Override
    public int getDomainSize() {
        return this.getInputDimensionality();
    }

    public PMF getProbabilityFunction() {
        return new PMF(this);
    }

    public static class PMF
    extends CategoricalDistribution
    implements ProbabilityMassFunction<Vector>,
    VectorInputEvaluator<Vector, Double> {
        public PMF() {
        }

        public PMF(int numClasses) {
            super(numClasses);
        }

        public PMF(Vector parameters) {
            super(parameters);
        }

        public PMF(CategoricalDistribution other) {
            super(other);
        }

        @Override
        public double getEntropy() {
            return ProbabilityMassFunctionUtil.getEntropy(this);
        }

        @Override
        public double logEvaluate(Vector input) {
            return Math.log(this.evaluate(input));
        }

        @Override
        public Double evaluate(Vector input) {
            this.parameters.assertSameDimensionality(input);
            double pi = -1.0;
            int N = this.getInputDimensionality();
            double sum = 0.0;
            for (int n = 0; n < N; ++n) {
                double p = this.parameters.getElement(n);
                sum += p;
                double x = input.getElement(n);
                if (x == 1.0) {
                    if (pi < 0.0) {
                        pi = p;
                        continue;
                    }
                    throw new IllegalArgumentException("input must only have one entry equal to 1.0!");
                }
                if (x == 0.0) continue;
                throw new IllegalArgumentException("input entries must be either 0.0 or 1.0");
            }
            if (pi < 0.0) {
                throw new IllegalArgumentException("input must have one entry equal to 1.0!");
            }
            return pi / sum;
        }

        @Override
        public PMF getProbabilityFunction() {
            return this;
        }
    }
}

