package gov.sandia.cognition.statistics.distribution;

import gov.sandia.cognition.collection.CollectionUtil;
import gov.sandia.cognition.math.Ring;
import gov.sandia.cognition.math.RingAccumulator;
import gov.sandia.cognition.statistics.AbstractDistribution;
import gov.sandia.cognition.statistics.DistributionWeightedEstimator;
import gov.sandia.cognition.statistics.PointMassDistribution;
import gov.sandia.cognition.statistics.ProbabilityMassFunctionUtil;
import gov.sandia.cognition.util.AbstractCloneableSerializable;
import gov.sandia.cognition.util.WeightedValue;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.LinkedList;
import java.util.Map;
import java.util.Random;

/* loaded from: input_file:gov/sandia/cognition/statistics/distribution/MapBasedPointMassDistribution.class */
public class MapBasedPointMassDistribution<DataType> extends AbstractDistribution<DataType> implements PointMassDistribution<DataType> {
    private double totalMass;
    private Map<DataType, Entry> dataMap;

    /* JADX INFO: Access modifiers changed from: protected */
    /* loaded from: input_file:gov/sandia/cognition/statistics/distribution/MapBasedPointMassDistribution$Entry.class */
    public static class Entry extends AbstractCloneableSerializable {
        protected double mass;

        protected Entry() {
            this(0.0d);
        }

        protected Entry(double d) {
            this.mass = d;
        }

        protected Entry(Entry entry) {
            this(entry.mass);
        }

        protected double getMass() {
            return this.mass;
        }

        protected void setMass(double d) {
            this.mass = d;
        }
    }

    /* loaded from: input_file:gov/sandia/cognition/statistics/distribution/MapBasedPointMassDistribution$Learner.class */
    public static class Learner<DataType> extends AbstractCloneableSerializable implements DistributionWeightedEstimator<DataType, PMF<DataType>> {
        @Override // gov.sandia.cognition.learning.algorithm.BatchLearner
        public PMF<DataType> learn(Collection<? extends WeightedValue<? extends DataType>> collection) {
            PMF<DataType> pmf = new PMF<>();
            for (WeightedValue<? extends DataType> weightedValue : collection) {
                pmf.add(weightedValue.getValue(), weightedValue.getWeight());
            }
            return pmf;
        }
    }

    /* loaded from: input_file:gov/sandia/cognition/statistics/distribution/MapBasedPointMassDistribution$PMF.class */
    public static class PMF<DataType> extends MapBasedPointMassDistribution<DataType> implements PointMassDistribution.PMF<DataType> {
        public PMF() {
        }

        public PMF(int i) {
            super(i);
        }

        public PMF(MapBasedPointMassDistribution<DataType> mapBasedPointMassDistribution) {
            super(mapBasedPointMassDistribution);
        }

        @Override // gov.sandia.cognition.statistics.distribution.MapBasedPointMassDistribution, gov.sandia.cognition.util.AbstractCloneableSerializable
        /* renamed from: clone */
        public PMF<DataType> mo539clone() {
            return (PMF) super.mo539clone();
        }

        @Override // gov.sandia.cognition.statistics.ProbabilityMassFunction
        public double getEntropy() {
            return ProbabilityMassFunctionUtil.getEntropy(this);
        }

        /* JADX WARN: Can't rename method to resolve collision */
        @Override // gov.sandia.cognition.evaluator.Evaluator
        public Double evaluate(DataType datatype) {
            return getTotalMass() > 0.0d ? Double.valueOf(getMass(datatype) / getTotalMass()) : Double.valueOf(0.0d);
        }

        @Override // gov.sandia.cognition.statistics.ProbabilityFunction
        public double logEvaluate(DataType datatype) {
            return Math.log(evaluate((PMF<DataType>) datatype).doubleValue());
        }

        @Override // gov.sandia.cognition.statistics.distribution.MapBasedPointMassDistribution, gov.sandia.cognition.statistics.ComputableDistribution
        public PMF<DataType> getProbabilityFunction() {
            return this;
        }

        /* JADX WARN: Multi-variable type inference failed */
        @Override // gov.sandia.cognition.evaluator.Evaluator
        public /* bridge */ /* synthetic */ Double evaluate(Object obj) {
            return evaluate((PMF<DataType>) obj);
        }
    }

