/*
 * Decompiled with CFR 0.152.
 */
package gov.sandia.cognition.learning.function.vector;

import gov.sandia.cognition.annotation.CodeReview;
import gov.sandia.cognition.annotation.CodeReviewResponse;
import gov.sandia.cognition.learning.algorithm.SupervisedBatchLearner;
import gov.sandia.cognition.learning.algorithm.gradient.GradientDescendable;
import gov.sandia.cognition.learning.data.DatasetUtil;
import gov.sandia.cognition.learning.data.InputOutputPair;
import gov.sandia.cognition.math.matrix.Matrix;
import gov.sandia.cognition.math.matrix.MatrixFactory;
import gov.sandia.cognition.math.matrix.Vector;
import gov.sandia.cognition.math.matrix.VectorInputEvaluator;
import gov.sandia.cognition.math.matrix.VectorOutputEvaluator;
import gov.sandia.cognition.math.matrix.VectorizableDifferentiableVectorFunction;
import gov.sandia.cognition.util.AbstractCloneableSerializable;
import java.util.Collection;

@CodeReview(reviewer={"Justin Basilico"}, date="2006-10-06", changesNeeded=true, comments={"Can you just add a comment for why the differentiation code is correct?", "Otherwise, class looks fine."}, response={@CodeReviewResponse(respondent="Kevin R. Dixon", date="2006-10-06", moreChangesNeeded=false, comments={"Added in-code comment describing the derivation of the differentiation formulae."})})
public class MatrixMultiplyVectorFunction
extends AbstractCloneableSerializable
implements VectorizableDifferentiableVectorFunction,
VectorInputEvaluator<Vector, Vector>,
VectorOutputEvaluator<Vector, Vector>,
GradientDescendable {
    private Matrix internalMatrix;

    public MatrixMultiplyVectorFunction() {
        this(1, 1);
    }

    public MatrixMultiplyVectorFunction(int numInputs, int numOutputs) {
        this(MatrixFactory.getDefault().createIdentity(numOutputs, numInputs));
    }

    public MatrixMultiplyVectorFunction(Matrix internalMatrix) {
        this.setInternalMatrix(internalMatrix);
    }

    public MatrixMultiplyVectorFunction(MatrixMultiplyVectorFunction other) {
        this(other.getInternalMatrix().clone());
    }

    @Override
    public MatrixMultiplyVectorFunction clone() {
        MatrixMultiplyVectorFunction clone = (MatrixMultiplyVectorFunction)super.clone();
        clone.setInternalMatrix(this.getInternalMatrix().clone());
        return clone;
    }

    public Matrix getInternalMatrix() {
        return this.internalMatrix;
    }

    protected void setInternalMatrix(Matrix internalMatrix) {
        this.internalMatrix = internalMatrix;
    }

    @Override
    public Vector convertToVector() {
        return this.internalMatrix.convertToVector();
    }

    @Override
    public void convertFromVector(Vector parameters) {
        this.internalMatrix.convertFromVector(parameters);
    }

    @Override
    public Vector evaluate(Vector input) {
        return this.internalMatrix.times(input);
    }

    @Override
    public Matrix differentiate(Vector input) {
        return this.getInternalMatrix();
    }

    @Override
    public Matrix computeParameterGradient(Vector input) {
        return MatrixMultiplyVectorFunction.computeParameterGradient(this.internalMatrix, input);
    }

    public static Matrix computeParameterGradient(Matrix matrix, Vector input) {
        int M = matrix.getNumRows();
        int N = matrix.getNumColumns();
        Matrix gradient = MatrixFactory.getDefault().createMatrix(M, M * N);
        int columnIndex = 0;
        for (int j = 0; j < N; ++j) {
            double inputValue = input.getElement(j);
            for (int i = 0; i < M; ++i) {
                gradient.setElement(i, columnIndex, inputValue);
                ++columnIndex;
            }
        }
        return gradient;
    }

    public String toString() {
        return ((Object)this.getInternalMatrix()).toString();
    }

    @Override
    public int getInputDimensionality() {
        return this.getInternalMatrix().getNumColumns();
    }

    @Override
    public int getOutputDimensionality() {
        return this.getInternalMatrix().getNumRows();
    }

    public static class ClosedFormSolver
    extends AbstractCloneableSerializable
    implements SupervisedBatchLearner<Vector, Vector, MatrixMultiplyVectorFunction> {
        @Override
        public MatrixMultiplyVectorFunction learn(Collection<? extends InputOutputPair<? extends Vector, Vector>> data) {
            InputOutputPair<? extends Vector, Vector> first = data.iterator().next();
            int M = first.getOutput().getDimensionality();
            int N = first.getInput().getDimensionality();
            int num = data.size();
            Matrix Y = MatrixFactory.getDefault().createMatrix(M, num);
            Matrix X = MatrixFactory.getDefault().createMatrix(N, num);
            int n = 0;
            for (InputOutputPair<? extends Vector, Vector> inputOutputPair : data) {
                Vector x = inputOutputPair.getInput();
                Vector y = inputOutputPair.getOutput();
                double weight = DatasetUtil.getWeight(inputOutputPair);
                if (weight != 1.0) {
                    x = (Vector)x.scale(weight);
                    y = (Vector)y.scale(weight);
                }
                X.setColumn(n, x);
                Y.setColumn(n, y);
                ++n;
            }
            return ClosedFormSolver.learn(X, Y);
        }

        public static MatrixMultiplyVectorFunction learn(Matrix X, Matrix Y) {
            Matrix Xt = X.transpose();
            Matrix Yt = Y.transpose();
            Matrix A = Xt.solve(Yt).transpose();
            return new MatrixMultiplyVectorFunction(A);
        }
    }
}

