/*
 * Decompiled with CFR 0.152.
 */
package gov.sandia.cognition.learning.algorithm.minimization;

import gov.sandia.cognition.learning.algorithm.minimization.AbstractAnytimeFunctionMinimizer;
import gov.sandia.cognition.learning.algorithm.minimization.MinimizationStoppingCriterion;
import gov.sandia.cognition.learning.data.DefaultInputOutputPair;
import gov.sandia.cognition.math.DifferentiableEvaluator;
import gov.sandia.cognition.math.matrix.Vector;

public class FunctionMinimizerGradientDescent
extends AbstractAnytimeFunctionMinimizer<Vector, Double, DifferentiableEvaluator<? super Vector, Double, Vector>> {
    private double learningRate;
    private double momentum;
    public static final double DEFAULT_LEARNING_RATE = 0.1;
    public static final double DEFAULT_MOMENTUM = 0.8;
    public static final double DEFAULT_TOLERANCE = 1.0E-7;
    public static final int DEFAULT_MAX_ITERATIONS = 1000000;
    private Vector previousDelta;

    public FunctionMinimizerGradientDescent() {
        this(0.1, 0.8);
    }

    public FunctionMinimizerGradientDescent(double learningRate, double momentum) {
        this(learningRate, momentum, null, 1.0E-7, 1000000);
    }

    public FunctionMinimizerGradientDescent(double learningRate, double momentum, Vector initialGuess, double tolerance, int maxIterations) {
        super(initialGuess, tolerance, maxIterations);
        this.setLearningRate(learningRate);
        this.setMomentum(momentum);
    }

    @Override
    protected boolean initializeAlgorithm() {
        this.previousDelta = null;
        this.result = new DefaultInputOutputPair<Vector, Object>(((Vector)this.initialGuess).clone(), null);
        return true;
    }

    @Override
    protected boolean step() {
        Vector xhat = (Vector)this.result.getInput();
        Vector gradient = (Vector)((DifferentiableEvaluator)this.data).differentiate(xhat);
        Vector delta = (Vector)gradient.scale(-this.learningRate);
        if (this.previousDelta != null && this.momentum != 0.0) {
            delta.plusEquals(this.previousDelta.scale(this.momentum));
        }
        this.previousDelta = delta;
        xhat.plusEquals(delta);
        return !MinimizationStoppingCriterion.convergence(xhat, null, gradient, delta, this.getTolerance());
    }

    @Override
    protected void cleanupAlgorithm() {
        double yhat = (Double)((DifferentiableEvaluator)this.data).evaluate(this.result.getInput());
        this.result = DefaultInputOutputPair.create(this.result.getInput(), yhat);
    }

    public double getLearningRate() {
        return this.learningRate;
    }

    public void setLearningRate(double learningRate) {
        if (learningRate <= 0.0 || learningRate > 1.0) {
            throw new IllegalArgumentException("Learning rate " + learningRate + " must be (0,1].");
        }
        this.learningRate = learningRate;
    }

    public double getMomentum() {
        return this.momentum;
    }

    public void setMomentum(double momentum) {
        if (momentum < 0.0 || momentum >= 1.0) {
            throw new IllegalArgumentException("momentum must be 0.0 <= momentum < 1.0");
        }
        this.momentum = momentum;
    }
}

