/*
 * Decompiled with CFR 0.152.
 */
package gov.sandia.cognition.learning.function.categorization;

import gov.sandia.cognition.evaluator.Evaluator;
import gov.sandia.cognition.learning.algorithm.AbstractBatchLearnerContainer;
import gov.sandia.cognition.learning.algorithm.BatchLearner;
import gov.sandia.cognition.learning.algorithm.SupervisedBatchLearner;
import gov.sandia.cognition.learning.data.DefaultInputOutputPair;
import gov.sandia.cognition.learning.data.DefaultWeightedValueDiscriminant;
import gov.sandia.cognition.learning.data.InputOutputPair;
import gov.sandia.cognition.learning.function.categorization.AbstractDiscriminantCategorizer;
import gov.sandia.cognition.math.matrix.Vector;
import gov.sandia.cognition.math.matrix.VectorFactory;
import gov.sandia.cognition.math.matrix.VectorFactoryContainer;
import gov.sandia.cognition.math.matrix.Vectorizable;
import java.util.ArrayList;
import java.util.Collection;
import java.util.LinkedHashMap;
import java.util.LinkedHashSet;
import java.util.Set;

public class WinnerTakeAllCategorizer<InputType, CategoryType>
extends AbstractDiscriminantCategorizer<InputType, CategoryType, Double> {
    protected Evaluator<? super InputType, ? extends Vectorizable> evaluator;

    public WinnerTakeAllCategorizer() {
        this(null, new LinkedHashSet());
    }

    public WinnerTakeAllCategorizer(Evaluator<? super InputType, ? extends Vectorizable> evaluator, Set<CategoryType> categories) {
        super(categories);
        this.setEvaluator(evaluator);
    }

    public DefaultWeightedValueDiscriminant<CategoryType> evaluateWithDiscriminant(InputType input) {
        Vector output = this.evaluator.evaluate(input).convertToVector();
        return this.findBestCategory(output);
    }

    public DefaultWeightedValueDiscriminant<CategoryType> findBestCategory(Vector output) {
        output.assertDimensionalityEquals(this.categories.size());
        Object best = null;
        double bestValue = Double.NEGATIVE_INFINITY;
        int index = 0;
        for (Object category : this.categories) {
            double value = output.getElement(index);
            if (best == null || value > bestValue) {
                best = category;
                bestValue = value;
            }
            ++index;
        }
        return new DefaultWeightedValueDiscriminant<Object>(best, bestValue);
    }

    public Evaluator<? super InputType, ? extends Vectorizable> getEvaluator() {
        return this.evaluator;
    }

    public void setEvaluator(Evaluator<? super InputType, ? extends Vectorizable> evaluator) {
        this.evaluator = evaluator;
    }

    @Override
    public void setCategories(Set<CategoryType> categories) {
        super.setCategories(categories);
    }

    public static class Learner<InputType, CategoryType>
    extends AbstractBatchLearnerContainer<BatchLearner<? super Collection<? extends InputOutputPair<? extends InputType, Vector>>, ? extends Evaluator<? super InputType, ? extends Vectorizable>>>
    implements SupervisedBatchLearner<InputType, CategoryType, WinnerTakeAllCategorizer<InputType, CategoryType>>,
    VectorFactoryContainer {
        protected VectorFactory<?> vectorFactory;

        public Learner() {
            this(null);
        }

        public Learner(BatchLearner<? super Collection<? extends InputOutputPair<? extends InputType, Vector>>, Evaluator<? super InputType, ? extends Vectorizable>> learner) {
            super(learner);
            this.setVectorFactory(VectorFactory.getDefault());
        }

        @Override
        public WinnerTakeAllCategorizer<InputType, CategoryType> learn(Collection<? extends InputOutputPair<? extends InputType, CategoryType>> data) {
            LinkedHashMap<CategoryType, Integer> categoryIndices = new LinkedHashMap<CategoryType, Integer>();
            for (InputOutputPair<InputType, CategoryType> example : data) {
                CategoryType category = example.getOutput();
                if (categoryIndices.containsKey(category)) continue;
                int index = categoryIndices.size();
                categoryIndices.put(category, index);
            }
            int categoryCount = categoryIndices.size();
            ArrayList vectorData = new ArrayList(data.size());
            for (InputOutputPair<InputType, CategoryType> example : data) {
                CategoryType category = example.getOutput();
                int index = (Integer)categoryIndices.get(category);
                Object output = this.getVectorFactory().createVector(categoryCount, -1.0);
                output.setElement(index, 1.0);
                vectorData.add(new DefaultInputOutputPair(example.getInput(), output));
            }
            Evaluator learned = (Evaluator)this.getLearner().learn(vectorData);
            LinkedHashSet categories = new LinkedHashSet(categoryIndices.keySet());
            return new WinnerTakeAllCategorizer(learned, categories);
        }

        public VectorFactory<?> getVectorFactory() {
            return this.vectorFactory;
        }

        public void setVectorFactory(VectorFactory<?> vectorFactory) {
            this.vectorFactory = vectorFactory;
        }
    }
}

