/*
 * Decompiled with CFR 0.152.
 */
package gov.sandia.cognition.learning.algorithm.tree;

import gov.sandia.cognition.learning.algorithm.SupervisedBatchLearner;
import gov.sandia.cognition.learning.algorithm.tree.AbstractDecisionTreeLearner;
import gov.sandia.cognition.learning.algorithm.tree.AbstractDecisionTreeNode;
import gov.sandia.cognition.learning.algorithm.tree.CategorizationTree;
import gov.sandia.cognition.learning.algorithm.tree.CategorizationTreeNode;
import gov.sandia.cognition.learning.algorithm.tree.DeciderLearner;
import gov.sandia.cognition.learning.data.InputOutputPair;
import gov.sandia.cognition.learning.function.categorization.Categorizer;
import gov.sandia.cognition.statistics.distribution.MapBasedDataHistogram;
import gov.sandia.cognition.util.ArgumentChecker;
import java.util.Collection;
import java.util.HashSet;

public class CategorizationTreeLearner<InputType, OutputType>
extends AbstractDecisionTreeLearner<InputType, OutputType>
implements SupervisedBatchLearner<InputType, OutputType, CategorizationTree<InputType, OutputType>> {
    public static final int DEFAULT_LEAF_COUNT_THRESHOLD = 1;
    public static final int DEFAULT_MAX_DEPTH = -1;
    protected int leafCountThreshold;
    protected int maxDepth;

    public CategorizationTreeLearner() {
        this(null);
    }

    public CategorizationTreeLearner(DeciderLearner<? super InputType, OutputType, ?, ?> deciderLearner) {
        this(deciderLearner, 1, -1);
    }

    public CategorizationTreeLearner(DeciderLearner<? super InputType, OutputType, ?, ?> deciderLearner, int leafCountThreshold, int maxDepth) {
        super(deciderLearner);
        this.setLeafCountThreshold(leafCountThreshold);
        this.setMaxDepth(maxDepth);
    }

    @Override
    public CategorizationTree<InputType, OutputType> learn(Collection<? extends InputOutputPair<? extends InputType, OutputType>> data) {
        if (data == null) {
            return null;
        }
        MapBasedDataHistogram<OutputType> rootCounts = CategorizationTreeLearner.getOutputCounts(data);
        return new CategorizationTree(this.learnNode((Collection)data, (AbstractDecisionTreeNode)null), new HashSet(rootCounts.getDomain()));
    }

    @Override
    protected CategorizationTreeNode<InputType, OutputType, ?> learnNode(Collection<? extends InputOutputPair<? extends InputType, OutputType>> data, AbstractDecisionTreeNode<InputType, OutputType, ?> parent) {
        boolean isLeaf;
        if (data == null || data.size() <= 0) {
            return null;
        }
        OutputType mostCommonOutput = CategorizationTreeLearner.getOutputCounts(data).getMaximumValue();
        CategorizationTreeNode node = new CategorizationTreeNode(parent, mostCommonOutput);
        boolean bl = isLeaf = this.areAllOutputsEqual(data) || data.size() <= this.leafCountThreshold || this.maxDepth > 0 && node.getDepth() >= this.maxDepth;
        if (!isLeaf) {
            Categorizer decider = (Categorizer)this.getDeciderLearner().learn(data);
            if (decider != null) {
                node.setDecider(decider);
                super.learnChildNodes(node, data, decider);
            } else {
                isLeaf = true;
            }
        }
        return node;
    }

    public static <OutputType> MapBasedDataHistogram<OutputType> getOutputCounts(Collection<? extends InputOutputPair<?, OutputType>> data) {
        MapBasedDataHistogram<OutputType> counts = new MapBasedDataHistogram<OutputType>();
        if (data == null) {
            return counts;
        }
        for (InputOutputPair<?, OutputType> example : data) {
            OutputType output = example.getOutput();
            counts.add(output);
        }
        return counts;
    }

    public int getLeafCountThreshold() {
        return this.leafCountThreshold;
    }

    public void setLeafCountThreshold(int leafCountThreshold) {
        ArgumentChecker.assertIsNonNegative("leafCountThreshold", leafCountThreshold);
        this.leafCountThreshold = leafCountThreshold;
    }

    public int getMaxDepth() {
        return this.maxDepth;
    }

    public void setMaxDepth(int maxDepth) {
        this.maxDepth = maxDepth;
    }
}

