/*
 * Decompiled with CFR 0.152.
 */
package gov.sandia.cognition.statistics.bayesian;

import gov.sandia.cognition.annotation.PublicationReference;
import gov.sandia.cognition.annotation.PublicationReferences;
import gov.sandia.cognition.annotation.PublicationType;
import gov.sandia.cognition.evaluator.Evaluator;
import gov.sandia.cognition.learning.algorithm.IncrementalLearner;
import gov.sandia.cognition.learning.data.DatasetUtil;
import gov.sandia.cognition.learning.data.InputOutputPair;
import gov.sandia.cognition.math.RingAccumulator;
import gov.sandia.cognition.math.matrix.Matrix;
import gov.sandia.cognition.math.matrix.MatrixFactory;
import gov.sandia.cognition.math.matrix.Vector;
import gov.sandia.cognition.math.matrix.VectorFactory;
import gov.sandia.cognition.statistics.AbstractSufficientStatistic;
import gov.sandia.cognition.statistics.bayesian.AbstractBayesianRegression;
import gov.sandia.cognition.statistics.distribution.MultivariateGaussian;
import gov.sandia.cognition.statistics.distribution.UnivariateGaussian;
import gov.sandia.cognition.util.AbstractCloneableSerializable;
import gov.sandia.cognition.util.ObjectUtil;
import java.util.Collection;

