/*
 * Decompiled with CFR 0.152.
 */
package gov.sandia.cognition.learning.algorithm.hmm;

import gov.sandia.cognition.annotation.PublicationReference;
import gov.sandia.cognition.annotation.PublicationType;
import gov.sandia.cognition.math.matrix.Matrix;
import gov.sandia.cognition.math.matrix.MatrixFactory;
import gov.sandia.cognition.math.matrix.Vector;
import gov.sandia.cognition.math.matrix.VectorFactory;
import gov.sandia.cognition.math.matrix.VectorUtil;
import gov.sandia.cognition.math.matrix.decomposition.EigenvectorPowerIteration;
import gov.sandia.cognition.util.AbstractCloneableSerializable;
import gov.sandia.cognition.util.ObjectUtil;

@PublicationReference(author={"Wikipedia"}, title="Markov chain", type=PublicationType.WebPage, year=2010, url="http://en.wikipedia.org/wiki/Markov_chain")
public class MarkovChain
extends AbstractCloneableSerializable {
    public static final int DEFAULT_NUM_STATES = 3;
    protected Vector initialProbability;
    protected Matrix transitionProbability;

    public MarkovChain() {
        this(3);
    }

    public MarkovChain(int numStates) {
        this(MarkovChain.createUniformInitialProbability(numStates), MarkovChain.createUniformTransitionProbability(numStates));
    }

    public MarkovChain(Vector initialProbability, Matrix transitionProbability) {
        if (!transitionProbability.isSquare()) {
            throw new IllegalArgumentException("transitionProbability must be square!");
        }
        int k = transitionProbability.getNumRows();
        initialProbability.assertDimensionalityEquals(k);
        this.setTransitionProbability(transitionProbability);
        this.setInitialProbability(initialProbability);
    }

    @Override
    public MarkovChain clone() {
        MarkovChain clone = (MarkovChain)super.clone();
        clone.setInitialProbability(ObjectUtil.cloneSafe(this.getInitialProbability()));
        clone.setTransitionProbability(ObjectUtil.cloneSafe(this.getTransitionProbability()));
        return clone;
    }

    protected static Vector createUniformInitialProbability(int numStates) {
        return VectorFactory.getDefault().createVector(numStates, 1.0 / (double)numStates);
    }

    protected static Matrix createUniformTransitionProbability(int numStates) {
        Matrix A = MatrixFactory.getDefault().createMatrix(numStates, numStates);
        double p = 1.0 / (double)numStates;
        for (int i = 0; i < numStates; ++i) {
            for (int j = 0; j < numStates; ++j) {
                A.setElement(i, j, p);
            }
        }
        return A;
    }

    public Vector getInitialProbability() {
        return this.initialProbability;
    }

    public void setInitialProbability(Vector initialProbability) {
        int k = initialProbability.getDimensionality();
        double sum = 0.0;
        for (int i = 0; i < k; ++i) {
            double value = initialProbability.getElement(i);
            if (value < 0.0) {
                throw new IllegalArgumentException("Initial Probabilities must be >= 0.0");
            }
            sum += value;
        }
        if (sum != 1.0) {
            initialProbability.scaleEquals(1.0 / sum);
        }
        this.initialProbability = initialProbability;
    }

    public Matrix getTransitionProbability() {
        return this.transitionProbability;
    }

    public void setTransitionProbability(Matrix transitionProbability) {
        if (!transitionProbability.isSquare()) {
            throw new IllegalArgumentException("Transition Probability must be square");
        }
        this.normalizeTransitionMatrix(transitionProbability);
        this.transitionProbability = transitionProbability;
    }

    public void normalize() {
        VectorUtil.divideByNorm1Equals(this.initialProbability);
        this.normalizeTransitionMatrix(this.transitionProbability);
    }

    protected static void normalizeTransitionMatrix(Matrix A, int j) {
        int i;
        double sum = 0.0;
        int k = A.getNumRows();
        for (i = 0; i < k; ++i) {
            double value = A.getElement(i, j);
            if (value < 0.0) {
                throw new IllegalArgumentException("Transition Probabilities must be >= 0.0");
            }
            sum += A.getElement(i, j);
        }
        if (sum <= 0.0) {
            sum = 1.0;
        }
        if (sum != 1.0) {
            for (i = 0; i < k; ++i) {
                A.setElement(i, j, A.getElement(i, j) / sum);
            }
        }
    }

    protected void normalizeTransitionMatrix(Matrix A) {
        int k = A.getNumColumns();
        for (int j = 0; j < k; ++j) {
            MarkovChain.normalizeTransitionMatrix(A, j);
        }
    }

    public int getNumStates() {
        return this.initialProbability.getDimensionality();
    }

    public String toString() {
        StringBuilder retval = new StringBuilder(100 * this.getNumStates());
        retval.append("Markov Chain has " + this.getNumStates() + " states:\n");
        retval.append("Initial: " + this.getInitialProbability() + "\n");
        retval.append("Transition:\n" + this.getTransitionProbability());
        return retval.toString();
    }

    public Vector getSteadyStateDistribution() {
        double tolerance = 1.0E-5;
        int maxIterations = 100;
        Vector p = EigenvectorPowerIteration.estimateEigenvector(this.initialProbability, this.transitionProbability, 1.0E-5, 100);
        double sum = 0.0;
        for (int i = 0; i < p.getDimensionality(); ++i) {
            sum += p.getElement(i);
        }
        p.scaleEquals(1.0 / sum);
        return p;
    }

    public Vector getFutureStateDistribution(Vector current, int numSteps) {
        Vector predicted = current;
        for (int n = 0; n < numSteps; ++n) {
            predicted = this.transitionProbability.times(predicted);
        }
        return (Vector)predicted.scale(1.0 / predicted.norm1());
    }
}

