/*
 * Decompiled with CFR 0.152.
 */
package gov.sandia.cognition.learning.algorithm.hmm;

import gov.sandia.cognition.annotation.PublicationReference;
import gov.sandia.cognition.annotation.PublicationType;
import gov.sandia.cognition.collection.MultiCollection;
import gov.sandia.cognition.learning.algorithm.BatchLearner;
import gov.sandia.cognition.learning.algorithm.hmm.AbstractBaumWelchAlgorithm;
import gov.sandia.cognition.learning.algorithm.hmm.HiddenMarkovModel;
import gov.sandia.cognition.learning.algorithm.hmm.MarkovChain;
import gov.sandia.cognition.learning.data.DatasetUtil;
import gov.sandia.cognition.math.Ring;
import gov.sandia.cognition.math.RingAccumulator;
import gov.sandia.cognition.math.matrix.Matrix;
import gov.sandia.cognition.math.matrix.Vector;
import gov.sandia.cognition.statistics.ComputableDistribution;
import gov.sandia.cognition.statistics.ProbabilityFunction;
import gov.sandia.cognition.util.DefaultPair;
import gov.sandia.cognition.util.DefaultWeightedValue;
import gov.sandia.cognition.util.Pair;
import gov.sandia.cognition.util.WeightedValue;
import java.util.ArrayList;
import java.util.Collection;

