/*
 * Decompiled with CFR 0.152.
 */
package gov.sandia.cognition.learning.algorithm.regression;

import gov.sandia.cognition.annotation.PublicationReference;
import gov.sandia.cognition.annotation.PublicationReferences;
import gov.sandia.cognition.annotation.PublicationType;
import gov.sandia.cognition.evaluator.CompositeEvaluatorPair;
import gov.sandia.cognition.learning.algorithm.AbstractAnytimeSupervisedBatchLearner;
import gov.sandia.cognition.learning.data.DatasetUtil;
import gov.sandia.cognition.learning.data.InputOutputPair;
import gov.sandia.cognition.learning.function.scalar.LinearDiscriminant;
import gov.sandia.cognition.learning.function.scalar.SigmoidFunction;
import gov.sandia.cognition.math.matrix.DiagonalMatrix;
import gov.sandia.cognition.math.matrix.Matrix;
import gov.sandia.cognition.math.matrix.MatrixFactory;
import gov.sandia.cognition.math.matrix.Vector;
import gov.sandia.cognition.math.matrix.VectorFactory;
import gov.sandia.cognition.math.matrix.Vectorizable;
import gov.sandia.cognition.util.ObjectUtil;
import java.util.Collection;

@PublicationReferences(references={@PublicationReference(author={"Tommi S. Jaakkola"}, title="Machine learning: lecture 5", type=PublicationType.WebPage, year=2004, url="http://www.ai.mit.edu/courses/6.867-f04/lectures/lecture-5-ho.pdf", notes={"Good formulation of logistic regression on slides 15-20"}), @PublicationReference(author={"Paul Komarek", "Andrew Moore"}, title="Making Logistic Regression A Core Data Mining Tool With TR-IRLS", publication="Proceedings of the 5th International Conference on Data Mining Machine Learning", type=PublicationType.Conference, year=2005, url="http://www.autonlab.org/autonweb/14717.html", notes={"Good practical overview of logistic regression"}), @PublicationReference(author={"Christopher M. Bishop"}, title="Pattern Recognition and Machine Learning", type=PublicationType.Book, year=2006, pages={207, 208}, notes={"Section 4.3.3"})})
public class LogisticRegression
extends AbstractAnytimeSupervisedBatchLearner<Vector, Double, Function> {
    public static final int DEFAULT_MAX_ITERATIONS = 100;
    public static final double DEFAULT_TOLERANCE = 1.0E-10;
    private Function objectToOptimize;
    private Function result;
    private double tolerance;
    private transient DiagonalMatrix W;
    private transient DiagonalMatrix R;
    private transient Matrix X;
    private transient Matrix Xt;
    private transient Vector err;

    public LogisticRegression() {
        this(1.0E-10);
    }

    public LogisticRegression(double tolerance) {
        this(tolerance, 100);
    }

    public LogisticRegression(double tolerance, int maxIterations) {
        super(maxIterations);
        this.setTolerance(tolerance);
    }

    @Override
    public LogisticRegression clone() {
        LogisticRegression clone = (LogisticRegression)super.clone();
        clone.setObjectToOptimize(ObjectUtil.cloneSafe(this.getObjectToOptimize()));
        clone.setResult(ObjectUtil.cloneSafe(this.getResult()));
        return clone;
    }

    @Override
    protected boolean initializeAlgorithm() {
        int M = ((Vector)((InputOutputPair)((Collection)this.data).iterator().next()).getInput()).getDimensionality();
        int N = ((Collection)this.data).size();
        if (this.getObjectToOptimize() == null) {
            this.setObjectToOptimize(new Function(M));
        }
        this.setResult(this.getObjectToOptimize().clone());
        this.R = MatrixFactory.getDiagonalDefault().createMatrix(N, N);
        this.X = MatrixFactory.getDefault().createMatrix(M, N);
        this.err = VectorFactory.getDefault().createVector(N);
        this.W = MatrixFactory.getDiagonalDefault().createMatrix(N, N);
        int n = 0;
        for (InputOutputPair sample : (Collection)this.data) {
            this.X.setColumn(n, (Vector)sample.getInput());
            this.W.setElement(n, DatasetUtil.getWeight(sample));
            ++n;
        }
        this.Xt = this.X.transpose();
        return true;
    }

    @Override
    protected boolean step() {
        int n = 0;
        Function f = this.getResult();
        for (InputOutputPair sample : (Collection)this.data) {
            double yhat = (Double)f.evaluate(sample.getInput());
            double y = (Double)sample.getOutput();
            this.err.setElement(n, y - yhat);
            this.R.setElement(n, yhat * (1.0 - yhat));
            ++n;
        }
        Vector w = f.convertToVector();
        Vector z = w.times(this.X).plus(this.R.inverse().times(this.err));
        Matrix lhs = this.X.times(this.W.times(this.R.times(this.Xt)));
        Vector rhs = this.X.times(this.W.times(this.R.times(z)));
        Vector wnew = lhs.solve(rhs);
        f.convertFromVector(wnew);
        double delta = wnew.minus(w).norm2();
        return delta > this.getTolerance();
    }

    @Override
    protected void cleanupAlgorithm() {
        this.X = null;
        this.Xt = null;
        this.err = null;
        this.R = null;
        this.W = null;
    }

    public Function getObjectToOptimize() {
        return this.objectToOptimize;
    }

    public void setObjectToOptimize(Function objectToOptimize) {
        this.objectToOptimize = objectToOptimize;
    }

    @Override
    public Function getResult() {
        return this.result;
    }

    public void setResult(Function result) {
        this.result = result;
    }

    public double getTolerance() {
        return this.tolerance;
    }

    public void setTolerance(double tolerance) {
        this.tolerance = tolerance;
    }

    public static class Function
    extends CompositeEvaluatorPair<Vector, Double, Double>
    implements Vectorizable {
        public Function(int dimensionality) {
            super(new LinearDiscriminant(VectorFactory.getDefault().createVector(dimensionality)), new SigmoidFunction());
        }

        @Override
        public Function clone() {
            return (Function)super.clone();
        }

        @Override
        public Vector convertToVector() {
            return ((LinearDiscriminant)this.getFirst()).convertToVector();
        }

        @Override
        public void convertFromVector(Vector parameters) {
            ((LinearDiscriminant)this.getFirst()).convertFromVector(parameters);
        }
    }
}

