/*
 * Decompiled with CFR 0.152.
 */
package org.openimaj.math.model;

import gov.sandia.cognition.learning.algorithm.bayes.VectorNaiveBayesCategorizer;
import gov.sandia.cognition.learning.data.DefaultInputOutputPair;
import gov.sandia.cognition.math.matrix.Vector1D;
import gov.sandia.cognition.math.matrix.VectorFactory;
import gov.sandia.cognition.statistics.DataHistogram;
import gov.sandia.cognition.statistics.distribution.UnivariateGaussian;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.openimaj.math.model.EstimatableModel;
import org.openimaj.util.pair.IndependentPair;

public class UnivariateGaussianNaiveBayesModel<T>
implements EstimatableModel<Double, T> {
    private VectorNaiveBayesCategorizer<T, UnivariateGaussian.PDF> model;

    public UnivariateGaussianNaiveBayesModel() {
    }

    public UnivariateGaussianNaiveBayesModel(VectorNaiveBayesCategorizer<T, UnivariateGaussian.PDF> model) {
        this.model = model;
    }

    @Override
    public boolean estimate(List<? extends IndependentPair<Double, T>> data) {
        VectorNaiveBayesCategorizer.BatchGaussianLearner learner = new VectorNaiveBayesCategorizer.BatchGaussianLearner();
        ArrayList<DefaultInputOutputPair<Vector1D, T>> cfdata = new ArrayList<DefaultInputOutputPair<Vector1D, T>>();
        for (IndependentPair<Double, T> d : data) {
            DefaultInputOutputPair<Vector1D, T> iop = new DefaultInputOutputPair<Vector1D, T>(VectorFactory.getDefault().createVector1D(d.firstObject()), d.secondObject());
            cfdata.add(iop);
        }
        this.model = learner.learn(cfdata);
        return true;
    }

    @Override
    public T predict(Double data) {
        return this.model.evaluate(VectorFactory.getDefault().createVector1D(data));
    }

    @Override
    public int numItemsToEstimate() {
        return 0;
    }

    public UnivariateGaussianNaiveBayesModel<T> clone() {
        try {
            return (UnivariateGaussianNaiveBayesModel)super.clone();
        }
        catch (CloneNotSupportedException e) {
            throw new RuntimeException(e);
        }
    }

    public UnivariateGaussian getClassDistribution(T clz) {
        return this.model.getConditionals().get(clz).get(0);
    }

    public Map<T, UnivariateGaussian> getClassDistribution() {
        HashMap<T, UnivariateGaussian.PDF> clzs = new HashMap<T, UnivariateGaussian.PDF>();
        for (T c : this.model.getCategories()) {
            clzs.put(c, this.model.getConditionals().get(c).get(0));
        }
        return clzs;
    }

    public DataHistogram<T> getClassPriors() {
        return this.model.getPriors();
    }

    public static void main(String[] args) {
        UnivariateGaussianNaiveBayesModel model = new UnivariateGaussianNaiveBayesModel();
        ArrayList<IndependentPair<Double, Boolean>> data = new ArrayList<IndependentPair<Double, Boolean>>();
        data.add(IndependentPair.pair(0.0, true));
        data.add(IndependentPair.pair(0.1, true));
        data.add(IndependentPair.pair(-0.1, true));
        data.add(IndependentPair.pair(9.9, false));
        data.add(IndependentPair.pair(10.0, false));
        data.add(IndependentPair.pair(10.1, false));
        model.estimate((List)data);
        System.out.println(model.predict(5.1));
        System.out.println(model.model.getConditionals().get(true));
        System.out.println(model.model.getConditionals().get(false));
        System.out.println(model.model.getConditionals().get(true).get(0).getMean());
        System.out.println(model.model.getConditionals().get(true).get(0).getVariance());
        System.out.println(model.model.getConditionals().get(false).get(0).getMean());
        System.out.println(model.model.getConditionals().get(false).get(0).getVariance());
        System.out.println(model.model.getPriors());
    }
}

