/*
 * Decompiled with CFR 0.152.
 */
package gov.sandia.cognition.learning.algorithm.regression;

import gov.sandia.cognition.annotation.PublicationReference;
import gov.sandia.cognition.annotation.PublicationType;
import gov.sandia.cognition.learning.algorithm.gradient.GradientDescendable;
import gov.sandia.cognition.learning.algorithm.minimization.MinimizationStoppingCriterion;
import gov.sandia.cognition.learning.algorithm.minimization.line.DirectionalVectorToDifferentiableScalarFunction;
import gov.sandia.cognition.learning.algorithm.minimization.line.LineMinimizer;
import gov.sandia.cognition.learning.algorithm.minimization.line.LineMinimizerDerivativeFree;
import gov.sandia.cognition.learning.algorithm.regression.LeastSquaresEstimator;
import gov.sandia.cognition.learning.algorithm.regression.ParameterDifferentiableCostMinimizer;
import gov.sandia.cognition.learning.data.WeightedInputOutputPair;
import gov.sandia.cognition.learning.function.cost.DifferentiableCostFunction;
import gov.sandia.cognition.learning.function.cost.SumSquaredErrorCostFunction;
import gov.sandia.cognition.math.matrix.Vector;
import gov.sandia.cognition.util.ObjectUtil;
import java.util.Collection;

@PublicationReference(author={"Wikipedia"}, title="Gauss-Newton algorithm", type=PublicationType.WebPage, year=2009, url="http://en.wikipedia.org/wiki/Gauss%E2%80%93Newton_algorithm")
public class GaussNewtonAlgorithm
extends LeastSquaresEstimator {
    public static final LineMinimizer<?> DEFAULT_LINE_MINIMIZER = new LineMinimizerDerivativeFree();
    private LineMinimizer<?> lineMinimizer;
    private DirectionalVectorToDifferentiableScalarFunction lineFunction;
    public static final double STEP_MAX = 100.0;

    public GaussNewtonAlgorithm() {
        this(ObjectUtil.cloneSafe(DEFAULT_LINE_MINIMIZER));
    }

    public GaussNewtonAlgorithm(LineMinimizer<?> lineMinimizer) {
        this(lineMinimizer, 2000, 1.0E-9);
    }

    public GaussNewtonAlgorithm(LineMinimizer<?> lineMinimizer, int maxIterations, double tolerance) {
        super(maxIterations, tolerance);
        this.setLineMinimizer(lineMinimizer);
    }

    @Override
    protected boolean initializeAlgorithm() {
        this.setResult(((GradientDescendable)this.getObjectToOptimize()).clone());
        ((SumSquaredErrorCostFunction)this.getCostFunction()).setCostParameters((Collection)this.getData());
        Vector parameters = ((GradientDescendable)this.getResult()).convertToVector();
        SumSquaredErrorCostFunction.Cache cost = SumSquaredErrorCostFunction.Cache.compute((GradientDescendable)this.getResult(), (Collection)this.getData());
        ParameterDifferentiableCostMinimizer.ParameterCostEvaluatorDerivativeBased f = new ParameterDifferentiableCostMinimizer.ParameterCostEvaluatorDerivativeBased((GradientDescendable)this.getResult(), (DifferentiableCostFunction)this.getCostFunction());
        this.lineFunction = new DirectionalVectorToDifferentiableScalarFunction(f, parameters, cost.Jte);
        return true;
    }

    @Override
    protected boolean step() {
        SumSquaredErrorCostFunction.Cache cost = SumSquaredErrorCostFunction.Cache.compute((GradientDescendable)this.getResult(), (Collection)this.getData());
        Vector lastParameters = this.lineFunction.getVectorOffset();
        Vector direction = cost.JtJ.solve(cost.Jte);
        double directionNorm = direction.norm2();
        if (directionNorm > 100.0) {
            direction.scaleEquals(100.0 / directionNorm);
        }
        this.lineFunction.setDirection(direction);
        WeightedInputOutputPair<Vector, Double> result = this.getLineMinimizer().minimizeAlongDirection(this.lineFunction, cost.parameterCost, cost.Jte);
        this.lineFunction.setVectorOffset((Vector)result.getInput());
        this.setResultCost((Double)result.getOutput());
        Vector delta = ((Vector)result.getInput()).minus(lastParameters);
        ((GradientDescendable)this.getResult()).convertFromVector((Vector)result.getInput());
        return !MinimizationStoppingCriterion.convergence((Vector)result.getInput(), (Double)result.getOutput(), cost.Jte, delta, this.getTolerance());
    }

    @Override
    protected void cleanupAlgorithm() {
    }

    public LineMinimizer<?> getLineMinimizer() {
        return this.lineMinimizer;
    }

    public void setLineMinimizer(LineMinimizer<?> lineMinimizer) {
        this.lineMinimizer = lineMinimizer;
    }
}

