/*
 * Decompiled with CFR 0.152.
 */
package gov.sandia.cognition.statistics.method;

import gov.sandia.cognition.algorithm.AbstractParallelAlgorithm;
import gov.sandia.cognition.algorithm.ParallelUtil;
import gov.sandia.cognition.collection.CollectionUtil;
import gov.sandia.cognition.learning.algorithm.BatchLearner;
import gov.sandia.cognition.learning.algorithm.minimization.FunctionMinimizerDirectionSetPowell;
import gov.sandia.cognition.learning.algorithm.minimization.FunctionMinimizerNelderMead;
import gov.sandia.cognition.learning.function.cost.ParallelNegativeLogLikelihood;
import gov.sandia.cognition.statistics.ClosedFormComputableDistribution;
import gov.sandia.cognition.statistics.ClosedFormDiscreteUnivariateDistribution;
import gov.sandia.cognition.statistics.ClosedFormDistribution;
import gov.sandia.cognition.statistics.DistributionEstimator;
import gov.sandia.cognition.statistics.EstimableDistribution;
import gov.sandia.cognition.statistics.ProbabilityFunction;
import gov.sandia.cognition.statistics.SmoothUnivariateDistribution;
import gov.sandia.cognition.statistics.distribution.UnivariateGaussian;
import gov.sandia.cognition.statistics.method.DistributionParameterEstimator;
import gov.sandia.cognition.util.AbstractCloneableSerializable;
import gov.sandia.cognition.util.DefaultPair;
import gov.sandia.cognition.util.ObjectUtil;
import gov.sandia.cognition.util.Pair;
import java.io.File;
import java.io.IOException;
import java.net.URL;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Enumeration;
import java.util.LinkedList;
import java.util.List;
import java.util.concurrent.Callable;

