/*
 * Decompiled with CFR 0.152.
 */
package gov.sandia.cognition.learning.algorithm.confidence;

import gov.sandia.cognition.annotation.PublicationReference;
import gov.sandia.cognition.annotation.PublicationType;
import gov.sandia.cognition.learning.algorithm.confidence.ConfidenceWeightedDiagonalVariance;
import gov.sandia.cognition.learning.function.categorization.DiagonalConfidenceWeightedBinaryCategorizer;
import gov.sandia.cognition.math.matrix.Vector;
import gov.sandia.cognition.math.matrix.VectorEntry;
import gov.sandia.cognition.math.matrix.VectorFactory;

@PublicationReference(title="Confidence-Weighted Linear Classification", author={"Mark Dredze", "Koby Crammer", "Fernando Pereira"}, year=2008, type=PublicationType.Conference, publication="International Conference on Machine Learning", url="http://portal.acm.org/citation.cfm?id=1390190")
public class ConfidenceWeightedDiagonalVarianceProject
extends ConfidenceWeightedDiagonalVariance {
    public ConfidenceWeightedDiagonalVarianceProject() {
        this(0.85, 1.0);
    }

    public ConfidenceWeightedDiagonalVarianceProject(double confidence, double defaultVariance) {
        super(confidence, defaultVariance);
    }

    @Override
    public void update(DiagonalConfidenceWeightedBinaryCategorizer target, Vector input, boolean label) {
        double denominator;
        Vector variance;
        Vector mean;
        if (!target.isInitialized()) {
            int dimensionality = input.getDimensionality();
            mean = VectorFactory.getDenseDefault().createVector(dimensionality);
            variance = VectorFactory.getDenseDefault().createVector(dimensionality, this.getDefaultVariance());
            target.setMean(mean);
            target.setVariance(variance);
        } else {
            mean = target.getMean();
            variance = target.getVariance();
        }
        double predicted = input.dotProduct(mean);
        double actual = label ? 1.0 : -1.0;
        double margin = actual * predicted;
        Vector varianceTimesInput = input.dotTimes(variance);
        double marginVariance = input.dotProduct(varianceTimesInput);
        if (marginVariance == 0.0 || margin > this.phi * marginVariance) {
            return;
        }
        double meanPart = 1.0 + 2.0 * this.phi * margin;
        double variancePart = margin - this.phi * marginVariance;
        double numerator = -meanPart + Math.sqrt(meanPart * meanPart - 8.0 * this.phi * variancePart);
        double alpha = numerator / (denominator = 4.0 * this.phi * marginVariance);
        if (alpha <= 0.0) {
            return;
        }
        Vector meanUpdate = (Vector)varianceTimesInput.scale(actual * alpha);
        mean.plusEquals(meanUpdate);
        double twoAlphaPhi = 2.0 * alpha * this.phi;
        for (VectorEntry entry : input) {
            int index = entry.getIndex();
            double value = entry.getValue();
            double sigma = variance.getElement(index);
            double newSigma = 1.0 / sigma + twoAlphaPhi * value * value;
            newSigma = 1.0 / newSigma;
            variance.setElement(index, newSigma);
        }
        target.setMean(mean);
        target.setVariance(variance);
    }
}