    public MapBasedPointMassDistribution() {
        setDataMap(new LinkedHashMap());
        setTotalMass(0.0d);
    }

    public MapBasedPointMassDistribution(int i) {
        setDataMap(new LinkedHashMap(i));
        setTotalMass(0.0d);
    }

    public MapBasedPointMassDistribution(Collection<? extends DataType> collection) {
        this(CollectionUtil.size((Collection<?>) collection));
        Iterator<? extends DataType> it = collection.iterator();
        while (it.hasNext()) {
            add(it.next());
        }
    }

    public MapBasedPointMassDistribution(MapBasedPointMassDistribution<DataType> mapBasedPointMassDistribution) {
        this(mapBasedPointMassDistribution.getDataMap().size());
        for (DataType datatype : mapBasedPointMassDistribution.getDomain()) {
            add(datatype, mapBasedPointMassDistribution.getMass(datatype));
        }
    }

    @Override // gov.sandia.cognition.util.AbstractCloneableSerializable
    /* renamed from: clone */
    public MapBasedPointMassDistribution<DataType> mo539clone() {
        MapBasedPointMassDistribution<DataType> mapBasedPointMassDistribution = (MapBasedPointMassDistribution) super.mo539clone();
        mapBasedPointMassDistribution.dataMap = new LinkedHashMap(this.dataMap.size());
        for (Map.Entry<DataType, Entry> entry : this.dataMap.entrySet()) {
            mapBasedPointMassDistribution.dataMap.put(entry.getKey(), new Entry(entry.getValue()));
        }
        return mapBasedPointMassDistribution;
    }

    public DataType getMean() {
        Object scaleSum;
        Object first = CollectionUtil.getFirst(getDomain());
        if (first instanceof Number) {
            scaleSum = new Double(ScalarDataDistribution.getMean(this));
        } else {
            if (!(first instanceof Ring)) {
                throw new UnsupportedOperationException("mean not supported");
            }
            RingAccumulator ringAccumulator = new RingAccumulator();
            for (DataType datatype : getDomain()) {
                ringAccumulator.accumulate((RingAccumulator) ((Ring) datatype).scale(getMass(datatype)));
            }
            scaleSum = ringAccumulator.scaleSum(1.0d / getTotalMass());
        }
        return (DataType) scaleSum;
    }

    @Override // gov.sandia.cognition.statistics.Distribution
    public ArrayList<DataType> sample(Random random, int i) {
        return ProbabilityMassFunctionUtil.sample(getProbabilityFunction(), random, i);
    }

    @Override // gov.sandia.cognition.statistics.PointMassDistribution
    public void add(DataType datatype) {
        add(datatype, 1.0d);
    }

    @Override // gov.sandia.cognition.statistics.PointMassDistribution
    public void add(DataType datatype, double d) {
        if (d < 0.0d) {
            throw new IllegalArgumentException("mass cannot be negative");
        }
        if (d == 0.0d) {
            return;
        }
        Entry entry = this.dataMap.get(datatype);
        if (entry == null) {
            this.dataMap.put(datatype, new Entry(d));
        } else {
            entry.mass += d;
        }
        this.totalMass += d;
    }

    @Override // gov.sandia.cognition.statistics.PointMassDistribution
    public void setMass(DataType datatype, double d) {
        if (d < 0.0d) {
            throw new IllegalArgumentException("mass cannot be negative");
        }
        double d2 = 0.0d;
        Entry entry = this.dataMap.get(datatype);
        if (entry != null) {
            d2 = d - entry.mass;
            entry.mass = d;
            if (d <= 0.0d) {
                this.dataMap.remove(datatype);
            }
        } else if (d > 0.0d) {
            this.dataMap.put(datatype, new Entry(d));
            d2 = d;
        }
        this.totalMass += d2;
    }

    @Override // gov.sandia.cognition.statistics.PointMassDistribution
    public void clear() {
        getDataMap().clear();
        setTotalMass(0.0d);
    }

    @Override // gov.sandia.cognition.statistics.PointMassDistribution
    public void remove(DataType datatype) {
        remove(datatype, getMass(datatype));
    }