@PublicationReference(author={"Lawrence R. Rabiner"}, title="A tutorial on hidden Markov models and selected applications in speech recognition", type=PublicationType.Journal, year=1989, publication="Proceedings of the IEEE", pages={257, 286}, url="http://www.cs.ubc.ca/~murphyk/Bayes/rabiner.pdf", notes={"Rabiner's transition matrix is transposed from mine."})
public class BaumWelchAlgorithm<ObservationType>
extends AbstractBaumWelchAlgorithm<ObservationType, Collection<? extends ObservationType>> {
    private transient ArrayList<DefaultWeightedValue<ObservationType>> weightedData;
    private transient ArrayList<DefaultWeightedValue<Double>> sequenceLogLikelihoods;
    private transient int totalNum;
    protected transient MultiCollection<? extends ObservationType> multicollection;
    protected transient ArrayList<Vector> sequenceGammas;

    public BaumWelchAlgorithm() {
        this(null, null, true);
    }

    public BaumWelchAlgorithm(HiddenMarkovModel<ObservationType> initialGuess, BatchLearner<Collection<? extends WeightedValue<? extends ObservationType>>, ? extends ComputableDistribution<ObservationType>> distributionLearner, boolean reestimateInitialProbabilities) {
        super(initialGuess, distributionLearner, reestimateInitialProbabilities);
    }

    @Override
    public BaumWelchAlgorithm<ObservationType> clone() {
        return (BaumWelchAlgorithm)super.clone();
    }

    @Override
    public HiddenMarkovModel<ObservationType> learn(MultiCollection<ObservationType> data) {
        return (HiddenMarkovModel)super.learn(data);
    }

    @Override
    protected boolean initializeAlgorithm() {
        this.multicollection = DatasetUtil.asMultiCollection((Collection)this.data);
        this.data = null;
        int numSequences = this.multicollection.getSubCollectionsCount();
        this.sequenceLogLikelihoods = new ArrayList(numSequences);
        this.totalNum = 0;
        for (Collection<ObservationType> sequence : this.multicollection.subCollections()) {
            this.sequenceLogLikelihoods.add(new DefaultWeightedValue());
            this.totalNum += sequence.size();
        }
        this.weightedData = new ArrayList(this.totalNum);
        this.sequenceGammas = new ArrayList(this.totalNum);
        for (Collection<ObservationType> sequence : this.multicollection.subCollections()) {
            for (ObservationType observation : sequence) {
                this.weightedData.add(new DefaultWeightedValue<ObservationType>(observation));
                this.sequenceGammas.add(null);
            }
        }
        this.result = this.getInitialGuess().clone();
        this.lastLogLikelihood = this.updateSequenceLogLikelihoods(this.result);
        return this.result != null;
    }

    @Override
    protected boolean step() {
        boolean gettingBetter;
        int numSequences = this.multicollection.getSubCollectionsCount();
        boolean updatePi = this.getReestimateInitialProbabilities();
        Pair<ArrayList<ArrayList<Vector>>, ArrayList<Matrix>> pair = this.computeSequenceParameters();
        ArrayList<ArrayList<Vector>> allGammas = pair.getFirst();
        ArrayList<Matrix> sequenceTransitionMatrices = pair.getSecond();
        ArrayList<Vector> firstGammas = updatePi ? new ArrayList<Vector>(numSequences) : null;
        int index = 0;
        for (int i = 0; i < numSequences; ++i) {
            ArrayList<Vector> gammas = allGammas.get(i);
            if (updatePi) {
                firstGammas.add(gammas.get(0));
            }
            int Ni = gammas.size();
            for (int n = 0; n < Ni; ++n) {
                this.sequenceGammas.set(index, gammas.get(n));
                ++index;
            }
        }
        Vector pi = this.result.getInitialProbability();
        if (this.getReestimateInitialProbabilities()) {
            pi = this.updateInitialProbabilities(firstGammas);
        }
        Matrix A = this.updateTransitionMatrix(sequenceTransitionMatrices);
        ArrayList<ProbabilityFunction<ObservationType>> fs = this.updateProbabilityFunctions(this.sequenceGammas);
        if (this.getMaxIterations() <= 1) {
            this.result.emissionFunctions = fs;
            this.result.initialProbability = pi;
            this.result.transitionProbability = A;
            gettingBetter = true;
        } else {
            MarkovChain candidate = this.result.clone();
            ((HiddenMarkovModel)candidate).emissionFunctions = fs;
            ((HiddenMarkovModel)candidate).initialProbability = pi;
            ((HiddenMarkovModel)candidate).transitionProbability = A;
            double logLikelihood = this.updateSequenceLogLikelihoods((HiddenMarkovModel<ObservationType>)candidate);
            boolean bl = gettingBetter = logLikelihood > this.lastLogLikelihood || this.getIteration() <= 1;
            if (gettingBetter) {
                this.result = candidate;
                this.lastLogLikelihood = logLikelihood;
            }
        }
        return gettingBetter;
    }

    @Override
    protected void cleanupAlgorithm() {
        this.multicollection = null;
        this.weightedData = null;
        this.sequenceLogLikelihoods = null;
        this.totalNum = 0;
    }

    protected Pair<ArrayList<ArrayList<Vector>>, ArrayList<Matrix>> computeSequenceParameters() {
        int numSequences = this.multicollection.getSubCollectionsCount();
        ArrayList<ArrayList<Vector>> allGammas = new ArrayList<ArrayList<Vector>>(numSequences);
        ArrayList<Matrix> sequenceTransitionMatrices = new ArrayList<Matrix>(numSequences);
        boolean normalize = true;
        int k = 0;
        for (Collection<? extends ObservationType> collection : this.multicollection.subCollections()) {
            double sequenceWeight = this.sequenceLogLikelihoods.get(k).getWeight();
            ArrayList<Vector> b = this.result.computeObservationLikelihoods(collection);
            ArrayList<WeightedValue<Vector>> alphas = this.result.computeForwardProbabilities(b, true);
            ArrayList<WeightedValue<Vector>> betas = this.result.computeBackwardProbabilities(b, alphas);
            ArrayList<Vector> gammas = this.result.computeStateObservationLikelihood(alphas, betas, sequenceWeight);
            allGammas.add(gammas);
            Matrix A = this.result.computeTransitions(alphas, betas, b);
            if (sequenceWeight != 1.0) {
                A.scaleEquals(sequenceWeight);
            }
            sequenceTransitionMatrices.add(A);
            ++k;
        }
        return DefaultPair.create(allGammas, sequenceTransitionMatrices);
    }

    protected ArrayList<ProbabilityFunction<ObservationType>> updateProbabilityFunctions(ArrayList<Vector> sequenceGammas) {
        int numStates = this.result.getNumStates();
        ArrayList<ProbabilityFunction<ObservationType>> fs = new ArrayList<ProbabilityFunction<ObservationType>>(numStates);
        for (int i = 0; i < numStates; ++i) {
            int index = 0;
            for (int n = 0; n < sequenceGammas.size(); ++n) {
                double g = sequenceGammas.get(n).getElement(i);
                this.weightedData.get(index).setWeight(g);
                ++index;
            }
            ProbabilityFunction f = ((ComputableDistribution)this.distributionLearner.learn(this.weightedData)).getProbabilityFunction();
            fs.add(f);
        }
        return fs;
    }

    protected Matrix updateTransitionMatrix(ArrayList<Matrix> sequenceTransitionMatrices) {
        RingAccumulator<Matrix> As = new RingAccumulator<Matrix>(sequenceTransitionMatrices);
        Matrix A = As.getSum();
        this.result.normalizeTransitionMatrix(A);
        return A;
    }

    protected Vector updateInitialProbabilities(ArrayList<Vector> firstGammas) {
        RingAccumulator<Ring> pi = new RingAccumulator<Ring>();
        for (int k = 0; k < firstGammas.size(); ++k) {
            pi.accumulate(firstGammas.get(k));
        }
        Vector pisum = (Vector)pi.getSum();
        pisum.scaleEquals(1.0 / pisum.norm1());
        return pisum;
    }

    protected double updateSequenceLogLikelihoods(HiddenMarkovModel<ObservationType> hmm) {
        double logLikelihood;
        int k = 0;
        double maxLogLikelihood = Double.NEGATIVE_INFINITY;
        double totalLogLikelihood = 0.0;
        for (Collection<? extends ObservationType> collection : this.multicollection.subCollections()) {
            logLikelihood = hmm.computeObservationLogLikelihood(collection);
            if (maxLogLikelihood < logLikelihood) {
                maxLogLikelihood = logLikelihood;
            }
            this.sequenceLogLikelihoods.get(k).setValue(logLikelihood);
            totalLogLikelihood += logLikelihood;
            ++k;
        }
        int numSequences = this.multicollection.getSubCollectionsCount();
        for (k = 0; k < numSequences; ++k) {
            DefaultWeightedValue<Double> defaultWeightedValue = this.sequenceLogLikelihoods.get(k);
            logLikelihood = defaultWeightedValue.getValue();
            double weight = 1.0 / Math.exp(logLikelihood - maxLogLikelihood);
            defaultWeightedValue.setWeight(weight);
        }
        return totalLogLikelihood;
    }
}

