/*
 * Decompiled with CFR 0.152.
 */
package gov.sandia.cognition.learning.function.cost;

import gov.sandia.cognition.annotation.CodeReview;
import gov.sandia.cognition.learning.algorithm.gradient.GradientDescendable;
import gov.sandia.cognition.learning.data.DatasetUtil;
import gov.sandia.cognition.learning.data.InputOutputPair;
import gov.sandia.cognition.learning.data.TargetEstimatePair;
import gov.sandia.cognition.learning.function.cost.AbstractSupervisedCostFunction;
import gov.sandia.cognition.learning.function.cost.DifferentiableCostFunction;
import gov.sandia.cognition.math.RingAccumulator;
import gov.sandia.cognition.math.matrix.Matrix;
import gov.sandia.cognition.math.matrix.Vector;
import java.util.Collection;
import java.util.Iterator;

@CodeReview(reviewer={"Justin Basilico"}, date="2006-10-04", changesNeeded=false, comments={"Minor documentaMtion changes."})
public class MeanSquaredErrorCostFunction
extends AbstractSupervisedCostFunction<Vector, Vector>
implements DifferentiableCostFunction {
    public MeanSquaredErrorCostFunction() {
        this((Collection<? extends InputOutputPair<? extends Vector, Vector>>)null);
    }

    public MeanSquaredErrorCostFunction(Collection<? extends InputOutputPair<? extends Vector, Vector>> dataset) {
        super(dataset);
    }

    @Override
    public MeanSquaredErrorCostFunction clone() {
        return (MeanSquaredErrorCostFunction)super.clone();
    }

    @Override
    public Double evaluatePerformance(Collection<? extends TargetEstimatePair<Vector, Vector>> data) {
        double sumSquaredError = 0.0;
        double denominator = 0.0;
        for (TargetEstimatePair<Vector, Vector> targetEstimatePair : data) {
            Vector target = targetEstimatePair.getTarget();
            Vector estimate = targetEstimatePair.getEstimate();
            double errorSquared = target.euclideanDistanceSquared(estimate);
            double weight = DatasetUtil.getWeight(targetEstimatePair);
            sumSquaredError += weight * errorSquared;
            denominator += weight;
        }
        double meanSquaredError = 0.0;
        if (denominator != 0.0) {
            meanSquaredError = sumSquaredError / denominator;
        }
        return meanSquaredError;
    }

    @Override
    public Vector computeParameterGradient(GradientDescendable function) {
        RingAccumulator<Vector> parameterDelta = new RingAccumulator<Vector>();
        double denominator = 0.0;
        Iterator i$ = this.getCostParameters().iterator();
        while (i$.hasNext()) {
            InputOutputPair pair = (InputOutputPair)i$.next();
            Vector input = (Vector)pair.getInput();
            Vector target = (Vector)pair.getOutput();
            Vector negativeError = (Vector)function.evaluate(input);
            negativeError.minusEquals(target);
            double weight = DatasetUtil.getWeight(pair);
            if (weight != 1.0) {
                negativeError.scaleEquals(weight);
            }
            denominator += weight;
            Matrix gradient = function.computeParameterGradient(input);
            Vector parameterUpdate = negativeError.times(gradient);
            parameterDelta.accumulate(parameterUpdate);
        }
        Vector negativeSum = (Vector)parameterDelta.getSum();
        if (denominator != 0.0) {
            negativeSum.scaleEquals(1.0 / denominator);
        }
        return negativeSum;
    }
}

