/*
 * Decompiled with CFR 0.152.
 */
package gov.sandia.cognition.learning.algorithm.confidence;

import gov.sandia.cognition.annotation.PublicationReference;
import gov.sandia.cognition.annotation.PublicationType;
import gov.sandia.cognition.learning.algorithm.AbstractSupervisedBatchAndIncrementalLearner;
import gov.sandia.cognition.learning.function.categorization.DefaultConfidenceWeightedBinaryCategorizer;
import gov.sandia.cognition.math.matrix.Matrix;
import gov.sandia.cognition.math.matrix.MatrixFactory;
import gov.sandia.cognition.math.matrix.Vector;
import gov.sandia.cognition.math.matrix.VectorFactory;
import gov.sandia.cognition.math.matrix.Vectorizable;
import gov.sandia.cognition.util.ArgumentChecker;

@PublicationReference(author={"Koby Crammer", "Alex Kulesza", "Mark Dredze"}, title="Adpative Regularization of Weight Vectors", year=2009, type=PublicationType.Conference, publication="Advances in Neural Information Processing Systems", url="http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.169.4127&rep=rep1&type=pdf")
public class AdaptiveRegularizationOfWeights
extends AbstractSupervisedBatchAndIncrementalLearner<Vectorizable, Boolean, DefaultConfidenceWeightedBinaryCategorizer> {
    public static final double DEFAULT_R = 0.001;
    protected double r;

    public AdaptiveRegularizationOfWeights() {
        this(0.001);
    }

    public AdaptiveRegularizationOfWeights(double r) {
        this.setR(r);
    }

    @Override
    public DefaultConfidenceWeightedBinaryCategorizer createInitialLearnedObject() {
        return new DefaultConfidenceWeightedBinaryCategorizer();
    }

    @Override
    public void update(DefaultConfidenceWeightedBinaryCategorizer target, Vectorizable input, Boolean output) {
        if (input != null && output != null) {
            this.update(target, input.convertToVector(), (boolean)output);
        }
    }

    @Override
    public void update(DefaultConfidenceWeightedBinaryCategorizer target, Vector input, boolean label) {
        boolean error;
        Matrix covariance;
        Vector mean;
        if (!target.isInitialized()) {
            int dimensionality = input.getDimensionality();
            mean = VectorFactory.getDenseDefault().createVector(dimensionality);
            covariance = MatrixFactory.getDenseDefault().createIdentity(dimensionality, dimensionality);
            target.setMean(mean);
            target.setCovariance(covariance);
        } else {
            mean = target.getMean();
            covariance = target.getCovariance();
        }
        double predicted = input.dotProduct(mean);
        double actual = label ? 1.0 : -1.0;
        double margin = actual * predicted;
        boolean bl = error = margin < 1.0;
        if (error) {
            Vector covarianceTimesInput = input.times(covariance);
            double marginVariance = covarianceTimesInput.dotProduct(input);
            double beta = 1.0 / (marginVariance + this.r);
            double alpha = Math.max(0.0, 1.0 - margin) * beta;
            Vector meanUpdate = input.times(covariance);
            meanUpdate.scaleEquals(alpha * actual);
            mean.plusEquals(meanUpdate);
            Matrix covarianceUpdate = covarianceTimesInput.outerProduct(covarianceTimesInput);
            covarianceUpdate.scaleEquals(-beta);
            covariance.plusEquals(covarianceUpdate);
        }
    }

    public double getR() {
        return this.r;
    }

    public void setR(double r) {
        ArgumentChecker.assertIsPositive("r", r);
        this.r = r;
    }
}

