/*
 * Decompiled with CFR 0.152.
 */
package gov.sandia.cognition.learning.algorithm.regression;

import gov.sandia.cognition.annotation.PublicationReference;
import gov.sandia.cognition.annotation.PublicationReferences;
import gov.sandia.cognition.annotation.PublicationType;
import gov.sandia.cognition.learning.algorithm.gradient.GradientDescendable;
import gov.sandia.cognition.learning.algorithm.minimization.FunctionMinimizerBFGS;
import gov.sandia.cognition.learning.algorithm.minimization.MinimizationStoppingCriterion;
import gov.sandia.cognition.learning.algorithm.minimization.line.DirectionalVectorToDifferentiableScalarFunction;
import gov.sandia.cognition.learning.algorithm.minimization.line.LineMinimizer;
import gov.sandia.cognition.learning.algorithm.minimization.line.LineMinimizerDerivativeFree;
import gov.sandia.cognition.learning.algorithm.regression.LeastSquaresEstimator;
import gov.sandia.cognition.learning.algorithm.regression.ParameterDifferentiableCostMinimizer;
import gov.sandia.cognition.learning.data.WeightedInputOutputPair;
import gov.sandia.cognition.learning.function.cost.DifferentiableCostFunction;
import gov.sandia.cognition.learning.function.cost.SumSquaredErrorCostFunction;
import gov.sandia.cognition.math.matrix.Matrix;
import gov.sandia.cognition.math.matrix.MatrixFactory;
import gov.sandia.cognition.math.matrix.Vector;
import gov.sandia.cognition.util.ObjectUtil;
import java.util.Collection;

