/*
 * Decompiled with CFR 0.152.
 */
package gov.sandia.cognition.learning.function.vector;

import gov.sandia.cognition.learning.algorithm.gradient.GradientDescendable;
import gov.sandia.cognition.learning.function.scalar.AtanFunction;
import gov.sandia.cognition.learning.function.vector.ElementWiseDifferentiableVectorFunction;
import gov.sandia.cognition.learning.function.vector.MatrixMultiplyVectorFunction;
import gov.sandia.cognition.learning.function.vector.SquashedMatrixMultiplyVectorFunction;
import gov.sandia.cognition.math.DifferentiableUnivariateScalarFunction;
import gov.sandia.cognition.math.matrix.DifferentiableVectorFunction;
import gov.sandia.cognition.math.matrix.Matrix;
import gov.sandia.cognition.math.matrix.Vector;

public class DifferentiableSquashedMatrixMultiplyVectorFunction
extends SquashedMatrixMultiplyVectorFunction
implements GradientDescendable,
DifferentiableVectorFunction {
    public DifferentiableSquashedMatrixMultiplyVectorFunction() {
        this(1, 1, new AtanFunction());
    }

    public DifferentiableSquashedMatrixMultiplyVectorFunction(int numInputs, int numOutputs, DifferentiableUnivariateScalarFunction scalarFunction) {
        this(new MatrixMultiplyVectorFunction(numInputs, numOutputs), new ElementWiseDifferentiableVectorFunction(scalarFunction));
    }

    public DifferentiableSquashedMatrixMultiplyVectorFunction(MatrixMultiplyVectorFunction matrixMultiply, DifferentiableVectorFunction squashingFunction) {
        super(matrixMultiply, squashingFunction);
    }

    public DifferentiableSquashedMatrixMultiplyVectorFunction(MatrixMultiplyVectorFunction matrixMultiply, DifferentiableUnivariateScalarFunction scalarSquashingFunction) {
        this(matrixMultiply, new ElementWiseDifferentiableVectorFunction(scalarSquashingFunction));
    }

    public DifferentiableSquashedMatrixMultiplyVectorFunction(DifferentiableSquashedMatrixMultiplyVectorFunction other) {
        super(other);
    }

    @Override
    public DifferentiableVectorFunction getSquashingFunction() {
        return (DifferentiableVectorFunction)super.getSquashingFunction();
    }

    @Override
    public Matrix computeParameterGradient(Vector input) {
        Matrix gradient = this.getMatrixMultiply().computeParameterGradient(input);
        Vector y = this.getMatrixMultiply().evaluate(input);
        Matrix derivative = this.getSquashingFunction().differentiate(y);
        return derivative.times(gradient);
    }

    @Override
    public DifferentiableSquashedMatrixMultiplyVectorFunction clone() {
        return (DifferentiableSquashedMatrixMultiplyVectorFunction)super.clone();
    }

    @Override
    public Matrix differentiate(Vector input) {
        Matrix dudx = this.getMatrixMultiply().differentiate(input);
        Vector u = this.getMatrixMultiply().evaluate(input);
        Matrix dydu = this.getSquashingFunction().differentiate(u);
        return dydu.times(dudx);
    }
}

