/*
 * Decompiled with CFR 0.152.
 */
package gov.sandia.cognition.learning.experiment;

import gov.sandia.cognition.algorithm.ParallelAlgorithm;
import gov.sandia.cognition.algorithm.ParallelUtil;
import gov.sandia.cognition.learning.algorithm.BatchLearner;
import gov.sandia.cognition.learning.data.PartitionedDataset;
import gov.sandia.cognition.learning.experiment.LearnerValidationExperiment;
import gov.sandia.cognition.learning.experiment.ValidationFoldCreator;
import gov.sandia.cognition.learning.performance.PerformanceEvaluator;
import gov.sandia.cognition.util.ObjectUtil;
import gov.sandia.cognition.util.Summarizer;
import java.util.ArrayList;
import java.util.Collection;
import java.util.LinkedList;
import java.util.concurrent.Callable;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.logging.Level;
import java.util.logging.Logger;

public class ParallelLearnerValidationExperiment<InputDataType, FoldDataType, LearnedType, StatisticType, SummaryType>
extends LearnerValidationExperiment<InputDataType, FoldDataType, LearnedType, StatisticType, SummaryType>
implements ParallelAlgorithm {
    private transient ThreadPoolExecutor threadPool;

    public ParallelLearnerValidationExperiment() {
        this(null, null, null);
    }

    public ParallelLearnerValidationExperiment(ValidationFoldCreator<InputDataType, FoldDataType> foldCreator, PerformanceEvaluator<? super LearnedType, Collection<? extends FoldDataType>, ? extends StatisticType> performanceEvaluator, Summarizer<? super StatisticType, ? extends SummaryType> summarizer) {
        super(foldCreator, performanceEvaluator, summarizer);
    }

    @Override
    protected void runExperiment(Collection<PartitionedDataset<FoldDataType>> folds) {
        this.setNumTrials(folds.size());
        this.fireExperimentStarted();
        LinkedList<TrialTask> trials = new LinkedList<TrialTask>();
        for (PartitionedDataset<FoldDataType> fold : folds) {
            TrialTask trial = new TrialTask(fold);
            trials.add(trial);
        }
        ArrayList results = null;
        try {
            results = ParallelUtil.executeInParallel(trials, this.getThreadPool());
        }
        catch (Exception ex) {
            Logger.getLogger(ParallelLearnerValidationExperiment.class.getName()).log(Level.SEVERE, null, ex);
        }
        this.getStatistics().addAll(results);
        this.fireExperimentEnded();
    }

    @Override
    public ThreadPoolExecutor getThreadPool() {
        if (this.threadPool == null) {
            this.threadPool = ParallelUtil.createThreadPool();
        }
        return this.threadPool;
    }

    @Override
    public void setThreadPool(ThreadPoolExecutor threadPool) {
        this.threadPool = threadPool;
    }

    @Override
    public int getNumThreads() {
        return ParallelUtil.getNumThreads(this);
    }

    private class TrialTask
    implements Callable<StatisticType> {
        private PartitionedDataset<FoldDataType> fold;

        public TrialTask(PartitionedDataset<FoldDataType> fold) {
            this.fold = fold;
        }

        @Override
        public StatisticType call() {
            try {
                ParallelLearnerValidationExperiment.this.fireTrialStarted();
                BatchLearner learnerClone = ObjectUtil.cloneSmart(ParallelLearnerValidationExperiment.this.getLearner());
                Object learned = learnerClone.learn(this.fold.getTrainingSet());
                Collection testingSet = this.fold.getTestingSet();
                Object statistic = ParallelLearnerValidationExperiment.this.getPerformanceEvaluator().evaluatePerformance(learned, testingSet);
                ParallelLearnerValidationExperiment.this.fireTrialEnded();
                return statistic;
            }
            catch (Exception e) {
                e.printStackTrace();
                return null;
            }
        }
    }
}

