/*
 * Decompiled with CFR 0.152.
 */
package gov.sandia.cognition.learning.function.cost;

import gov.sandia.cognition.evaluator.Evaluator;
import gov.sandia.cognition.learning.algorithm.gradient.GradientDescendable;
import gov.sandia.cognition.learning.data.DatasetUtil;
import gov.sandia.cognition.learning.data.InputOutputPair;
import gov.sandia.cognition.learning.data.TargetEstimatePair;
import gov.sandia.cognition.learning.function.cost.AbstractParallelizableCostFunction;
import gov.sandia.cognition.math.Ring;
import gov.sandia.cognition.math.RingAccumulator;
import gov.sandia.cognition.math.matrix.Matrix;
import gov.sandia.cognition.math.matrix.Vector;
import gov.sandia.cognition.util.AbstractCloneableSerializable;
import gov.sandia.cognition.util.DefaultPair;
import java.util.Collection;
import java.util.Iterator;

public class SumSquaredErrorCostFunction
extends AbstractParallelizableCostFunction {
    public SumSquaredErrorCostFunction() {
        this((Collection<? extends InputOutputPair<? extends Vector, Vector>>)null);
    }

    public SumSquaredErrorCostFunction(Collection<? extends InputOutputPair<? extends Vector, Vector>> dataset) {
        super(dataset);
    }

    @Override
    public SumSquaredErrorCostFunction clone() {
        return (SumSquaredErrorCostFunction)super.clone();
    }

    @Override
    public Object evaluatePartial(Evaluator<? super Vector, ? extends Vector> evaluator) {
        double sumSquaredError = 0.0;
        double weightSum = 0.0;
        Iterator i$ = this.getCostParameters().iterator();
        while (i$.hasNext()) {
            InputOutputPair pair = (InputOutputPair)i$.next();
            Vector target = (Vector)pair.getOutput();
            Vector estimate = evaluator.evaluate((Vector)pair.getInput());
            double errorSquared = target.euclideanDistanceSquared(estimate);
            double weight = DatasetUtil.getWeight(pair);
            weightSum += weight;
            sumSquaredError += weight * errorSquared;
        }
        return new EvaluatePartialSSE(sumSquaredError, weightSum *= 2.0);
    }

    @Override
    public Double evaluateAmalgamate(Collection<Object> partialResults) {
        double numerator = 0.0;
        double denominator = 0.0;
        for (Object result : partialResults) {
            EvaluatePartialSSE sse = (EvaluatePartialSSE)result;
            numerator += ((Double)sse.getFirst()).doubleValue();
            denominator += ((Double)sse.getSecond()).doubleValue();
        }
        if (denominator == 0.0) {
            return 0.0;
        }
        return numerator / denominator;
    }

    @Override
    public Object computeParameterGradientPartial(GradientDescendable function) {
        RingAccumulator<Vector> parameterDelta = new RingAccumulator<Vector>();
        double denominator = 0.0;
        Iterator i$ = this.getCostParameters().iterator();
        while (i$.hasNext()) {
            InputOutputPair pair = (InputOutputPair)i$.next();
            Vector input = (Vector)pair.getInput();
            Vector target = (Vector)pair.getOutput();
            Vector negativeError = (Vector)function.evaluate(input);
            negativeError.minusEquals(target);
            double weight = DatasetUtil.getWeight(pair);
            if (weight != 1.0) {
                negativeError.scaleEquals(weight);
            }
            denominator += weight;
            Matrix gradient = function.computeParameterGradient(input);
            Vector parameterUpdate = negativeError.times(gradient);
            parameterDelta.accumulate(parameterUpdate);
        }
        Vector negativeSum = (Vector)parameterDelta.getSum();
        return new GradientPartialSSE(negativeSum, denominator);
    }

    @Override
    public Vector computeParameterGradientAmalgamate(Collection<Object> partialResults) {
        RingAccumulator<Ring> numerator = new RingAccumulator<Ring>();
        double denominator = 0.0;
        for (Object result : partialResults) {
            GradientPartialSSE sse = (GradientPartialSSE)result;
            numerator.accumulate((Ring)sse.getFirst());
            denominator += ((Double)sse.getSecond()).doubleValue();
        }
        Vector scaleSum = (Vector)numerator.getSum();
        if (denominator != 0.0) {
            scaleSum.scaleEquals(1.0 / (2.0 * denominator));
        }
        return scaleSum;
    }

    @Override
    public Double evaluatePerformance(Collection<? extends TargetEstimatePair<Vector, Vector>> data) {
        double sumSquaredError = 0.0;
        double weightSum = 0.0;
        for (TargetEstimatePair<Vector, Vector> targetEstimatePair : data) {
            Vector target = targetEstimatePair.getTarget();
            Vector estimate = targetEstimatePair.getEstimate();
            double errorSquared = target.euclideanDistanceSquared(estimate);
            double weight = DatasetUtil.getWeight(targetEstimatePair);
            weightSum += weight;
            sumSquaredError += weight * errorSquared;
        }
        if ((weightSum *= 2.0) == 0.0) {
            return 0.0;
        }
        return sumSquaredError / weightSum;
    }

    public static class GradientPartialSSE
    extends DefaultPair<Vector, Double> {
        public GradientPartialSSE(Vector numerator, Double denominator) {
            super(numerator, denominator);
        }
    }

    private static class EvaluatePartialSSE
    extends DefaultPair<Double, Double> {
        public EvaluatePartialSSE(Double numerator, Double denominator) {
            super(numerator, denominator);
        }
    }

    public static class Cache
    extends AbstractCloneableSerializable {
        public final Matrix J;
        public final Matrix JtJ;
        public final Vector Jte;
        public final double parameterCost;

        protected Cache(Matrix J, Matrix JtJ, Vector Jte, double parameterCost) {
            this.J = J;
            this.JtJ = JtJ;
            this.Jte = Jte;
            this.parameterCost = parameterCost;
        }

        public static Cache compute(GradientDescendable objectToOptimize, Collection<? extends InputOutputPair<? extends Vector, Vector>> data) {
            RingAccumulator<Matrix> gradientAverage = new RingAccumulator<Matrix>();
            RingAccumulator<Vector> gradientError = new RingAccumulator<Vector>();
            double weightSum = 0.0;
            double parameterCost = 0.0;
            for (InputOutputPair<? extends Vector, Vector> inputOutputPair : data) {
                Vector negativeError = (Vector)objectToOptimize.evaluate(inputOutputPair.getInput());
                negativeError.minusEquals((Ring)inputOutputPair.getOutput());
                double norm2 = negativeError.norm2Squared();
                double weight = DatasetUtil.getWeight(inputOutputPair);
                if (weight != 1.0) {
                    negativeError.scaleEquals(weight);
                }
                weightSum += weight;
                parameterCost += norm2 * weight;
                Matrix gradient = objectToOptimize.computeParameterGradient(inputOutputPair.getInput());
                gradientAverage.accumulate(gradient);
                gradientError.accumulate(negativeError.times(gradient));
            }
            if ((weightSum *= 2.0) == 0.0) {
                weightSum = 1.0;
            }
            Matrix J = (Matrix)gradientAverage.getSum();
            J.scaleEquals(1.0 / weightSum);
            Matrix matrix = J.transpose().times(J);
            Vector Jte = (Vector)gradientError.getSum();
            Jte.scaleEquals(1.0 / weightSum);
            return new Cache(J, matrix, Jte, parameterCost /= weightSum);
        }
    }
}

