/*
 * Decompiled with CFR 0.152.
 */
package gov.sandia.cognition.learning.algorithm.perceptron;

import gov.sandia.cognition.algorithm.MeasurablePerformanceAlgorithm;
import gov.sandia.cognition.annotation.PublicationReference;
import gov.sandia.cognition.annotation.PublicationType;
import gov.sandia.cognition.collection.CollectionUtil;
import gov.sandia.cognition.learning.algorithm.AbstractAnytimeSupervisedBatchLearner;
import gov.sandia.cognition.learning.data.DatasetUtil;
import gov.sandia.cognition.learning.data.InputOutputPair;
import gov.sandia.cognition.learning.function.categorization.LinearBinaryCategorizer;
import gov.sandia.cognition.learning.function.categorization.LinearMultiCategorizer;
import gov.sandia.cognition.math.matrix.Vector;
import gov.sandia.cognition.math.matrix.VectorFactory;
import gov.sandia.cognition.math.matrix.VectorFactoryContainer;
import gov.sandia.cognition.math.matrix.Vectorizable;
import gov.sandia.cognition.util.ArgumentChecker;
import gov.sandia.cognition.util.DefaultNamedValue;
import gov.sandia.cognition.util.NamedValue;
import gov.sandia.cognition.util.ObjectUtil;
import java.util.Collection;
import java.util.Set;

@PublicationReference(title="Ultraconservative Online Algorithms for Multiclass Problems", author={"Koby Crammer", "Yoram Singer"}, year=2003, type=PublicationType.Journal, publication="Journal of Machine Learning Research", pages={951, 991}, url="http://portal.acm.org/citation.cfm?id=944936")
public class BatchMultiPerceptron<CategoryType>
extends AbstractAnytimeSupervisedBatchLearner<Vectorizable, CategoryType, LinearMultiCategorizer<CategoryType>>
implements MeasurablePerformanceAlgorithm,
VectorFactoryContainer {
    public static final int DEFAULT_MAX_ITERATIONS = 100;
    public static final double DEFAULT_MIN_MARGIN = 0.0;
    protected double minMargin;
    protected VectorFactory<?> vectorFactory;
    protected transient LinearMultiCategorizer<CategoryType> result;
    protected transient int errorCount;

    public BatchMultiPerceptron() {
        this(100);
    }

    public BatchMultiPerceptron(int maxIterations) {
        this(maxIterations, 0.0);
    }

    public BatchMultiPerceptron(int maxIterations, double minMargin) {
        this(maxIterations, minMargin, VectorFactory.getDefault());
    }

    public BatchMultiPerceptron(int maxIterations, double minMargin, VectorFactory<?> vectorFactory) {
        super(maxIterations);
        this.setMinMargin(minMargin);
        this.setVectorFactory(vectorFactory);
    }

    @Override
    protected boolean initializeAlgorithm() {
        if (CollectionUtil.isEmpty((Collection)this.getData())) {
            return false;
        }
        int dimensionality = DatasetUtil.getInputDimensionality((Iterable)this.getData());
        this.result = new LinearMultiCategorizer();
        Set categories = DatasetUtil.findUniqueOutputs((Iterable)this.getData());
        for (Object category : categories) {
            LinearBinaryCategorizer prototype = new LinearBinaryCategorizer((Vector)this.getVectorFactory().createVector(dimensionality), 0.0);
            this.result.getPrototypes().put(category, prototype);
        }
        return true;
    }

    @Override
    protected boolean step() {
        this.setErrorCount(0);
        for (InputOutputPair example : (Collection)this.getData()) {
            if (example == null) continue;
            Vector input = ((Vectorizable)example.getInput()).convertToVector();
            Object actual = example.getOutput();
            Object predicted = null;
            double predictedScore = Double.NEGATIVE_INFINITY;
            for (CategoryType category : this.result.getCategories()) {
                double score = this.result.evaluateAsDouble(input, category);
                if (this.minMargin != 0.0 && actual.equals(category)) {
                    score -= this.minMargin;
                }
                if (!(score > predictedScore)) continue;
                predicted = category;
                predictedScore = score;
            }
            boolean correct = ObjectUtil.equalsSafe(actual, predicted);
            if (correct) continue;
            this.setErrorCount(this.getErrorCount() + 1);
            LinearBinaryCategorizer actualPrototype = this.result.getPrototypes().get(actual);
            actualPrototype.getWeights().plusEquals(input);
            actualPrototype.setBias(actualPrototype.getBias() + 1.0);
            LinearBinaryCategorizer predictedPrototype = this.result.getPrototypes().get(predicted);
            predictedPrototype.getWeights().minusEquals(input);
            predictedPrototype.setBias(predictedPrototype.getBias() - 1.0);
        }
        return this.getErrorCount() > 0;
    }

    @Override
    protected void cleanupAlgorithm() {
    }

    @Override
    public LinearMultiCategorizer<CategoryType> getResult() {
        return this.result;
    }

    protected void setResult(LinearMultiCategorizer<CategoryType> result) {
        this.result = result;
    }

    public double getMinMargin() {
        return this.minMargin;
    }

    public void setMinMargin(double minMargin) {
        ArgumentChecker.assertIsNonNegative("minMargin", minMargin);
        this.minMargin = minMargin;
    }

    public VectorFactory<?> getVectorFactory() {
        return this.vectorFactory;
    }

    public void setVectorFactory(VectorFactory<?> vectorFactory) {
        this.vectorFactory = vectorFactory;
    }

    public int getErrorCount() {
        return this.errorCount;
    }

    protected void setErrorCount(int errorCount) {
        this.errorCount = errorCount;
    }

    public NamedValue<Integer> getPerformance() {
        return new DefaultNamedValue<Integer>("error count", this.getErrorCount());
    }
}

