/*
 * Decompiled with CFR 0.152.
 */
package gov.sandia.cognition.learning.function.categorization;

import gov.sandia.cognition.annotation.PublicationReference;
import gov.sandia.cognition.annotation.PublicationType;
import gov.sandia.cognition.learning.algorithm.SupervisedBatchLearner;
import gov.sandia.cognition.learning.data.DatasetUtil;
import gov.sandia.cognition.learning.data.DefaultInputOutputPair;
import gov.sandia.cognition.learning.data.InputOutputPair;
import gov.sandia.cognition.learning.function.categorization.ScalarFunctionToBinaryCategorizerAdapter;
import gov.sandia.cognition.learning.function.scalar.LinearDiscriminant;
import gov.sandia.cognition.math.MultivariateStatisticsUtil;
import gov.sandia.cognition.math.matrix.Matrix;
import gov.sandia.cognition.math.matrix.MatrixFactory;
import gov.sandia.cognition.math.matrix.Vector;
import gov.sandia.cognition.statistics.method.ReceiverOperatingCharacteristic;
import gov.sandia.cognition.util.AbstractCloneableSerializable;
import gov.sandia.cognition.util.DefaultPair;
import gov.sandia.cognition.util.Pair;
import java.util.ArrayList;
import java.util.Collection;
import java.util.LinkedList;

@PublicationReference(author={"Wikipedia"}, title="Linear discriminant analysis", type=PublicationType.WebPage, year=2009, url="http://en.wikipedia.org/wiki/Linear_discriminant_analysis#Fisher.27s_linear_discriminant")
public class FisherLinearDiscriminantBinaryCategorizer
extends ScalarFunctionToBinaryCategorizerAdapter<Vector> {
    public FisherLinearDiscriminantBinaryCategorizer() {
        this((Vector)null, 0.0);
    }

    public FisherLinearDiscriminantBinaryCategorizer(Vector weightVector, double threshold) {
        this(new LinearDiscriminant(weightVector), threshold);
    }

    public FisherLinearDiscriminantBinaryCategorizer(LinearDiscriminant discriminant, double threshold) {
        super(discriminant, threshold);
    }

    @Override
    public FisherLinearDiscriminantBinaryCategorizer clone() {
        return (FisherLinearDiscriminantBinaryCategorizer)super.clone();
    }

    public static class ClosedFormSolver
    extends AbstractCloneableSerializable
    implements SupervisedBatchLearner<Vector, Boolean, FisherLinearDiscriminantBinaryCategorizer> {
        private double defaultCovariance;

        public ClosedFormSolver() {
            this(1.0E-5);
        }

        public ClosedFormSolver(double defaultCovariance) {
            this.defaultCovariance = defaultCovariance;
        }

        @Override
        public FisherLinearDiscriminantBinaryCategorizer learn(Collection<? extends InputOutputPair<? extends Vector, Boolean>> data) {
            return ClosedFormSolver.learn(data, this.defaultCovariance);
        }

        public static FisherLinearDiscriminantBinaryCategorizer learn(Collection<? extends InputOutputPair<? extends Vector, Boolean>> data, double defaultCovariance) {
            Matrix cinverse;
            DefaultPair pair = DatasetUtil.splitDatasets(data);
            LinkedList d1 = pair.getFirst();
            LinkedList d0 = pair.getSecond();
            Pair<Vector, Matrix> r1 = MultivariateStatisticsUtil.computeMeanAndCovariance(d1);
            Vector m1 = r1.getFirst();
            Matrix c1 = r1.getSecond();
            Pair<Vector, Matrix> r0 = MultivariateStatisticsUtil.computeMeanAndCovariance(d0);
            Vector m0 = r0.getFirst();
            Matrix c0 = r0.getSecond();
            if (defaultCovariance != 0.0) {
                int M = m0.getDimensionality();
                Matrix ci = (Matrix)MatrixFactory.getDefault().createIdentity(M, M).scale(defaultCovariance);
                cinverse = c0.plus(c1.plus(ci)).inverse();
            } else {
                cinverse = c0.plus(c1).inverse();
            }
            Vector weightVector = cinverse.times(m1.minus(m0));
            System.out.println("Weights: " + weightVector);
            LinearDiscriminant discriminant = new LinearDiscriminant(weightVector);
            ArrayList<DefaultInputOutputPair<Double, Boolean>> doubleData = new ArrayList<DefaultInputOutputPair<Double, Boolean>>(data.size());
            for (InputOutputPair<? extends Vector, Boolean> inputOutputPair : data) {
                Double value = discriminant.evaluate(inputOutputPair.getInput());
                doubleData.add(new DefaultInputOutputPair<Double, Boolean>(value, inputOutputPair.getOutput()));
            }
            ReceiverOperatingCharacteristic roc = ReceiverOperatingCharacteristic.create(doubleData);
            ReceiverOperatingCharacteristic.Statistic statistic = roc.computeStatistics();
            System.out.println("Optimal Threshold: " + statistic.getOptimalThreshold().getClassifier().getThreshold());
            System.out.println("Confusion Matrix:\n" + statistic.getOptimalThreshold().getConfusionMatrix());
            System.out.println("AUC: " + statistic.getAreaUnderCurve() + " d': " + statistic.getDPrime());
            System.out.println("p-value: " + statistic.getNullHypothesisProbability());
            return new FisherLinearDiscriminantBinaryCategorizer(discriminant, statistic.getOptimalThreshold().getClassifier().getThreshold());
        }
    }
}

