/*
 * Decompiled with CFR 0.152.
 */
package gov.sandia.cognition.learning.algorithm.regression;

import gov.sandia.cognition.evaluator.Evaluator;
import gov.sandia.cognition.learning.algorithm.AbstractAnytimeSupervisedBatchLearner;
import gov.sandia.cognition.learning.algorithm.SupervisedBatchLearner;
import gov.sandia.cognition.learning.data.DatasetUtil;
import gov.sandia.cognition.learning.data.DefaultWeightedInputOutputPair;
import gov.sandia.cognition.learning.data.InputOutputPair;
import gov.sandia.cognition.learning.function.kernel.Kernel;
import java.util.ArrayList;
import java.util.Collection;

public class KernelWeightedRobustRegression<InputType, OutputType>
extends AbstractAnytimeSupervisedBatchLearner<InputType, OutputType, Evaluator<? super InputType, ? extends OutputType>> {
    private Evaluator<? super InputType, ? extends OutputType> result;
    private SupervisedBatchLearner<InputType, OutputType, ?> iterationLearner;
    private Kernel<? super OutputType> kernelWeightingFunction;
    private double tolerance;
    public static final int DEFAULT_MAX_ITERATIONS = 100;
    public static final double DEFAULT_TOLERANCE = 1.0E-5;
    private ArrayList<DefaultWeightedInputOutputPair<InputType, OutputType>> weightedData;

    public KernelWeightedRobustRegression(SupervisedBatchLearner<InputType, OutputType, ?> iterationLearner, Kernel<? super OutputType> kernelWeightingFunction) {
        this(iterationLearner, kernelWeightingFunction, 100, 1.0E-5);
    }

    public KernelWeightedRobustRegression(SupervisedBatchLearner<InputType, OutputType, ?> iterationLearner, Kernel<? super OutputType> kernelWeightingFunction, int maxIterations, double tolerance) {
        super(maxIterations);
        this.setLearned(null);
        this.setTolerance(tolerance);
        this.setKernelWeightingFunction(kernelWeightingFunction);
        this.setIterationLearner(iterationLearner);
    }

    @Override
    protected boolean initializeAlgorithm() {
        this.weightedData = new ArrayList(((Collection)this.data).size());
        for (InputOutputPair pair : (Collection)this.data) {
            double weight = DatasetUtil.getWeight(pair);
            this.weightedData.add(new DefaultWeightedInputOutputPair(pair.getInput(), pair.getOutput(), weight));
        }
        return true;
    }

    @Override
    protected boolean step() {
        this.result = (Evaluator)this.iterationLearner.learn(this.weightedData);
        double change = this.updateWeights(this.result);
        return change > this.tolerance;
    }

    @Override
    protected void cleanupAlgorithm() {
    }

    private double updateWeights(Evaluator<? super InputType, ? extends OutputType> f) {
        double change = 0.0;
        for (DefaultWeightedInputOutputPair<InputType, OutputType> pair : this.weightedData) {
            OutputType yhat = f.evaluate(pair.getInput());
            double weightNew = this.kernelWeightingFunction.evaluate(pair.getOutput(), yhat);
            double weightOld = pair.getWeight();
            System.out.println("Y: " + pair.getOutput() + " yhat: " + yhat + " weightNew: " + weightNew + " weightOld: " + weightOld);
            change += Math.abs(weightNew - weightOld);
            pair.setWeight(weightNew);
        }
        return change /= (double)this.weightedData.size();
    }

    public Kernel<? super OutputType> getKernelWeightingFunction() {
        return this.kernelWeightingFunction;
    }

    public void setKernelWeightingFunction(Kernel<? super OutputType> kernelWeightingFunction) {
        this.kernelWeightingFunction = kernelWeightingFunction;
    }

    public double getTolerance() {
        return this.tolerance;
    }

    public void setTolerance(double tolerance) {
        if (tolerance <= 0.0) {
            throw new IllegalArgumentException("Tolerance must be > 0.0");
        }
        this.tolerance = tolerance;
    }

    public void setLearned(Evaluator<InputType, OutputType> result) {
        this.result = result;
    }

    @Override
    public Evaluator<? super InputType, ? extends OutputType> getResult() {
        return this.result;
    }

    public SupervisedBatchLearner<InputType, OutputType, ?> getIterationLearner() {
        return this.iterationLearner;
    }

    public void setIterationLearner(SupervisedBatchLearner<InputType, OutputType, ?> iterationLearner) {
        this.iterationLearner = iterationLearner;
    }
}

