/*
 * Decompiled with CFR 0.152.
 */
package gov.sandia.cognition.learning.algorithm.bayes;

import gov.sandia.cognition.collection.CollectionUtil;
import gov.sandia.cognition.learning.algorithm.AbstractBatchAndIncrementalLearner;
import gov.sandia.cognition.learning.algorithm.IncrementalLearner;
import gov.sandia.cognition.learning.algorithm.SupervisedBatchLearner;
import gov.sandia.cognition.learning.data.DatasetUtil;
import gov.sandia.cognition.learning.data.DefaultWeightedValueDiscriminant;
import gov.sandia.cognition.learning.data.InputOutputPair;
import gov.sandia.cognition.learning.function.categorization.Categorizer;
import gov.sandia.cognition.learning.function.categorization.DiscriminantCategorizer;
import gov.sandia.cognition.math.LogMath;
import gov.sandia.cognition.math.RingAccumulator;
import gov.sandia.cognition.math.matrix.Vector;
import gov.sandia.cognition.math.matrix.VectorInputEvaluator;
import gov.sandia.cognition.math.matrix.Vectorizable;
import gov.sandia.cognition.statistics.DataHistogram;
import gov.sandia.cognition.statistics.DistributionEstimator;
import gov.sandia.cognition.statistics.UnivariateProbabilityDensityFunction;
import gov.sandia.cognition.statistics.distribution.MapBasedDataHistogram;
import gov.sandia.cognition.statistics.distribution.UnivariateGaussian;
import gov.sandia.cognition.util.AbstractCloneableSerializable;
import gov.sandia.cognition.util.ObjectUtil;
import java.util.ArrayList;
import java.util.Collection;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;