public class MaximumLikelihoodDistributionEstimator<DataType>
extends AbstractParallelAlgorithm
implements BatchLearner<Collection<? extends DataType>, ClosedFormComputableDistribution<DataType>> {
    private Collection<? extends ClosedFormComputableDistribution<DataType>> distributions;

    public MaximumLikelihoodDistributionEstimator() {
        this((Collection<ClosedFormComputableDistribution<DataType>>)null);
    }

    public MaximumLikelihoodDistributionEstimator(Collection<? extends ClosedFormComputableDistribution<DataType>> distributions) {
        this.setDistributions(distributions);
    }

    @Override
    public MaximumLikelihoodDistributionEstimator<DataType> clone() {
        MaximumLikelihoodDistributionEstimator clone = (MaximumLikelihoodDistributionEstimator)super.clone();
        clone.setDistributions(ObjectUtil.cloneSmartElementsAsArrayList(this.getDistributions()));
        return clone;
    }

    public Collection<? extends ClosedFormComputableDistribution<DataType>> getDistributions() {
        return this.distributions;
    }

    public void setDistributions(Collection<? extends ClosedFormComputableDistribution<DataType>> distributions) {
        this.distributions = distributions;
    }

    @Override
    public ClosedFormComputableDistribution<DataType> learn(Collection<? extends DataType> data) {
        ArrayList<Pair> results;
        ArrayList<DistributionEstimationTask<? extends DataType>> tasks = new ArrayList<DistributionEstimationTask<? extends DataType>>(this.distributions.size());
        for (ClosedFormComputableDistribution<DataType> distribution : this.getDistributions()) {
            tasks.add(new DistributionEstimationTask<DataType>((ClosedFormComputableDistribution)distribution.clone(), data));
        }
        try {
            results = ParallelUtil.executeInParallel(tasks, this.getThreadPool());
        }
        catch (Exception e) {
            System.out.println("Exception: " + e);
            e.printStackTrace();
            results = null;
        }
        double minCost = Double.POSITIVE_INFINITY;
        ClosedFormComputableDistribution minDistribution = null;
        for (Pair result : results) {
            double cost = (Double)result.getFirst();
            if (!(minCost > cost)) continue;
            minCost = cost;
            minDistribution = (ClosedFormComputableDistribution)result.getSecond();
        }
        return minDistribution;
    }

    public static SmoothUnivariateDistribution estimateContinuousDistribution(Collection<Double> data) throws Exception {
        LinkedList<SmoothUnivariateDistribution> distributions = MaximumLikelihoodDistributionEstimator.getDistributionClasses(SmoothUnivariateDistribution.class);
        MaximumLikelihoodDistributionEstimator<Double> estimator = new MaximumLikelihoodDistributionEstimator<Double>(distributions);
        return (SmoothUnivariateDistribution)estimator.learn(data);
    }

    public static ClosedFormDiscreteUnivariateDistribution estimateDiscreteDistribution(Collection<? extends Number> data) throws Exception {
        LinkedList<ClosedFormDiscreteUnivariateDistribution> distributions = MaximumLikelihoodDistributionEstimator.getDistributionClasses(ClosedFormDiscreteUnivariateDistribution.class);
        MaximumLikelihoodDistributionEstimator<? extends Number> estimator = new MaximumLikelihoodDistributionEstimator<Number>(distributions);
        return (ClosedFormDiscreteUnivariateDistribution)estimator.learn(data);
    }

    protected static <DistributionType extends ClosedFormComputableDistribution<?>> LinkedList<DistributionType> getDistributionClasses(Class<? extends DistributionType> baseDistribution) throws ClassNotFoundException, IOException, InstantiationException, IllegalAccessException {
        UnivariateGaussian g = new UnivariateGaussian();
        Package p = g.getClass().getPackage();
        LinkedList<Class<?>> cs = MaximumLikelihoodDistributionEstimator.getClasses(p.getName());
        LinkedList<ClosedFormComputableDistribution> instances = new LinkedList<ClosedFormComputableDistribution>();
        for (Class clazz : cs) {
            if (!baseDistribution.isAssignableFrom(clazz) || !ProbabilityFunction.class.isAssignableFrom(clazz)) continue;
            try {
                instances.add((ClosedFormComputableDistribution)clazz.newInstance());
            }
            catch (Exception e) {
                System.out.println("Couldn't instantiate: " + clazz.getCanonicalName());
            }
        }
        return instances;
    }

    private static LinkedList<Class<?>> getClasses(String packageName) throws ClassNotFoundException, IOException {
        ClassLoader classLoader = Thread.currentThread().getContextClassLoader();
        assert (classLoader != null);
        String path = packageName.replace('.', '/');
        Enumeration<URL> resources = classLoader.getResources(path);
        ArrayList<File> dirs = new ArrayList<File>();
        while (resources.hasMoreElements()) {
            URL resource = resources.nextElement();
            dirs.add(new File(resource.getFile()));
        }
        LinkedList classes = new LinkedList();
        for (File directory : dirs) {
            classes.addAll(MaximumLikelihoodDistributionEstimator.findClasses(directory, packageName));
        }
        return classes;
    }

    private static LinkedList<Class<?>> findClasses(File directory, String packageName) throws ClassNotFoundException {
        File[] files;
        LinkedList classes = new LinkedList();
        for (File file : files = directory.listFiles()) {
            if (!file.getName().endsWith(".class")) continue;
            classes.add(Class.forName(packageName + '.' + file.getName().substring(0, file.getName().length() - 6)));
        }
        return classes;
    }

    public static class DistributionEstimationTask<DataType>
    extends AbstractCloneableSerializable
    implements Callable<Pair<Double, ClosedFormComputableDistribution<DataType>>> {
        ClosedFormComputableDistribution<DataType> distribution;
        Collection<? extends DataType> data;

        public DistributionEstimationTask(ClosedFormComputableDistribution<DataType> distribution, Collection<? extends DataType> data) {
            this.distribution = distribution;
            this.data = data;
        }

        @Override
        public Pair<Double, ClosedFormComputableDistribution<DataType>> call() throws Exception {
            try {
                ParallelNegativeLogLikelihood<DataType> costFunction = new ParallelNegativeLogLikelihood<DataType>(this.data);
                ClosedFormComputableDistribution<DataType> result1 = ObjectUtil.cloneSafe(this.distribution);
                double cost1 = costFunction.evaluate(result1);
                System.out.println("Initial Cost: " + cost1 + ", Class: " + result1.getClass().getCanonicalName() + ", Parameters: " + result1.convertToVector());
                if (Double.isInfinite(cost1) || Double.isNaN(cost1)) {
                    ClosedFormComputableDistribution result2 = ObjectUtil.cloneSafe(this.distribution);
                    boolean bruteForce = true;
                    int Nsub = Math.min(1000, this.data.size() / 1000);
                    List<? extends DataType> subList = CollectionUtil.asArrayList(this.data).subList(0, Nsub);
                    if (this.distribution instanceof EstimableDistribution) {
                        double cost2;
                        DistributionEstimator solver = ((EstimableDistribution)((Object)this.distribution)).getEstimator();
                        try {
                            result2 = (ClosedFormComputableDistribution)solver.learn(this.data);
                            cost2 = costFunction.evaluate(result2);
                            System.out.println("Solver Cost: " + cost2 + ", Class: " + result2.getClass().getCanonicalName() + ", Parameters: " + result2.convertToVector());
                            bruteForce = Double.isInfinite(cost2) || Double.isNaN(cost2);
                        }
                        catch (Exception e) {
                            System.out.println("Solver barfed: " + solver.getClass().getCanonicalName() + ", Exception: " + e);
                            bruteForce = true;
                            result2 = ObjectUtil.cloneSafe(this.distribution);
                        }
                        if (bruteForce) {
                            try {
                                result2 = (ClosedFormComputableDistribution)solver.learn(subList);
                                cost2 = costFunction.evaluate(result2);
                                System.out.println("Sub-Solver Cost: " + cost2 + ", Class: " + result2.getClass().getCanonicalName() + ", Parameters: " + result2.convertToVector());
                                bruteForce = Double.isInfinite(cost2) || Double.isNaN(cost2);
                            }
                            catch (Exception e) {
                                System.out.println("Sub-Solver barfed: " + solver.getClass().getCanonicalName() + ", Exception: " + e);
                                result2 = ObjectUtil.cloneSafe(this.distribution);
                            }
                        }
                    }
                    if (bruteForce) {
                        FunctionMinimizerNelderMead minimizer1 = new FunctionMinimizerNelderMead();
                        minimizer1.setMaxIterations(10);
                        minimizer1.setTolerance(1.0);
                        DistributionParameterEstimator<Object, ClosedFormDistribution> estimator2 = new DistributionParameterEstimator<Object, ClosedFormDistribution>(ObjectUtil.cloneSafe(result2), costFunction, minimizer1);
                        result2 = (ClosedFormComputableDistribution)estimator2.learn(this.data);
                        double cost2 = costFunction.evaluate(result2);
                        System.out.println("Brute Cost: " + cost2 + ", Class: " + result2.getClass().getCanonicalName() + ", Parameters: " + result2.convertToVector());
                        if (Double.isInfinite(cost2) || Double.isNaN(cost2)) {
                            minimizer1.setMaxIterations(1000);
                            costFunction.setCostParameters(subList);
                            estimator2 = new DistributionParameterEstimator(ObjectUtil.cloneSafe(result2), costFunction, minimizer1);
                            result2 = (ClosedFormComputableDistribution)estimator2.learn(subList);
                            costFunction.setCostParameters(this.data);
                            double cost3 = costFunction.evaluate(result2);
                            System.out.println("Subsample Cost: " + cost3 + ", Class: " + result2.getClass().getCanonicalName() + ", Parameters: " + result2.convertToVector());
                        }
                    }
                    result1 = result2;
                }
                FunctionMinimizerDirectionSetPowell minimizer3 = new FunctionMinimizerDirectionSetPowell();
                DistributionParameterEstimator<? extends DataType, ClosedFormDistribution> estimator3 = new DistributionParameterEstimator<DataType, ClosedFormDistribution>(ObjectUtil.cloneSafe(result1), costFunction, minimizer3);
                ClosedFormComputableDistribution result3 = (ClosedFormComputableDistribution)estimator3.learn(this.data);
                double cost3 = costFunction.evaluate(result3);
                System.out.println("Final Cost: " + cost3 + ", Class: " + result3.getClass().getCanonicalName() + ", Parameters: " + result3.convertToVector());
                return DefaultPair.create(cost3, result3);
            }
            catch (Exception e) {
                System.out.println(this.distribution.getClass().getCanonicalName() + " barfed: " + e);
                return DefaultPair.create(Double.POSITIVE_INFINITY, (ClosedFormComputableDistribution)this.distribution.clone());
            }
        }
    }
}

