/*
 * Decompiled with CFR 0.152.
 */
package gov.sandia.cognition.learning.algorithm.ensemble;

import gov.sandia.cognition.evaluator.Evaluator;
import gov.sandia.cognition.learning.algorithm.ensemble.Ensemble;
import gov.sandia.cognition.learning.data.DefaultWeightedValueDiscriminant;
import gov.sandia.cognition.learning.function.categorization.AbstractCategorizer;
import gov.sandia.cognition.learning.function.categorization.DiscriminantCategorizer;
import gov.sandia.cognition.statistics.distribution.MapBasedPointMassDistribution;
import gov.sandia.cognition.util.DefaultWeightedValue;
import gov.sandia.cognition.util.WeightedValue;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Set;

public class WeightedVotingCategorizerEnsemble<InputType, CategoryType, MemberType extends Evaluator<? super InputType, ? extends CategoryType>>
extends AbstractCategorizer<InputType, CategoryType>
implements Ensemble<WeightedValue<MemberType>>,
DiscriminantCategorizer<InputType, CategoryType, Double> {
    public static final double DEFAULT_WEIGHT = 1.0;
    protected List<WeightedValue<MemberType>> members;

    public WeightedVotingCategorizerEnsemble() {
        this(new HashSet());
    }

    public WeightedVotingCategorizerEnsemble(Set<CategoryType> categories) {
        this(categories, new ArrayList<WeightedValue<MemberType>>());
    }

    public WeightedVotingCategorizerEnsemble(Set<CategoryType> categories, List<WeightedValue<MemberType>> members) {
        super(categories);
        this.setMembers(members);
    }

    public void add(MemberType categorizer) {
        this.add(categorizer, 1.0);
    }

    public void add(MemberType categorizer, double weight) {
        if (categorizer == null) {
            throw new IllegalArgumentException("categorizer cannot be null");
        }
        if (weight < 0.0) {
            throw new IllegalArgumentException("weight cannot be negative");
        }
        DefaultWeightedValue<MemberType> weighted = new DefaultWeightedValue<MemberType>(categorizer, weight);
        this.getMembers().add(weighted);
    }

    @Override
    public CategoryType evaluate(InputType input) {
        MapBasedPointMassDistribution<CategoryType> votes = this.evaluateAsVotes(input);
        return votes.getMaximumValue();
    }

    public DefaultWeightedValueDiscriminant<CategoryType> evaluateWithDiscriminant(InputType input) {
        MapBasedPointMassDistribution<CategoryType> votes = this.evaluateAsVotes(input);
        CategoryType bestCategory = votes.getMaximumValue();
        if (bestCategory == null) {
            return null;
        }
        double bestVotePercentage = votes.getFraction(bestCategory);
        return DefaultWeightedValueDiscriminant.create(bestCategory, bestVotePercentage);
    }

    public MapBasedPointMassDistribution<CategoryType> evaluateAsVotes(InputType input) {
        MapBasedPointMassDistribution votes = new MapBasedPointMassDistribution(this.getCategories().size());
        for (WeightedValue member : this.getMembers()) {
            Object category = ((Evaluator)member.getValue()).evaluate(input);
            double weight = member.getWeight();
            if (category == null) continue;
            votes.add(category, weight);
        }
        return votes;
    }

    @Override
    public List<WeightedValue<MemberType>> getMembers() {
        return this.members;
    }

    public void setMembers(List<WeightedValue<MemberType>> members) {
        this.members = members;
    }
}

