/*
 * Decompiled with CFR 0.152.
 */
package gov.sandia.cognition.learning.algorithm.pca;

import gov.sandia.cognition.annotation.PublicationReference;
import gov.sandia.cognition.annotation.PublicationReferences;
import gov.sandia.cognition.annotation.PublicationType;
import gov.sandia.cognition.collection.CollectionUtil;
import gov.sandia.cognition.learning.algorithm.BatchLearner;
import gov.sandia.cognition.learning.function.kernel.DefaultKernelContainer;
import gov.sandia.cognition.learning.function.kernel.Kernel;
import gov.sandia.cognition.math.matrix.DiagonalMatrix;
import gov.sandia.cognition.math.matrix.Matrix;
import gov.sandia.cognition.math.matrix.MatrixFactory;
import gov.sandia.cognition.math.matrix.Vector;
import gov.sandia.cognition.math.matrix.VectorFactory;
import gov.sandia.cognition.math.matrix.VectorOutputEvaluator;
import gov.sandia.cognition.math.matrix.mtj.DenseMatrix;
import gov.sandia.cognition.math.matrix.mtj.DenseMatrixFactoryMTJ;
import gov.sandia.cognition.math.matrix.mtj.decomposition.EigenDecompositionRightMTJ;
import gov.sandia.cognition.util.ArgumentChecker;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;

