/*
 * Decompiled with CFR 0.152.
 */
package org.openimaj.math.statistics.distribution.metrics;

import Jama.Matrix;
import org.openimaj.math.matrix.MatrixUtils;
import org.openimaj.math.statistics.distribution.MultivariateGaussian;
import org.openimaj.util.comparator.DistanceComparator;

public class GaussianKLDivergence
implements DistanceComparator<MultivariateGaussian> {
    @Override
    public boolean isDistance() {
        return true;
    }

    @Override
    public double compare(MultivariateGaussian o1, MultivariateGaussian o2) {
        Matrix sig0 = o1.getCovariance();
        Matrix sig1 = o2.getCovariance();
        Matrix mu0 = o1.getMean();
        Matrix mu1 = o2.getMean();
        int K = o1.numDims();
        Matrix sig1inv = sig1.inverse();
        double sigtrace = MatrixUtils.trace(sig1inv.times(sig0));
        Matrix mudiff = mu1.minus(mu0);
        double xt_s_x = mudiff.transpose().times(sig1inv).times(mudiff).get(0, 0);
        double ln_norm_sig = Math.log(sig0.norm1() / sig1.norm1());
        return 0.5 * (sigtrace + xt_s_x - (double)K - ln_norm_sig);
    }
}