public class VectorNaiveBayesCategorizer<CategoryType, DistributionType extends UnivariateProbabilityDensityFunction>
extends AbstractCloneableSerializable
implements Categorizer<Vectorizable, CategoryType>,
VectorInputEvaluator<Vectorizable, CategoryType>,
DiscriminantCategorizer<Vectorizable, CategoryType, Double> {
    protected DataHistogram<CategoryType> priors;
    protected Map<CategoryType, List<DistributionType>> conditionals;

    public VectorNaiveBayesCategorizer() {
        this(new MapBasedDataHistogram(), new LinkedHashMap());
    }

    public VectorNaiveBayesCategorizer(DataHistogram<CategoryType> priors, Map<CategoryType, List<DistributionType>> conditionals) {
        this.setPriors(priors);
        this.setConditionals(conditionals);
    }

    @Override
    public VectorNaiveBayesCategorizer<CategoryType, DistributionType> clone() {
        VectorNaiveBayesCategorizer clone = (VectorNaiveBayesCategorizer)super.clone();
        clone.priors = ObjectUtil.cloneSafe(this.priors);
        clone.conditionals = new LinkedHashMap<CategoryType, List<DistributionType>>(this.conditionals.size());
        for (CategoryType category : this.conditionals.keySet()) {
            clone.conditionals.put(category, ObjectUtil.cloneSmartElementsAsArrayList((Collection)this.conditionals.get(category)));
        }
        return clone;
    }

    @Override
    public CategoryType evaluate(Vectorizable input) {
        Vector vector = input.convertToVector();
        double maxLogPosterior = Double.NEGATIVE_INFINITY;
        CategoryType maxCategory = null;
        for (CategoryType category : this.getCategories()) {
            double logPosterior = this.computeLogPosterior(vector, category);
            if (maxCategory != null && !(logPosterior > maxLogPosterior)) continue;
            maxLogPosterior = logPosterior;
            maxCategory = category;
        }
        return maxCategory;
    }

    public DefaultWeightedValueDiscriminant<CategoryType> evaluateWithDiscriminant(Vectorizable input) {
        Vector vector = input.convertToVector();
        double maxLogPosterior = Double.NEGATIVE_INFINITY;
        double logDenominator = Double.NEGATIVE_INFINITY;
        Object maxCategory = null;
        for (CategoryType category : this.getCategories()) {
            double logPosterior = this.computeLogPosterior(vector, category);
            if (maxCategory == null || logPosterior > maxLogPosterior) {
                maxLogPosterior = logPosterior;
                maxCategory = category;
            }
            logDenominator = LogMath.add(logDenominator, logPosterior);
        }
        double logMaximumLikelihood = maxLogPosterior - logDenominator;
        return DefaultWeightedValueDiscriminant.create(maxCategory, logMaximumLikelihood);
    }

    public double computePosterior(Vector input, CategoryType category) {
        return Math.exp(this.computeLogPosterior(input, category));
    }

    public double computeLogPosterior(Vector input, CategoryType category) {
        double priorProbability = this.priors.getFraction(category);
        double logPosterior = Math.log(priorProbability);
        List<DistributionType> probabilityFunctions = this.conditionals.get(category);
        int size = probabilityFunctions.size();
        for (int i = 0; i < size; ++i) {
            double value = input.getElement(i);
            double x = ((UnivariateProbabilityDensityFunction)probabilityFunctions.get(i)).logEvaluate(value);
            logPosterior += x;
        }
        return logPosterior;
    }

    @Override
    public Set<CategoryType> getCategories() {
        return this.conditionals.keySet();
    }

    @Override
    public int getInputDimensionality() {
        List<DistributionType> first = CollectionUtil.getFirst(this.conditionals.values());
        return first == null ? 0 : first.size();
    }

    public DataHistogram<CategoryType> getPriors() {
        return this.priors;
    }

    public void setPriors(DataHistogram<CategoryType> priors) {
        this.priors = priors;
    }

    public Map<CategoryType, List<DistributionType>> getConditionals() {
        return this.conditionals;
    }

    public void setConditionals(Map<CategoryType, List<DistributionType>> conditionals) {
        this.conditionals = conditionals;
    }

    public static class OnlineLearner<CategoryType, DistributionType extends UnivariateProbabilityDensityFunction>
    extends AbstractBatchAndIncrementalLearner<InputOutputPair<? extends Vectorizable, CategoryType>, VectorNaiveBayesCategorizer<CategoryType, DistributionType>> {
        protected IncrementalLearner<? super Double, DistributionType> distributionLearner;

        public OnlineLearner() {
            this(null);
        }

        public OnlineLearner(IncrementalLearner<? super Double, DistributionType> distributionLearner) {
            this.setDistributionLearner(distributionLearner);
        }

        @Override
        public VectorNaiveBayesCategorizer<CategoryType, DistributionType> createInitialLearnedObject() {
            return new VectorNaiveBayesCategorizer();
        }

        @Override
        public void update(VectorNaiveBayesCategorizer<CategoryType, DistributionType> target, InputOutputPair<? extends Vectorizable, CategoryType> data) {
            int i;
            Vector input = data.getInput().convertToVector();
            CategoryType category = data.getOutput();
            target.getPriors().add(category);
            List<DistributionType> conditionals = target.getConditionals().get(category);
            int dimensionality = input.getDimensionality();
            if (conditionals == null) {
                conditionals = new ArrayList<DistributionType>(dimensionality);
                for (i = 0; i < dimensionality; ++i) {
                    conditionals.add(this.distributionLearner.createInitialLearnedObject());
                }
                target.getConditionals().put(category, conditionals);
            }
            for (i = 0; i < dimensionality; ++i) {
                UnivariateProbabilityDensityFunction conditional = (UnivariateProbabilityDensityFunction)conditionals.get(i);
                this.distributionLearner.update(conditional, input.getElement(i));
            }
        }

        public IncrementalLearner<? super Double, DistributionType> getDistributionLearner() {
            return this.distributionLearner;
        }

        public void setDistributionLearner(IncrementalLearner<? super Double, DistributionType> distributionLearner) {
            this.distributionLearner = distributionLearner;
        }
    }

    public static class BatchGaussianLearner<CategoryType>
    extends AbstractCloneableSerializable
    implements SupervisedBatchLearner<Vectorizable, CategoryType, VectorNaiveBayesCategorizer<CategoryType, UnivariateGaussian.PDF>> {
        @Override
        public VectorNaiveBayesCategorizer<CategoryType, UnivariateGaussian.PDF> learn(Collection<? extends InputOutputPair<? extends Vectorizable, CategoryType>> data) {
            int dimensionality = DatasetUtil.getInputDimensionality(data);
            Map examplesPerCategory = DatasetUtil.splitOnOutput(data);
            VectorNaiveBayesCategorizer result = new VectorNaiveBayesCategorizer();
            for (Object category : examplesPerCategory.keySet()) {
                List examples = examplesPerCategory.get(category);
                RingAccumulator<Vector> sumsAccumulator = new RingAccumulator<Vector>();
                RingAccumulator<Vector> sumsOfSquaresAccumulator = new RingAccumulator<Vector>();
                for (Vectorizable input : examples) {
                    Vector vector = input.convertToVector();
                    sumsAccumulator.accumulate(vector);
                    sumsOfSquaresAccumulator.accumulate(vector.dotTimes(vector));
                }
                Vector sums = (Vector)sumsAccumulator.getSum();
                Vector sumsOfSquares = (Vector)sumsOfSquaresAccumulator.getSum();
                int count = examples.size();
                long varianceDenominator = count > 1 ? (long)(count - 1) : 1L;
                ArrayList<UnivariateGaussian.PDF> conditionals = new ArrayList<UnivariateGaussian.PDF>(dimensionality);
                for (int i = 0; i < dimensionality; ++i) {
                    double sum = sums.getElement(i);
                    double sumOfSquares = sumsOfSquares.getElement(i);
                    double mean = sum / (double)count;
                    double variance = (sumOfSquares - sum * mean) / (double)varianceDenominator;
                    conditionals.add(new UnivariateGaussian.PDF(mean, variance));
                }
                result.priors.add(category, count);
                result.conditionals.put(category, conditionals);
            }
            return result;
        }
    }

    public static class Learner<CategoryType, DistributionType extends UnivariateProbabilityDensityFunction>
    extends AbstractCloneableSerializable
    implements SupervisedBatchLearner<Vectorizable, CategoryType, VectorNaiveBayesCategorizer<CategoryType, DistributionType>> {
        protected DistributionEstimator<? super Double, ? extends DistributionType> distributionEstimator;

        public Learner() {
            this(null);
        }

        public Learner(DistributionEstimator<? super Double, ? extends DistributionType> distributionEstimator) {
            this.setDistributionEstimator(distributionEstimator);
        }

        @Override
        public VectorNaiveBayesCategorizer<CategoryType, DistributionType> learn(Collection<? extends InputOutputPair<? extends Vectorizable, CategoryType>> data) {
            int dimensionality = DatasetUtil.getInputDimensionality(data);
            Map examplesPerCategory = DatasetUtil.splitOnOutput(data);
            VectorNaiveBayesCategorizer result = new VectorNaiveBayesCategorizer();
            ArrayList<Double> values = new ArrayList<Double>(data.size());
            for (Object category : examplesPerCategory.keySet()) {
                List examples = examplesPerCategory.get(category);
                int count = examples.size();
                ArrayList conditionals = new ArrayList(dimensionality);
                for (int i = 0; i < dimensionality; ++i) {
                    for (Vectorizable input : examples) {
                        values.add(input.convertToVector().getElement(i));
                    }
                    conditionals.add(this.distributionEstimator.learn(values));
                    values.clear();
                }
                result.priors.add(category, count);
                result.conditionals.put(category, conditionals);
            }
            return result;
        }

        public DistributionEstimator<? super Double, ? extends DistributionType> getDistributionEstimator() {
            return this.distributionEstimator;
        }

        public void setDistributionEstimator(DistributionEstimator<? super Double, ? extends DistributionType> distributionEstimator) {
            this.distributionEstimator = distributionEstimator;
        }
    }
}