@PublicationReferences(references={@PublicationReference(author={"Bernard Scholkopf", "Alexander Smola", "Klaus-Robert Muller"}, title="Nonlinear Component Analysis as a Kernel Eigenvalue Problem", year=1996, type=PublicationType.TechnicalReport, url="http://citeseerx.ist.psu.edu/viewdoc/summary?doi=10.1.1.29.1366"), @PublicationReference(author={"John  Shawe-Taylor", "Nello Christianini"}, title="Kernel Methods for Pattern Analysis", year=2004, type=PublicationType.Book, pages={150, 153})})
public class KernelPrincipalComponentsAnalysis<DataType>
extends DefaultKernelContainer<DataType>
implements BatchLearner<Collection<? extends DataType>, Function<DataType>> {
    public static final int DEFAULT_COMPONENT_COUNT = 10;
    public static final boolean DEFAULT_CENTER_DATA = true;
    protected int componentCount;
    protected boolean centerData;

    public KernelPrincipalComponentsAnalysis() {
        this(null, 10);
    }

    public KernelPrincipalComponentsAnalysis(Kernel<? super DataType> kernel, int componentCount) {
        this(kernel, componentCount, true);
    }

    public KernelPrincipalComponentsAnalysis(Kernel<? super DataType> kernel, int componentCount, boolean centerData) {
        super(kernel);
        this.setComponentCount(componentCount);
        this.setCenterData(centerData);
    }

    @Override
    public Function<DataType> learn(Collection<? extends DataType> data) {
        DenseMatrix k;
        int dataSize = data.size();
        ArrayList<DataType> dataList = CollectionUtil.asArrayList(data);
        DenseMatrix kernelMatrix = new DenseMatrixFactoryMTJ().createMatrix(dataSize, dataSize);
        for (int i = 0; i < dataSize; ++i) {
            DataType x = dataList.get(i);
            kernelMatrix.setElement(i, i, this.kernel.evaluate(x, x));
            for (int j = i + 1; j < dataSize; ++j) {
                DataType y = dataList.get(j);
                double value = this.kernel.evaluate(x, y);
                kernelMatrix.setElement(i, j, value);
                kernelMatrix.setElement(j, i, value);
            }
        }
        if (!this.centerData) {
            k = kernelMatrix;
        } else {
            DiagonalMatrix centeringTerm = MatrixFactory.getDiagonalDefault().createIdentity(dataSize, dataSize);
            centeringTerm.scaleEquals(1.0 / (double)dataSize);
            k = kernelMatrix.clone();
            k.minusEquals(centeringTerm.times(kernelMatrix));
            k.minusEquals(kernelMatrix.times(centeringTerm));
            k.plusEquals(centeringTerm.times(kernelMatrix.times(centeringTerm)));
        }
        int realComponentCount = Math.min(this.componentCount, dataSize);
        EigenDecompositionRightMTJ decomposition = EigenDecompositionRightMTJ.create(k);
        Matrix components = MatrixFactory.getDenseDefault().createMatrix(realComponentCount, dataSize);
        for (int i = 0; i < realComponentCount; ++i) {
            Vector eigenVector = decomposition.getEigenVectorsRealPart().getColumn(i);
            double eigenValue = decomposition.getEigenValue(i).getRealPart();
            Vector component = (Vector)eigenVector.scale(1.0 / Math.sqrt(Math.abs(eigenValue)));
            components.setRow(i, component);
        }
        return new Function<DataType>(this.kernel, dataList, components, this.centerData, kernelMatrix);
    }

    public int getComponentCount() {
        return this.componentCount;
    }

    public void setComponentCount(int componentCount) {
        ArgumentChecker.assertIsPositive("componentCount", componentCount);
        this.componentCount = componentCount;
    }

    public boolean isCenterData() {
        return this.centerData;
    }

    public void setCenterData(boolean centerData) {
        this.centerData = centerData;
    }

    public static class Function<DataType>
    extends DefaultKernelContainer<DataType>
    implements VectorOutputEvaluator<DataType, Vector> {
        protected List<? extends DataType> data;
        protected Matrix components;
        protected boolean centerData;
        protected Matrix kernelMatrix;

        public Function() {
            this(null, null, null, true, null);
        }

        public Function(Kernel<? super DataType> kernel, List<? extends DataType> data, Matrix components, boolean centerData, Matrix kernelMatrix) {
            super(kernel);
            this.setData(data);
            this.setComponents(components);
            this.setCenterData(centerData);
            this.setKernelMatrix(kernelMatrix);
        }

        @Override
        public Vector evaluate(DataType input) {
            Vector kInput;
            int dataSize = this.data.size();
            Vector kernelVector = VectorFactory.getDenseDefault().createVector(dataSize);
            int index = 0;
            for (DataType other : this.data) {
                double value = this.kernel.evaluate(input, other);
                kernelVector.setElement(index, value);
                ++index;
            }
            if (!this.centerData || this.kernelMatrix == null) {
                kInput = kernelVector;
            } else {
                DiagonalMatrix centeringMatrix = MatrixFactory.getDiagonalDefault().createIdentity(dataSize, dataSize);
                centeringMatrix.scaleEquals(1.0 / (double)dataSize);
                Vector centeringVector = VectorFactory.getDenseDefault().createVector(dataSize, 1.0 / (double)dataSize);
                kInput = kernelVector.clone();
                kInput.minusEquals(centeringVector.times(this.kernelMatrix));
                kInput.minusEquals(kernelVector.times(centeringMatrix));
                kInput.plusEquals(centeringVector.times(this.kernelMatrix.times(centeringMatrix)));
            }
            return this.components.times(kInput);
        }

        @Override
        public int getOutputDimensionality() {
            return this.components.getNumRows();
        }

        public int getComponentCount() {
            return this.components.getNumRows();
        }

        public List<? extends DataType> getData() {
            return this.data;
        }

        public void setData(List<? extends DataType> data) {
            this.data = data;
        }

        public Matrix getComponents() {
            return this.components;
        }

        public void setComponents(Matrix components) {
            this.components = components;
        }

        public boolean isCenterData() {
            return this.centerData;
        }

        public void setCenterData(boolean centerData) {
            this.centerData = centerData;
        }

        public Matrix getKernelMatrix() {
            return this.kernelMatrix;
        }

        public void setKernelMatrix(Matrix kernelMatrix) {
            this.kernelMatrix = kernelMatrix;
        }
    }
}

