/*
 * Decompiled with CFR 0.152.
 */
package gov.sandia.cognition.learning.algorithm.regression;

import gov.sandia.cognition.annotation.CodeReview;
import gov.sandia.cognition.annotation.PublicationReference;
import gov.sandia.cognition.annotation.PublicationType;
import gov.sandia.cognition.evaluator.Evaluator;
import gov.sandia.cognition.learning.algorithm.SupervisedBatchLearner;
import gov.sandia.cognition.learning.data.DatasetUtil;
import gov.sandia.cognition.learning.data.InputOutputPair;
import gov.sandia.cognition.learning.function.scalar.VectorFunctionLinearDiscriminant;
import gov.sandia.cognition.learning.function.vector.ScalarBasisSet;
import gov.sandia.cognition.math.UnivariateStatisticsUtil;
import gov.sandia.cognition.math.matrix.Matrix;
import gov.sandia.cognition.math.matrix.MatrixFactory;
import gov.sandia.cognition.math.matrix.Vector;
import gov.sandia.cognition.math.matrix.VectorFactory;
import gov.sandia.cognition.statistics.distribution.ChiSquareDistribution;
import gov.sandia.cognition.statistics.method.AbstractConfidenceStatistic;
import gov.sandia.cognition.util.AbstractCloneableSerializable;
import gov.sandia.cognition.util.ObjectUtil;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;