@PublicationReferences(references={@PublicationReference(author={"Christopher M. Bishop"}, title="Pattern Recognition and Machine Learning", type=PublicationType.Book, year=2006, pages={152, 159}), @PublicationReference(author={"Hanna M. Wallach"}, title="Introduction to Gaussian Process Regression", type=PublicationType.Misc, year=2005, url="http://www.cs.umass.edu/~wallach/talks/gp_intro.pdf"), @PublicationReference(author={"Wikipedia"}, title="Bayesian linear regression", type=PublicationType.WebPage, year=2010, url="http://en.wikipedia.org/wiki/Bayesian_linear_regression")})
public class BayesianLinearRegression<InputType>
extends AbstractBayesianRegression<InputType, Double, MultivariateGaussian> {
    public static final double DEFAULT_OUTPUT_VARIANCE = 1.0;
    public static final double DEFAULT_WEIGHT_VARIANCE = 1.0;
    protected double outputVariance;
    protected MultivariateGaussian weightPrior;

    public BayesianLinearRegression(int dimensionality) {
        this(null, 1.0, new MultivariateGaussian(VectorFactory.getDefault().createVector(dimensionality), (Matrix)MatrixFactory.getDefault().createIdentity(dimensionality, dimensionality).scale(1.0)));
    }

    public BayesianLinearRegression(Evaluator<? super InputType, Vector> featureMap, double outputVariance, MultivariateGaussian weightPrior) {
        super(featureMap);
        this.setOutputVariance(outputVariance);
        this.setWeightPrior(weightPrior);
    }

    @Override
    public BayesianLinearRegression<InputType> clone() {
        BayesianLinearRegression clone = (BayesianLinearRegression)super.clone();
        clone.setWeightPrior(ObjectUtil.cloneSafe(this.getWeightPrior()));
        return clone;
    }

    @Override
    public MultivariateGaussian.PDF learn(Collection<? extends InputOutputPair<? extends InputType, Double>> data) {
        MultivariateGaussian prior = this.getWeightPrior();
        RingAccumulator<Matrix> Cin = new RingAccumulator<Matrix>();
        Matrix Ci = prior.getCovarianceInverse().clone();
        Cin.accumulate(Ci);
        RingAccumulator<Vector> zn = new RingAccumulator<Vector>();
        Vector z = Ci.times(prior.getMean());
        zn.accumulate(z);
        for (InputOutputPair<InputType, Double> pair : data) {
            Vector x1 = (Vector)this.featureMap.evaluate(pair.getInput());
            Vector x2 = x1.clone();
            double beta = DatasetUtil.getWeight(pair) / this.outputVariance;
            if (beta != 1.0) {
                x2.scaleEquals(beta);
            }
            Cin.accumulate(x1.outerProduct(x2));
            double y = pair.getOutput();
            if (y != 1.0) {
                x2.scaleEquals(y);
            }
            zn.accumulate(x2);
        }
        Ci = (Matrix)Cin.getSum();
        Matrix C = Ci.inverse();
        z = (Vector)zn.getSum();
        Vector mean = C.times(z);
        return new MultivariateGaussian.PDF(mean, C);
    }

    public UnivariateGaussian createConditionalDistribution(InputType input, Vector weights) {
        double mean = ((Vector)this.featureMap.evaluate(input)).dotProduct(weights);
        return new UnivariateGaussian(mean, this.getOutputVariance());
    }

    public MultivariateGaussian getWeightPrior() {
        return this.weightPrior;
    }

    public void setWeightPrior(MultivariateGaussian weightPrior) {
        this.weightPrior = weightPrior;
    }

    public double getOutputVariance() {
        return this.outputVariance;
    }

    public void setOutputVariance(double outputVariance) {
        if (outputVariance <= 0.0) {
            throw new IllegalArgumentException("outputVariance must be > 0.0");
        }
        this.outputVariance = outputVariance;
    }

    public PredictiveDistribution createPredictiveDistribution(MultivariateGaussian posterior) {
        return new PredictiveDistribution(posterior);
    }

    public static class IncrementalEstimator<InputType>
    extends BayesianLinearRegression<InputType>
    implements IncrementalLearner<InputOutputPair<? extends InputType, Double>, SufficientStatistic> {
        public IncrementalEstimator(int dimensionality) {
            super(dimensionality);
        }

        public IncrementalEstimator(int dimensionality, Evaluator<? super InputType, Vector> featureMap) {
            this(dimensionality);
            this.setFeatureMap(featureMap);
        }

        public IncrementalEstimator(Evaluator<? super InputType, Vector> featureMap, double outputVariance, MultivariateGaussian weightPrior) {
            super(featureMap, outputVariance, weightPrior);
        }

        @Override
        public SufficientStatistic createInitialLearnedObject() {
            return new SufficientStatistic(this.getWeightPrior());
        }

        @Override
        public MultivariateGaussian.PDF learn(Collection<? extends InputOutputPair<? extends InputType, Double>> data) {
            SufficientStatistic target = this.createInitialLearnedObject();
            this.update(target, (Iterable<? extends InputOutputPair<? extends InputType, Double>>)data);
            return target.create();
        }

        @Override
        public void update(SufficientStatistic target, InputOutputPair<? extends InputType, Double> data) {
            target.update(data);
        }

        @Override
        public void update(SufficientStatistic target, Iterable<? extends InputOutputPair<? extends InputType, Double>> data) {
            target.update(data);
        }

        public class SufficientStatistic
        extends AbstractSufficientStatistic<InputOutputPair<? extends InputType, Double>, MultivariateGaussian> {
            private Vector z;
            private Matrix covarianceInverse;

            public SufficientStatistic(MultivariateGaussian prior) {
                if (prior != null) {
                    this.covarianceInverse = prior.getCovarianceInverse().clone();
                    this.z = this.covarianceInverse.times(prior.getMean());
                    this.count = 1L;
                } else {
                    this.covarianceInverse = null;
                    this.z = null;
                    this.count = 0L;
                }
            }

            @Override
            public void update(InputOutputPair<? extends InputType, Double> value) {
                Vector v;
                ++this.count;
                Vector x1 = v = (Vector)IncrementalEstimator.this.featureMap.evaluate(value.getInput());
                Vector x2 = v.clone();
                double y = value.getOutput();
                double beta = DatasetUtil.getWeight(value) / IncrementalEstimator.this.outputVariance;
                if (beta != 1.0) {
                    x2.scaleEquals(beta);
                }
                if (this.covarianceInverse == null) {
                    this.covarianceInverse = x1.outerProduct(x2);
                } else {
                    this.covarianceInverse.plusEquals(x1.outerProduct(x2));
                }
                if (y != 1.0) {
                    x2.scaleEquals(y);
                }
                if (this.z == null) {
                    this.z = x2;
                } else {
                    this.z.plusEquals(x2);
                }
            }

            @Override
            public MultivariateGaussian.PDF create() {
                MultivariateGaussian.PDF g = new MultivariateGaussian.PDF(this.getDimensionality());
                this.create(g);
                return g;
            }

            @Override
            public void create(MultivariateGaussian distribution) {
                distribution.setMean(this.getMean());
                distribution.setCovarianceInverse(this.getCovarianceInverse());
            }

            public Matrix getCovarianceInverse() {
                return this.covarianceInverse;
            }

            public Vector getZ() {
                return this.z;
            }

            public Vector getMean() {
                return this.covarianceInverse.inverse().times(this.z);
            }

            public int getDimensionality() {
                return this.getZ().getDimensionality();
            }
        }
    }

    @PublicationReference(author={"Christopher M. Bishop"}, title="Pattern Recognition and Machine Learning", type=PublicationType.Book, year=2006, pages={156})
    public class PredictiveDistribution
    extends AbstractCloneableSerializable
    implements Evaluator<InputType, UnivariateGaussian.PDF> {
        private MultivariateGaussian posterior;

        public PredictiveDistribution(MultivariateGaussian posterior) {
            this.posterior = posterior;
        }

        @Override
        public UnivariateGaussian.PDF evaluate(InputType input) {
            Vector x = (Vector)BayesianLinearRegression.this.featureMap.evaluate(input);
            double mean = x.dotProduct(this.posterior.getMean());
            double variance = x.times(this.posterior.getCovariance()).dotProduct(x) + BayesianLinearRegression.this.outputVariance;
            return new UnivariateGaussian.PDF(mean, variance);
        }
    }
}

