/*
 * Decompiled with CFR 0.152.
 */
package gov.sandia.cognition.learning.function.cost;

import gov.sandia.cognition.algorithm.ParallelAlgorithm;
import gov.sandia.cognition.algorithm.ParallelUtil;
import gov.sandia.cognition.evaluator.Evaluator;
import gov.sandia.cognition.learning.algorithm.gradient.GradientDescendable;
import gov.sandia.cognition.learning.data.InputOutputPair;
import gov.sandia.cognition.learning.data.SequentialDataMultiPartitioner;
import gov.sandia.cognition.learning.data.TargetEstimatePair;
import gov.sandia.cognition.learning.function.cost.AbstractSupervisedCostFunction;
import gov.sandia.cognition.learning.function.cost.DifferentiableCostFunction;
import gov.sandia.cognition.learning.function.cost.ParallelizableCostFunction;
import gov.sandia.cognition.math.matrix.Vector;
import gov.sandia.cognition.util.ObjectUtil;
import java.util.ArrayList;
import java.util.Collection;
import java.util.concurrent.Callable;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.logging.Level;
import java.util.logging.Logger;

public class ParallelizedCostFunctionContainer
extends AbstractSupervisedCostFunction<Vector, Vector>
implements DifferentiableCostFunction,
ParallelAlgorithm {
    private ParallelizableCostFunction costFunction;
    private transient ArrayList<Callable<Object>> evaluationComponents;
    private transient ArrayList<Callable<Object>> gradientComponents;
    private transient ThreadPoolExecutor threadPool;

    public ParallelizedCostFunctionContainer() {
        this((ParallelizableCostFunction)null);
    }

    public ParallelizedCostFunctionContainer(ParallelizableCostFunction costFunction) {
        this(costFunction, ParallelUtil.createThreadPool());
    }

    public ParallelizedCostFunctionContainer(ParallelizableCostFunction costFunction, ThreadPoolExecutor threadPool) {
        this.setCostFunction(costFunction);
        this.setThreadPool(threadPool);
    }

    @Override
    public ParallelizedCostFunctionContainer clone() {
        ParallelizedCostFunctionContainer clone = (ParallelizedCostFunctionContainer)super.clone();
        clone.setCostFunction(ObjectUtil.cloneSafe(this.getCostFunction()));
        clone.setThreadPool(ParallelUtil.createThreadPool(this.getNumThreads()));
        return clone;
    }

    public ParallelizableCostFunction getCostFunction() {
        return this.costFunction;
    }

    public void setCostFunction(ParallelizableCostFunction costFunction) {
        this.costFunction = costFunction;
        this.evaluationComponents = null;
        this.gradientComponents = null;
    }

    protected void createPartitions() {
        int numThreads = this.getNumThreads();
        ArrayList partitions = SequentialDataMultiPartitioner.create(this.getCostParameters(), numThreads);
        this.evaluationComponents = new ArrayList(numThreads);
        this.gradientComponents = new ArrayList(numThreads);
        for (int i = 0; i < numThreads; ++i) {
            ParallelizableCostFunction subcost = (ParallelizableCostFunction)this.getCostFunction().clone();
            subcost.setCostParameters(partitions.get(i));
            this.evaluationComponents.add(new SubCostEvaluate(subcost, null));
            this.gradientComponents.add(new SubCostGradient(subcost, null));
        }
    }

    @Override
    public void setCostParameters(Collection<? extends InputOutputPair<? extends Vector, Vector>> costParameters) {
        super.setCostParameters(costParameters);
        this.evaluationComponents = null;
        this.gradientComponents = null;
    }

    @Override
    public Double evaluate(Evaluator<? super Vector, ? extends Vector> evaluator) {
        if (this.evaluationComponents == null) {
            this.createPartitions();
        }
        for (Callable<Object> sce : this.evaluationComponents) {
            ((SubCostEvaluate)sce).evaluator = evaluator;
        }
        ArrayList<Object> partialResults = null;
        try {
            partialResults = ParallelUtil.executeInParallel(this.evaluationComponents, this.getThreadPool());
        }
        catch (Exception ex) {
            Logger.getLogger(ParallelizedCostFunctionContainer.class.getName()).log(Level.SEVERE, null, ex);
        }
        return this.getCostFunction().evaluateAmalgamate(partialResults);
    }

    @Override
    public Double evaluatePerformance(Collection<? extends TargetEstimatePair<Vector, Vector>> data) {
        return (Double)this.getCostFunction().evaluatePerformance(data);
    }

    @Override
    public Vector computeParameterGradient(GradientDescendable function) {
        if (this.gradientComponents == null) {
            this.createPartitions();
        }
        for (Callable<Object> eval : this.gradientComponents) {
            ((SubCostGradient)eval).evaluator = function;
        }
        ArrayList<Object> results = null;
        try {
            results = ParallelUtil.executeInParallel(this.gradientComponents, this.getThreadPool());
        }
        catch (Exception ex) {
            Logger.getLogger(ParallelizedCostFunctionContainer.class.getName()).log(Level.SEVERE, null, ex);
        }
        return this.getCostFunction().computeParameterGradientAmalgamate(results);
    }

    @Override
    public ThreadPoolExecutor getThreadPool() {
        if (this.threadPool == null) {
            this.setThreadPool(ParallelUtil.createThreadPool());
        }
        return this.threadPool;
    }

    @Override
    public void setThreadPool(ThreadPoolExecutor threadPool) {
        this.threadPool = threadPool;
    }

    @Override
    public int getNumThreads() {
        return ParallelUtil.getNumThreads(this);
    }

    protected void createThreadPool() {
        this.setThreadPool(ParallelUtil.createThreadPool());
    }

    protected static class SubCostGradient
    implements Callable<Object> {
        private ParallelizableCostFunction costFunction;
        private GradientDescendable evaluator;

        public SubCostGradient(ParallelizableCostFunction costFunction, GradientDescendable evaluator) {
            this.costFunction = costFunction;
            this.evaluator = evaluator;
        }

        @Override
        public Object call() {
            return this.costFunction.computeParameterGradientPartial(this.evaluator);
        }
    }

    protected static class SubCostEvaluate
    implements Callable<Object> {
        private ParallelizableCostFunction costFunction;
        private Evaluator<? super Vector, ? extends Vector> evaluator;

        public SubCostEvaluate(ParallelizableCostFunction costFunction, Evaluator<? super Vector, ? extends Vector> evaluator) {
            this.costFunction = costFunction;
            this.evaluator = evaluator;
        }

        @Override
        public Object call() {
            return this.costFunction.evaluatePartial(this.evaluator);
        }
    }
}