@CodeReview(reviewer={"Kevin R. Dixon"}, date="2008-09-02", changesNeeded=false, comments={"Made minor changes to javadoc", "Looks fine."})
@PublicationReference(author={"Wikipedia"}, title="Linear regression", type=PublicationType.WebPage, year=2008, url="http://en.wikipedia.org/wiki/Linear_regression")
public class LinearRegression<InputType>
extends AbstractCloneableSerializable
implements SupervisedBatchLearner<InputType, Double, VectorFunctionLinearDiscriminant<InputType>> {
    public static final double DEFAULT_PSEUDO_INVERSE_TOLERANCE = 1.0E-10;
    private VectorFunctionLinearDiscriminant<InputType> learned;
    private Evaluator<? super InputType, Vector> inputToVectorMap;
    private boolean usePseudoInverse;

    public LinearRegression(Evaluator<? super InputType, Double> ... basisFunctions) {
        this(Arrays.asList(basisFunctions));
    }

    public LinearRegression(Collection<? extends Evaluator<? super InputType, Double>> basisFunctions) {
        this(new ScalarBasisSet(basisFunctions));
    }

    public LinearRegression(ScalarBasisSet<InputType> inputToVectorMap) {
        this((Evaluator<InputType, Vector>)inputToVectorMap);
    }

    public LinearRegression(Evaluator<? super InputType, Vector> inputToVectorMap) {
        this.setInputToVectorMap(inputToVectorMap);
        this.setUsePseudoInverse(true);
        this.setLearned(null);
    }

    @Override
    public LinearRegression<InputType> clone() {
        LinearRegression clone = (LinearRegression)super.clone();
        clone.setInputToVectorMap(ObjectUtil.cloneSmart(this.getInputToVectorMap()));
        clone.setLearned(ObjectUtil.cloneSafe(this.getLearned()));
        return clone;
    }

    public VectorFunctionLinearDiscriminant<InputType> getLearned() {
        return this.learned;
    }

    protected void setLearned(VectorFunctionLinearDiscriminant<InputType> learned) {
        this.learned = learned;
    }

    @Override
    public VectorFunctionLinearDiscriminant<InputType> learn(Collection<? extends InputOutputPair<? extends InputType, Double>> data) {
        Vector coefficients;
        this.setLearned(null);
        InputOutputPair<InputType, Double> first = data.iterator().next();
        Vector firstOutput = this.inputToVectorMap.evaluate(first.getInput());
        int numCoefficients = firstOutput.getDimensionality();
        int numSamples = data.size();
        Matrix X = MatrixFactory.getDefault().createMatrix(numSamples, numCoefficients);
        Vector y = VectorFactory.getDefault().createVector(numSamples);
        int i = 0;
        for (InputOutputPair<InputType, Double> pair : data) {
            double weight = DatasetUtil.getWeight(pair);
            y.setElement(i, pair.getOutput() * weight);
            InputType input = pair.getInput();
            Vector xrow = this.inputToVectorMap.evaluate(input);
            X.setRow(i, (Vector)xrow.scale(weight));
            ++i;
        }
        if (this.getUsePseudoInverse()) {
            Matrix psuedoInverse = X.pseudoInverse(1.0E-10);
            coefficients = psuedoInverse.times(y);
        } else {
            coefficients = X.solve(y);
        }
        this.setLearned(new VectorFunctionLinearDiscriminant<InputType>(this.inputToVectorMap, coefficients));
        return this.getLearned();
    }

    public Evaluator<? super InputType, Vector> getInputToVectorMap() {
        return this.inputToVectorMap;
    }

    public void setInputToVectorMap(Evaluator<? super InputType, Vector> inputToVectorMap) {
        this.inputToVectorMap = inputToVectorMap;
    }

    public boolean getUsePseudoInverse() {
        return this.usePseudoInverse;
    }

    public void setUsePseudoInverse(boolean usePseudoInverse) {
        this.usePseudoInverse = usePseudoInverse;
    }

    public static class Statistic
    extends AbstractConfidenceStatistic {
        private double chiSquare;
        private double rootMeanSquaredError;
        private double meanL1Error;
        private double targetEstimateCorrelation;
        private double unpredictedErrorFraction;
        private int numSamples;
        private int numParameters;
        private double degreesOfFreedom;

        public Statistic(Collection<Double> targets, Collection<Double> estimates, int numParameters) {
            super(0.0);
            List<Double> weights = Collections.nCopies(targets.size(), new Double(1.0));
            this.computeStatistics(targets, estimates, weights, numParameters);
        }

        public Statistic(Collection<Double> targets, Collection<Double> estimates, Collection<Double> weights, int numParameters) {
            super(0.0);
            this.computeStatistics(targets, estimates, weights, numParameters);
        }

        private Statistic(Statistic other) {
            super(other.getNullHypothesisProbability());
            this.setDegreesOfFreedom(other.getDegreesOfFreedom());
            this.setMeanL1Error(other.getMeanL1Error());
            this.setNumParameters(other.getNumParameters());
            this.setNumSamples(other.getNumSamples());
            this.setRootMeanSquaredError(other.getRootMeanSquaredError());
            this.setTargetEstimateCorrelation(other.getTargetEstimateCorrelation());
            this.setUnpredictedErrorFraction(other.getUnpredictedErrorFraction());
        }

        @Override
        public Statistic clone() {
            return (Statistic)super.clone();
        }

        private void computeStatistics(Collection<Double> targets, Collection<Double> estimates, Collection<Double> weights, int numParameters) {
            if (targets.size() != estimates.size() && targets.size() != weights.size()) {
                throw new IllegalArgumentException("Targets, Estimates, and Weights must be the same size!");
            }
            int num = targets.size();
            ArrayList<Double> errors = new ArrayList<Double>(num);
            double averageL1Error = 0.0;
            double weightSum = 0.0;
            Iterator<Double> it = targets.iterator();
            Iterator<Double> ie = estimates.iterator();
            Iterator<Double> iw = weights.iterator();
            for (int n = 0; n < num; ++n) {
                double estimate = ie.next();
                double target = it.next();
                double weight = iw.next();
                double error = weight * (target - estimate);
                errors.add(error);
                averageL1Error += Math.abs(error);
                weightSum += weight;
            }
            averageL1Error = weightSum > 0.0 ? (averageL1Error /= weightSum) : 0.0;
            double dofs = num - numParameters;
            if (dofs < 1.0) {
                dofs = 1.0;
            }
            double chi2 = UnivariateStatisticsUtil.computeSumSquaredDifference(errors, 0.0);
            double pvalue = 1.0 - ChiSquareDistribution.CDF.evaluate(chi2, dofs);
            double rmsError = UnivariateStatisticsUtil.computeRootMeanSquaredError(errors, 0.0);
            double correlation = UnivariateStatisticsUtil.computeCorrelation(targets, estimates);
            double unpredictedFraction = 1.0 - correlation * correlation;
            this.setNullHypothesisProbability(pvalue);
            this.setChiSquare(chi2);
            this.setDegreesOfFreedom(dofs);
            this.setMeanL1Error(averageL1Error);
            this.setNumSamples(num);
            this.setRootMeanSquaredError(rmsError);
            this.setTargetEstimateCorrelation(correlation);
            this.setUnpredictedErrorFraction(unpredictedFraction);
            this.setNumParameters(numParameters);
        }

        public double getRootMeanSquaredError() {
            return this.rootMeanSquaredError;
        }

        protected void setRootMeanSquaredError(double rootMeanSquaredError) {
            this.rootMeanSquaredError = rootMeanSquaredError;
        }

        public double getTargetEstimateCorrelation() {
            return this.targetEstimateCorrelation;
        }

        protected void setTargetEstimateCorrelation(double targetEstimateCorrelation) {
            this.targetEstimateCorrelation = targetEstimateCorrelation;
        }

        public double getUnpredictedErrorFraction() {
            return this.unpredictedErrorFraction;
        }

        protected void setUnpredictedErrorFraction(double unpredictedErrorFraction) {
            this.unpredictedErrorFraction = unpredictedErrorFraction;
        }

        public int getNumSamples() {
            return this.numSamples;
        }

        protected void setNumSamples(int numSamples) {
            this.numSamples = numSamples;
        }

        public double getDegreesOfFreedom() {
            return this.degreesOfFreedom;
        }

        protected void setDegreesOfFreedom(double degreesOfFreedom) {
            this.degreesOfFreedom = degreesOfFreedom;
        }

        public double getMeanL1Error() {
            return this.meanL1Error;
        }

        protected void setMeanL1Error(double meanL1Error) {
            this.meanL1Error = meanL1Error;
        }

        public int getNumParameters() {
            return this.numParameters;
        }

        public void setNumParameters(int numParameters) {
            this.numParameters = numParameters;
        }

        public double getChiSquare() {
            return this.chiSquare;
        }

        public void setChiSquare(double chiSquare) {
            this.chiSquare = chiSquare;
        }

        @Override
        public double getTestStatistic() {
            return this.getChiSquare();
        }
    }
}

