/*
 * Decompiled with CFR 0.152.
 */
package gov.sandia.cognition.learning.algorithm;

import gov.sandia.cognition.collection.DefaultMultiCollection;
import gov.sandia.cognition.collection.FiniteCapacityBuffer;
import gov.sandia.cognition.collection.MultiCollection;
import gov.sandia.cognition.learning.algorithm.AbstractBatchLearnerContainer;
import gov.sandia.cognition.learning.algorithm.BatchLearner;
import gov.sandia.cognition.learning.data.DatasetUtil;
import gov.sandia.cognition.learning.data.DefaultInputOutputPair;
import gov.sandia.cognition.learning.data.InputOutputPair;
import java.util.ArrayList;
import java.util.Collection;

public class SequencePredictionLearner<DataType, LearnedType>
extends AbstractBatchLearnerContainer<BatchLearner<? super Collection<? extends InputOutputPair<? extends DataType, DataType>>, ? extends LearnedType>>
implements BatchLearner<Collection<? extends DataType>, LearnedType> {
    public static final int DEFAULT_PREDICTION_HORIZION = 1;
    protected int predictionHorizon;

    public SequencePredictionLearner() {
        this(null, 1);
    }

    public SequencePredictionLearner(BatchLearner<? super Collection<? extends InputOutputPair<? extends DataType, DataType>>, ? extends LearnedType> learner, int predictionHorizon) {
        super(learner);
        this.setPredictionHorizon(predictionHorizon);
    }

    @Override
    public LearnedType learn(Collection<? extends DataType> data) {
        return this.learn(DatasetUtil.asMultiCollection(data));
    }

    @Override
    public LearnedType learn(MultiCollection<? extends DataType> data) {
        MultiCollection<InputOutputPair<? extends DataType, ? extends DataType>> supervisedData = SequencePredictionLearner.createPredictionDataset(data, this.getPredictionHorizion());
        return (LearnedType)this.getLearner().learn(supervisedData);
    }

    public static <DataType> MultiCollection<InputOutputPair<DataType, DataType>> createPredictionDataset(Collection<? extends DataType> data, int predictionHorizon) {
        return SequencePredictionLearner.createPredictionDataset(DatasetUtil.asMultiCollection(data), predictionHorizon);
    }

    public static <DataType> MultiCollection<InputOutputPair<DataType, DataType>> createPredictionDataset(MultiCollection<? extends DataType> data, int predictionHorizon) {
        ArrayList sequences = new ArrayList(data.subCollections().size());
        FiniteCapacityBuffer<DataType> buffer = new FiniteCapacityBuffer<DataType>(predictionHorizon);
        for (Collection<DataType> subData : data.subCollections()) {
            int sequenceLength = subData.size() - predictionHorizon;
            if (sequenceLength <= 0) continue;
            ArrayList sequence = new ArrayList(sequenceLength);
            buffer.clear();
            for (DataType output : subData) {
                if (buffer.isFull()) {
                    Object input = buffer.getFirst();
                    sequence.add(new DefaultInputOutputPair(input, output));
                }
                buffer.addLast(output);
            }
            sequences.add(sequence);
        }
        return new DefaultMultiCollection<InputOutputPair<DataType, DataType>>(sequences);
    }

    public int getPredictionHorizion() {
        return this.predictionHorizon;
    }

    public void setPredictionHorizon(int predictionHorizon) {
        if (predictionHorizon <= 0) {
            throw new IllegalArgumentException("predictionHorizon must be positive");
        }
        this.predictionHorizon = predictionHorizon;
    }
}