@PublicationReferences(references={@PublicationReference(author={"R. Fletcher"}, title="Practical Methods of Optimization, Second Edition", type=PublicationType.Book, year=1987, pages={116, 117}, notes={"Section 6.1 motivates the algorithm w.r.t. Gauss-Newton, BFGS, and Levenberg-Marquardt"}), @PublicationReference(author={"R. Fletcher", "C. Xu"}, title="Hybrid Methods for Nonlinear Least Squares", type=PublicationType.Journal, year=1987, pages={371, 389}, publication="Institute of Mathematics and its Applications Journal of Numerical Analysis")})
public class FletcherXuHybridEstimation
extends LeastSquaresEstimator {
    public static final double DEFAULT_REDUCTION_TEST = 0.2;
    public static final double DEFAULT_DAMPING_DIVISOR = 2.0;
    public static final LineMinimizer<?> DEFAULT_LINE_MINIMIZER = new LineMinimizerDerivativeFree();
    private LineMinimizer<?> lineMinimizer;
    private double reductionTest;
    private double dampingFactorDivisor;
    private SumSquaredErrorCostFunction.Cache lastCost;
    private DirectionalVectorToDifferentiableScalarFunction lineFunction;
    private Matrix hessianInverse;
    private double dampingFactor;

    public FletcherXuHybridEstimation() {
        this(ObjectUtil.cloneSafe(DEFAULT_LINE_MINIMIZER));
    }

    public FletcherXuHybridEstimation(LineMinimizer<?> lineMinimizer) {
        this(lineMinimizer, 0.2);
    }

    public FletcherXuHybridEstimation(LineMinimizer<?> lineMinimizer, double reductionTest) {
        this(lineMinimizer, reductionTest, 2.0);
    }

    public FletcherXuHybridEstimation(LineMinimizer<?> lineMinimizer, double reductionTest, double dampingFactorDivisor) {
        this(lineMinimizer, reductionTest, dampingFactorDivisor, 1000, 1.0E-7);
    }

    public FletcherXuHybridEstimation(LineMinimizer<?> lineMinimizer, double reductionTest, double dampingFactorDivisor, int maxIterations, double tolerance) {
        super(maxIterations, tolerance);
        this.setLineMinimizer(lineMinimizer);
        this.setReductionTest(reductionTest);
        this.setDampingFactorDivisor(dampingFactorDivisor);
    }

    @Override
    protected boolean initializeAlgorithm() {
        this.setResult(((GradientDescendable)this.getObjectToOptimize()).clone());
        ((SumSquaredErrorCostFunction)this.getCostFunction()).setCostParameters((Collection)this.getData());
        this.dampingFactor = 1.0;
        this.lastCost = SumSquaredErrorCostFunction.Cache.compute((GradientDescendable)this.getResult(), (Collection)this.getData());
        ParameterDifferentiableCostMinimizer.ParameterCostEvaluatorDerivativeBased f = new ParameterDifferentiableCostMinimizer.ParameterCostEvaluatorDerivativeBased((GradientDescendable)this.getResult(), (DifferentiableCostFunction)this.getCostFunction());
        Vector parameters = ((GradientDescendable)this.getResult()).convertToVector();
        int M = parameters.getDimensionality();
        this.lineFunction = new DirectionalVectorToDifferentiableScalarFunction(f, parameters, this.lastCost.Jte);
        this.hessianInverse = (Matrix)MatrixFactory.getDefault().createIdentity(M, M).scale(0.5);
        return true;
    }

    @Override
    protected boolean step() {
        WeightedInputOutputPair<Vector, Double> result = this.getLineMinimizer().minimizeAlongDirection(this.lineFunction, this.lastCost.parameterCost, this.lastCost.Jte);
        Vector lastParameters = this.lineFunction.getVectorOffset();
        Vector delta = ((Vector)result.getInput()).minus(lastParameters);
        this.lineFunction.setVectorOffset((Vector)result.getInput());
        ((GradientDescendable)this.getResult()).convertFromVector((Vector)result.getInput());
        SumSquaredErrorCostFunction.Cache cache = SumSquaredErrorCostFunction.Cache.compute((GradientDescendable)this.getResult(), (Collection)this.getData());
        this.setResultCost(cache.parameterCost);
        if (this.getReductionTest() * this.lastCost.parameterCost <= this.lastCost.parameterCost - cache.parameterCost) {
            Matrix JtJpI = (Matrix)cache.JtJ.scale(-1.0);
            Vector Jte = cache.Jte;
            int M = JtJpI.getNumRows();
            for (int i = 0; i < M; ++i) {
                double v = JtJpI.getElement(i, i);
                JtJpI.setElement(i, i, v - this.dampingFactor);
            }
            Vector direction = JtJpI.solve(Jte);
            this.dampingFactor /= this.getDampingFactorDivisor();
            double directionNorm = direction.norm2();
            if (directionNorm > 100.0) {
                direction.scaleEquals(100.0 / directionNorm);
            }
            this.lineFunction.setDirection(direction);
        } else {
            Vector gamma = cache.JtJ.times((Vector)this.lineFunction.getDirection().scale(2.0)).plus(cache.Jte.minus(this.lastCost.Jte));
            FunctionMinimizerBFGS.BFGSupdateRule(this.hessianInverse, delta, gamma, this.getTolerance());
            Vector direction = this.hessianInverse.times(cache.Jte);
            this.lineFunction.setDirection(direction);
            this.dampingFactor *= this.getDampingFactorDivisor();
        }
        this.lastCost = cache;
        return !MinimizationStoppingCriterion.convergence((Vector)result.getInput(), (Double)result.getOutput(), delta, cache.Jte, this.getTolerance());
    }

    @Override
    protected void cleanupAlgorithm() {
    }

    public double getReductionTest() {
        return this.reductionTest;
    }

    public void setReductionTest(double reductionTest) {
        if (reductionTest < 0.0 || reductionTest > 1.0) {
            throw new IllegalArgumentException("reductionTest must be [0,1]");
        }
        this.reductionTest = reductionTest;
    }

    public double getDampingFactorDivisor() {
        return this.dampingFactorDivisor;
    }

    public void setDampingFactorDivisor(double dampingFactorDivisor) {
        this.dampingFactorDivisor = dampingFactorDivisor;
    }

    public LineMinimizer<?> getLineMinimizer() {
        return this.lineMinimizer;
    }

    public void setLineMinimizer(LineMinimizer<?> lineMinimizer) {
        this.lineMinimizer = lineMinimizer;
    }
}

