/*
 * Decompiled with CFR 0.152.
 */
package org.openimaj.math.matrix.algorithm;

import Jama.Matrix;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.openimaj.citation.annotation.Reference;
import org.openimaj.citation.annotation.ReferenceType;
import org.openimaj.math.matrix.GeneralisedEigenvalueProblem;
import org.openimaj.math.matrix.MatrixUtils;
import org.openimaj.util.array.ArrayUtils;
import org.openimaj.util.pair.IndependentPair;

@Reference(type=ReferenceType.Article, author={"Fisher, Ronald A."}, title="{The use of multiple measurements in taxonomic problems}", year="1936", journal="Annals Eugen.", pages={"179", "", "188"}, volume="7", customData={"citeulike-article-id", "764226", "keywords", "classification", "posted-at", "2006-09-18 14:06:16", "priority", "2"})
public class LinearDiscriminantAnalysis {
    protected int numComponents;
    protected Matrix eigenvectors;
    protected double[] eigenvalues;
    protected double[] mean;

    public LinearDiscriminantAnalysis(int numComponents) {
        this.numComponents = numComponents;
    }

    private MeanData computeMeans(List<double[][]> data) {
        int i;
        int cols = data.get(0)[0].length;
        int numClasses = data.size();
        MeanData md = new MeanData();
        md.overallMean = new double[cols];
        md.classMeans = new double[numClasses][];
        md.numInstances = 0;
        for (i = 0; i < numClasses; ++i) {
            double[][] classData = data.get(i);
            int classSize = classData.length;
            md.classMeans[i] = this.computeSum(classData);
            md.numInstances += classSize;
            int j = 0;
            while (j < cols) {
                int n = j;
                md.overallMean[n] = md.overallMean[n] + md.classMeans[i][j];
                double[] dArray = md.classMeans[i];
                int n2 = j++;
                dArray[n2] = dArray[n2] / (double)classSize;
            }
        }
        i = 0;
        while (i < cols) {
            int n = i++;
            md.overallMean[n] = md.overallMean[n] / (double)md.numInstances;
        }
        return md;
    }

    private double[] computeSum(double[][] data) {
        double[] sum = new double[data[0].length];
        for (int j = 0; j < data.length; ++j) {
            for (int i = 0; i < sum.length; ++i) {
                int n = i;
                sum[n] = sum[n] + data[j][i];
            }
        }
        return sum;
    }

    public void learnBasisIP(List<? extends IndependentPair<?, double[]>> data) {
        HashMap mapData = new HashMap();
        for (IndependentPair<?, double[]> item : data) {
            ArrayList<double[]> fvs = (ArrayList<double[]>)mapData.get(item.firstObject());
            if (fvs == null) {
                fvs = new ArrayList<double[]>();
                mapData.put(item.firstObject(), fvs);
            }
            fvs.add(item.getSecondObject());
        }
        this.learnBasisML(mapData);
    }

    public void learnBasisML(Map<?, List<double[]>> data) {
        ArrayList<double[][]> list = new ArrayList<double[][]>();
        for (Map.Entry<?, List<double[]>> e : data.entrySet()) {
            list.add((double[][])e.getValue().toArray((T[])new double[e.getValue().size()][]));
        }
        this.learnBasis(list);
    }

    public void learnBasisLL(List<List<double[]>> data) {
        ArrayList<double[][]> list = new ArrayList<double[][]>();
        for (List<double[]> e : data) {
            list.add((double[][])e.toArray((T[])new double[e.size()][]));
        }
        this.learnBasis(list);
    }

    public void learnBasis(Map<?, double[][]> data) {
        ArrayList<double[][]> list = new ArrayList<double[][]>();
        for (Map.Entry<?, double[][]> e : data.entrySet()) {
            list.add(e.getValue());
        }
        this.learnBasis(data);
    }

    public void learnBasis(List<double[][]> data) {
        int c = data.size();
        if (c < 0 || this.numComponents >= c) {
            this.numComponents = c - 1;
        }
        MeanData meanData = this.computeMeans(data);
        this.mean = meanData.overallMean;
        double[][] classMeans = meanData.classMeans;
        Matrix Sw = new Matrix(this.mean.length, this.mean.length);
        Matrix Sb = new Matrix(this.mean.length, this.mean.length);
        for (int i = 0; i < c; ++i) {
            Matrix classData = new Matrix(data.get(i));
            double[] classMean = classMeans[i];
            Matrix zeroCentred = MatrixUtils.minusRow(classData, classMean);
            MatrixUtils.plusEquals(Sw, zeroCentred.transpose().times(zeroCentred));
            ArrayUtils.subtract(classMean, this.mean);
            Matrix diff = new Matrix(new double[][]{classMean});
            MatrixUtils.plusEquals(Sb, MatrixUtils.times(diff.transpose().times(diff), meanData.numInstances));
        }
        IndependentPair<Matrix, double[]> evs = GeneralisedEigenvalueProblem.symmetricGeneralisedEigenvectorsSorted(Sb, Sw, this.numComponents);
        this.eigenvectors = evs.firstObject();
        this.eigenvalues = evs.secondObject();
    }

    public Matrix getBasis() {
        return this.eigenvectors;
    }

    public double[] getBasisVector(int index) {
        double[] pc = new double[this.eigenvectors.getRowDimension()];
        double[][] data = this.eigenvectors.getArray();
        for (int r = 0; r < pc.length; ++r) {
            pc[r] = data[r][index];
        }
        return pc;
    }

    public Matrix getEigenVectors() {
        return this.eigenvectors;
    }

    public double[] getEigenValues() {
        return this.eigenvalues;
    }

    public double getEigenValue(int i) {
        return this.eigenvalues[i];
    }

    public double[] getMean() {
        return this.mean;
    }

    public double[] generate(double[] scalings) {
        Matrix scale = new Matrix(this.eigenvalues.length, 1);
        for (int i = 0; i < Math.min(this.eigenvalues.length, scalings.length); ++i) {
            scale.set(i, 0, scalings[i]);
        }
        Matrix meanMatrix = new Matrix(new double[][]{this.mean}).transpose();
        return meanMatrix.plus(this.eigenvectors.times(scale)).getColumnPackedCopy();
    }

    public Matrix project(Matrix m) {
        Matrix vec = m.copy();
        int rows = vec.getRowDimension();
        int cols = vec.getColumnDimension();
        double[][] vecarr = vec.getArray();
        for (int r = 0; r < rows; ++r) {
            for (int c = 0; c < cols; ++c) {
                double[] dArray = vecarr[r];
                int n = c;
                dArray[n] = dArray[n] - this.mean[c];
            }
        }
        return vec.times(this.eigenvectors);
    }

    public double[] project(double[] vector) {
        Matrix vec = new Matrix(1, vector.length);
        double[][] vecarr = vec.getArray();
        for (int i = 0; i < vector.length; ++i) {
            vecarr[0][i] = vector[i] - this.mean[i];
        }
        return vec.times(this.eigenvectors).getColumnPackedCopy();
    }

    private static class MeanData {
        double[] overallMean;
        double[][] classMeans;
        int numInstances;

        private MeanData() {
        }
    }
}

