/*
 * Decompiled with CFR 0.152.
 */
package gov.sandia.cognition.learning.algorithm.svm;

import gov.sandia.cognition.annotation.PublicationReference;
import gov.sandia.cognition.annotation.PublicationType;
import gov.sandia.cognition.collection.CollectionUtil;
import gov.sandia.cognition.learning.algorithm.AbstractAnytimeSupervisedBatchLearner;
import gov.sandia.cognition.learning.data.DatasetUtil;
import gov.sandia.cognition.learning.data.InputOutputPair;
import gov.sandia.cognition.learning.function.categorization.LinearBinaryCategorizer;
import gov.sandia.cognition.math.matrix.Vector;
import gov.sandia.cognition.math.matrix.VectorFactory;
import gov.sandia.cognition.math.matrix.Vectorizable;
import gov.sandia.cognition.statistics.DiscreteSamplingUtil;
import gov.sandia.cognition.util.ArgumentChecker;
import gov.sandia.cognition.util.Randomized;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.Random;

@PublicationReference(author={"Shai Shalev-Shwartz", "Yoram Singer", "Nathan Srebro"}, title="Pegasos: Primal Estimated sub-GrAdient SOlver for SVM", year=2007, type=PublicationType.Conference, publication="Proceedings of the 24th International Conference on Machine Learning", url="http://citeseerx.ist.psu.edu/viewdoc/summary?doi=10.1.1.74.8513")
public class PrimalEstimatedSubGradient
extends AbstractAnytimeSupervisedBatchLearner<Vectorizable, Boolean, LinearBinaryCategorizer>
implements Randomized {
    public static final int DEFAULT_SAMPLE_SIZE = 100;
    public static final double DEFAULT_REGULARIZATION_WEIGHT = 1.0E-4;
    public static final int DEFAULT_MAX_ITERATIONS = 10000;
    protected int sampleSize;
    protected double regularizationWeight;
    protected Random random;
    protected transient int dataSize;
    protected transient ArrayList<? extends InputOutputPair<? extends Vectorizable, Boolean>> dataList;
    protected transient int dimensionality;
    protected transient int dataSampleSize;
    protected transient Vector update;
    protected transient LinearBinaryCategorizer result;

    public PrimalEstimatedSubGradient() {
        this(100, 1.0E-4, 10000, new Random());
    }

    public PrimalEstimatedSubGradient(int sampleSize, double regularizationWeight, int maxIterations, Random random) {
        super(maxIterations);
        this.setSampleSize(sampleSize);
        this.setRegularizationWeight(regularizationWeight);
        this.setRandom(random);
    }

    @Override
    protected boolean initializeAlgorithm() {
        if (CollectionUtil.isEmpty((Collection)this.data)) {
            return false;
        }
        this.dataSize = ((Collection)this.data).size();
        this.dataList = CollectionUtil.asArrayList((Iterable)this.data);
        this.dimensionality = DatasetUtil.getInputDimensionality((Iterable)this.data);
        this.dataSampleSize = Math.min(this.dataSize, this.sampleSize);
        VectorFactory<Vector> vectorFactory = VectorFactory.getDenseDefault();
        this.update = vectorFactory.createVector(this.dimensionality);
        double lambda = this.regularizationWeight;
        double sqrtLambda = Math.sqrt(lambda);
        double initializationRange = 1.0 / ((double)this.dimensionality * sqrtLambda);
        Vector initialWeights = vectorFactory.createUniformRandom(this.dimensionality, -initializationRange, initializationRange, this.random);
        if (initialWeights.norm2() < 1.0 / sqrtLambda) {
            initialWeights.unitVectorEquals();
            initialWeights.scaleEquals(1.0 / sqrtLambda);
        }
        this.result = new LinearBinaryCategorizer(initialWeights, 0.0);
        this.update = vectorFactory.createVector(this.dimensionality);
        return true;
    }

    @Override
    protected boolean step() {
        List<? extends InputOutputPair<? extends Vectorizable, Boolean>> subSet = DiscreteSamplingUtil.sampleWithoutReplacement(this.random, this.dataList, this.dataSampleSize);
        double lambda = this.regularizationWeight;
        double learningRate = 1.0 / (lambda * (double)this.iteration);
        Vector weights = this.result.getWeights();
        weights.scaleEquals(1.0 - learningRate * lambda);
        this.update.zero();
        int errorCount = 0;
        for (InputOutputPair<? extends Vectorizable, Boolean> inputOutputPair : subSet) {
            boolean output = inputOutputPair.getOutput();
            double d = output ? 1.0 : -1.0;
            double actual = d;
            double predicted = this.result.evaluateAsDouble(inputOutputPair.getInput());
            if (!(actual * predicted < 1.0)) continue;
            ++errorCount;
            Vector input = inputOutputPair.getInput().convertToVector();
            if (output) {
                this.update.plusEquals(input);
                continue;
            }
            this.update.minusEquals(input);
        }
        this.update.scaleEquals(learningRate / (double)this.dataSampleSize);
        weights.plusEquals(this.update);
        double norm2Squared = weights.norm2Squared();
        double projection = 1.0 / Math.sqrt(lambda * norm2Squared);
        if (projection < 1.0) {
            weights.scaleEquals(projection);
        }
        this.result.setWeights(weights);
        return true;
    }

    @Override
    protected void cleanupAlgorithm() {
        this.dataList = null;
        this.update = null;
    }

    @Override
    public LinearBinaryCategorizer getResult() {
        return this.result;
    }

    public int getSampleSize() {
        return this.sampleSize;
    }

    public void setSampleSize(int sampleSize) {
        ArgumentChecker.assertIsPositive("sampleSize", sampleSize);
        this.sampleSize = sampleSize;
    }

    public double getRegularizationWeight() {
        return this.regularizationWeight;
    }

    public void setRegularizationWeight(double regularizationWeight) {
        ArgumentChecker.assertIsPositive("regularizationWeight", regularizationWeight);
        this.regularizationWeight = regularizationWeight;
    }

    @Override
    public Random getRandom() {
        return this.random;
    }

    @Override
    public void setRandom(Random random) {
        this.random = random;
    }
}

