/*
 * Decompiled with CFR 0.152.
 */
package gov.sandia.cognition.learning.data.feature;

import gov.sandia.cognition.evaluator.Evaluator;
import gov.sandia.cognition.learning.algorithm.BatchLearner;
import gov.sandia.cognition.learning.data.DatasetUtil;
import gov.sandia.cognition.math.RingAccumulator;
import gov.sandia.cognition.math.matrix.Matrix;
import gov.sandia.cognition.math.matrix.MatrixFactory;
import gov.sandia.cognition.math.matrix.Vector;
import gov.sandia.cognition.math.matrix.VectorFactory;
import gov.sandia.cognition.math.matrix.Vectorizable;
import gov.sandia.cognition.math.matrix.mtj.DenseMatrixFactoryMTJ;
import gov.sandia.cognition.math.matrix.mtj.decomposition.CholeskyDecompositionMTJ;
import gov.sandia.cognition.statistics.distribution.MultivariateGaussian;
import gov.sandia.cognition.util.AbstractCloneableSerializable;
import gov.sandia.cognition.util.ObjectUtil;
import java.util.Collection;

public class MultivariateDecorrelator
extends AbstractCloneableSerializable
implements Evaluator<Vectorizable, Vector> {
    protected MultivariateGaussian gaussian;
    private Matrix covarianceInverseSquareRoot;

    public MultivariateDecorrelator() {
        this((MultivariateGaussian)null);
    }

    public MultivariateDecorrelator(Vector mean, Matrix covariance) {
        this(new MultivariateGaussian(mean, covariance));
    }

    public MultivariateDecorrelator(MultivariateGaussian gaussian) {
        this.setGaussian(gaussian);
    }

    public MultivariateDecorrelator(MultivariateDecorrelator other) {
        this(ObjectUtil.cloneSafe(other.getGaussian()));
    }

    @Override
    public MultivariateDecorrelator clone() {
        MultivariateDecorrelator clone = (MultivariateDecorrelator)super.clone();
        clone.gaussian = ObjectUtil.cloneSafe(this.gaussian);
        clone.covarianceInverseSquareRoot = ObjectUtil.cloneSafe(this.covarianceInverseSquareRoot);
        return clone;
    }

    @Override
    public Vector evaluate(Vectorizable value) {
        Vector input = value.convertToVector();
        return input.minus(this.getMean()).times(this.getCovarianceInverseSquareRoot());
    }

    public Vector getMean() {
        return this.getGaussian().getMean();
    }

    public Matrix getCovariance() {
        return this.getGaussian().getCovariance();
    }

    public MultivariateGaussian getGaussian() {
        return this.gaussian;
    }

    public void setGaussian(MultivariateGaussian gaussian) {
        if (gaussian == null) {
            this.gaussian = null;
            this.covarianceInverseSquareRoot = null;
        } else {
            this.gaussian = gaussian.clone();
            CholeskyDecompositionMTJ chokesky = CholeskyDecompositionMTJ.create(DenseMatrixFactoryMTJ.INSTANCE.copyMatrix(gaussian.getCovarianceInverse()));
            this.covarianceInverseSquareRoot = chokesky.getR();
        }
    }

    public Matrix getCovarianceInverseSquareRoot() {
        return this.covarianceInverseSquareRoot;
    }

    public static MultivariateDecorrelator learnFullCovariance(Collection<? extends Vectorizable> values, double defaultCovariance) {
        Collection<Vector> vectorValues = DatasetUtil.asVectorCollection(values);
        MultivariateGaussian.PDF pdf = MultivariateGaussian.MaximumLikelihoodEstimator.learn(vectorValues, defaultCovariance);
        return new MultivariateDecorrelator(pdf);
    }

    public static MultivariateDecorrelator learnDiagonalCovariance(Collection<? extends Vectorizable> values, double defaultCovariance) {
        if (values == null) {
            throw new IllegalArgumentException("values cannot be null.");
        }
        int count = values.size();
        if (count <= 0) {
            throw new IllegalArgumentException("values cannot be empty.");
        }
        RingAccumulator<Vector> meanAccumulator = new RingAccumulator<Vector>();
        for (Vectorizable vectorizable : values) {
            meanAccumulator.accumulate(vectorizable.convertToVector());
        }
        Vector mean = (Vector)meanAccumulator.getMean();
        Vector vector = VectorFactory.getDefault().createVector(mean.getDimensionality());
        for (Vectorizable vectorizable : values) {
            Vector difference = vectorizable.convertToVector().minus(mean);
            difference.dotTimesEquals(difference);
            vector.plusEquals(difference);
        }
        vector.scaleEquals(1.0 / (double)count);
        vector.plusEquals(VectorFactory.getDefault().createVector(mean.getDimensionality(), defaultCovariance));
        Matrix covariance = MatrixFactory.getDefault().createDiagonal(vector);
        return new MultivariateDecorrelator(mean, covariance);
    }

    public static class DiagonalCovarianceLearner
    extends AbstractCloneableSerializable
    implements BatchLearner<Collection<? extends Vectorizable>, MultivariateDecorrelator> {
        public static final double DEFAULT_DEFAULT_COVARIANCE = 1.0E-5;
        protected double defaultCovariance;

        public DiagonalCovarianceLearner() {
            this(1.0E-5);
        }

        public DiagonalCovarianceLearner(double defaultCovariance) {
            this.setDefaultCovariance(defaultCovariance);
        }

        @Override
        public MultivariateDecorrelator learn(Collection<? extends Vectorizable> values) {
            return MultivariateDecorrelator.learnDiagonalCovariance(values, this.getDefaultCovariance());
        }

        public double getDefaultCovariance() {
            return this.defaultCovariance;
        }

        public void setDefaultCovariance(double defaultCovariance) {
            if (defaultCovariance < 0.0) {
                throw new IllegalArgumentException("defaultCovariance cannot be negative.");
            }
            this.defaultCovariance = defaultCovariance;
        }
    }

    public static class FullCovarianceLearner
    extends AbstractCloneableSerializable
    implements BatchLearner<Collection<? extends Vectorizable>, MultivariateDecorrelator> {
        public static final double DEFAULT_DEFAULT_COVARIANCE = 1.0E-5;
        protected double defaultCovariance;

        public FullCovarianceLearner() {
            this(1.0E-5);
        }

        public FullCovarianceLearner(double defaultCovariance) {
            this.setDefaultCovariance(defaultCovariance);
        }

        @Override
        public MultivariateDecorrelator learn(Collection<? extends Vectorizable> values) {
            return MultivariateDecorrelator.learnFullCovariance(values, this.getDefaultCovariance());
        }

        public double getDefaultCovariance() {
            return this.defaultCovariance;
        }

        public void setDefaultCovariance(double defaultCovariance) {
            if (defaultCovariance < 0.0) {
                throw new IllegalArgumentException("defaultCovariance cannot be negative.");
            }
            this.defaultCovariance = defaultCovariance;
        }
    }
}