    @Override // gov.sandia.cognition.statistics.PointMassDistribution
    public void remove(DataType datatype, double d) {
        Entry entry;
        if (d < 0.0d) {
            throw new IllegalArgumentException("mass cannot be negative");
        }
        if (d == 0.0d || (entry = this.dataMap.get(datatype)) == null) {
            return;
        }
        double d2 = entry.mass;
        double d3 = d2 - d;
        if (d3 <= 0.0d) {
            this.totalMass -= d2;
            this.dataMap.remove(datatype);
        } else {
            entry.mass = d3;
            this.totalMass -= d;
        }
    }

    @Override // gov.sandia.cognition.statistics.PointMassDistribution
    public double getMass(DataType datatype) {
        Entry entry = this.dataMap.get(datatype);
        if (entry == null) {
            return 0.0d;
        }
        return entry.mass;
    }

    public void plusEquals(PointMassDistribution<DataType> pointMassDistribution) {
        for (DataType datatype : pointMassDistribution.getDomain()) {
            add(datatype, pointMassDistribution.getMass(datatype));
        }
    }

    public void scaleEquals(double d) {
        if (d < 0.0d) {
            throw new IllegalArgumentException("scaleFactor cannot be negative.");
        }
        Iterator<Entry> it = this.dataMap.values().iterator();
        while (it.hasNext()) {
            it.next().mass *= d;
        }
        this.totalMass *= d;
    }

    public double getFraction(DataType datatype) {
        if (this.totalMass == 0.0d) {
            return 0.0d;
        }
        return getMass(datatype) / this.totalMass;
    }

    public double getMaximumMass() {
        double d = 0.0d;
        Iterator<Entry> it = this.dataMap.values().iterator();
        while (it.hasNext()) {
            double d2 = it.next().mass;
            if (d2 > d) {
                d = d2;
            }
        }
        return d;
    }

    public DataType getMaximumValue() {
        DataType datatype = null;
        double d = 0.0d;
        for (Map.Entry<DataType, Entry> entry : this.dataMap.entrySet()) {
            double d2 = entry.getValue().mass;
            if (d2 > d) {
                datatype = entry.getKey();
                d = d2;
            }
        }
        return datatype;
    }

    public LinkedList<DataType> getMaximumValues() {
        return getMaximumValues(0.0d);
    }

    public LinkedList<DataType> getMaximumValues(double d) {
        double maximumMass = getMaximumMass() - d;
        LinkedList<DataType> linkedList = new LinkedList<>();
        for (Map.Entry<DataType, Entry> entry : this.dataMap.entrySet()) {
            if (entry.getValue().mass >= maximumMass) {
                linkedList.add(entry.getKey());
            }
        }
        return linkedList;
    }

    @Override // gov.sandia.cognition.statistics.DiscreteDistribution
    public Collection<? extends DataType> getDomain() {
        return this.dataMap.keySet();
    }

    @Override // gov.sandia.cognition.statistics.DiscreteDistribution
    public int getDomainSize() {
        return getDomain().size();
    }

    @Override // gov.sandia.cognition.statistics.ComputableDistribution
    public PMF<DataType> getProbabilityFunction() {
        return new PMF<>(this);
    }

    protected Map<DataType, Entry> getDataMap() {
        return this.dataMap;
    }

    protected void setDataMap(Map<DataType, Entry> map) {
        this.dataMap = map;
    }

    protected void setTotalMass(double d) {
        this.totalMass = d;
    }

    @Override // gov.sandia.cognition.statistics.PointMassDistribution
    public double getTotalMass() {
        return this.totalMass;
    }

    public String toString() {
        int size = getDomain().size();
        StringBuilder sb = new StringBuilder(size * 100);
        sb.append("Point mass distribution has " + size + " particles and " + getTotalMass() + " total mass:\n");
        for (DataType datatype : getDomain()) {
            sb.append(datatype.toString());
            sb.append(": ");
            sb.append(getMass(datatype));
            sb.append(" (");
            sb.append(getFraction(datatype));
            sb.append(")");
            sb.append("\n");
        }
        return sb.toString();
    }
}
