/*
 * Decompiled with CFR 0.152.
 */
package gov.sandia.cognition.learning.algorithm.clustering;

import gov.sandia.cognition.algorithm.MeasurablePerformanceAlgorithm;
import gov.sandia.cognition.annotation.CodeReview;
import gov.sandia.cognition.annotation.PublicationReference;
import gov.sandia.cognition.annotation.PublicationType;
import gov.sandia.cognition.learning.algorithm.AbstractAnytimeBatchLearner;
import gov.sandia.cognition.learning.algorithm.clustering.BatchClusterer;
import gov.sandia.cognition.learning.algorithm.clustering.cluster.CentroidCluster;
import gov.sandia.cognition.math.DivergenceFunction;
import gov.sandia.cognition.util.DefaultNamedValue;
import gov.sandia.cognition.util.NamedValue;
import gov.sandia.cognition.util.ObjectUtil;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;

@CodeReview(reviewer={"Kevin R. Dixon"}, date="2008-07-22", changesNeeded=false, comments={"Removed transient declaration on members.", "Fixed a few typos in javadoc.", "Added PublicationReference annotation.", "Added comment about use of direct-member access.", "Code generally looked fine."})
@PublicationReference(author={"Brendan J. Frey", "Delbert Dueck"}, title="Clustering by Passing Messages Between Data Points.", type=PublicationType.Journal, publication="Science", notes={"Volume 315, number 5814"}, pages={972, 976}, year=2007)
public class AffinityPropagation<DataType>
extends AbstractAnytimeBatchLearner<Collection<? extends DataType>, Collection<CentroidCluster<DataType>>>
implements BatchClusterer<DataType, CentroidCluster<DataType>>,
MeasurablePerformanceAlgorithm {
    public static final int DEFAULT_MAX_ITERATIONS = 100;
    public static final double DEFAULT_SELF_DIVERGENCE = 0.0;
    public static final double DEFAULT_DAMPING_FACTOR = 0.5;
    protected DivergenceFunction<? super DataType, ? super DataType> divergence;
    private double selfDivergence;
    protected double dampingFactor;
    protected double oneMinusDampingFactor;
    protected transient int exampleCount;
    protected ArrayList<DataType> examples;
    protected double[][] similarities;
    protected double[][] responsibilities;
    protected double[][] availabilities;
    protected int[] assignments;
    protected int changedCount;
    protected HashMap<Integer, CentroidCluster<DataType>> clusters;

    public AffinityPropagation() {
        this(null, 0.0);
    }

    public AffinityPropagation(DivergenceFunction<? super DataType, ? super DataType> divergence, double selfDivergence) {
        this(divergence, selfDivergence, 0.5);
    }

    public AffinityPropagation(DivergenceFunction<? super DataType, ? super DataType> divergence, double selfDivergence, double dampingFactor) {
        this(divergence, selfDivergence, dampingFactor, 100);
    }

    public AffinityPropagation(DivergenceFunction<? super DataType, ? super DataType> divergence, double selfDivergence, double dampingFactor, int maxIterations) {
        super(maxIterations);
        this.setDivergence(divergence);
        this.setSelfDivergence(selfDivergence);
        this.setDampingFactor(dampingFactor);
    }

    @Override
    public AffinityPropagation<DataType> clone() {
        AffinityPropagation result = (AffinityPropagation)super.clone();
        result.divergence = ObjectUtil.cloneSafe(this.divergence);
        result.exampleCount = 0;
        result.examples = null;
        result.similarities = null;
        result.responsibilities = null;
        result.availabilities = null;
        result.assignments = null;
        result.changedCount = 0;
        result.clusters = null;
        return result;
    }

    @Override
    protected boolean initializeAlgorithm() {
        int i;
        if (this.getData() == null || ((Collection)this.getData()).size() <= 0) {
            return false;
        }
        this.setExamples(new ArrayList((Collection)this.getData()));
        this.setSimilarities(new double[this.exampleCount][this.exampleCount]);
        this.setResponsibilities(new double[this.exampleCount][this.exampleCount]);
        this.setAvailabilities(new double[this.exampleCount][this.exampleCount]);
        for (i = 0; i < this.exampleCount; ++i) {
            DataType exampleI = this.examples.get(i);
            for (int j = 0; j < this.exampleCount; ++j) {
                double similarity;
                DataType exampleJ = this.examples.get(j);
                this.similarities[i][j] = similarity = -this.divergence.evaluate(exampleI, exampleJ);
            }
        }
        for (i = 0; i < this.exampleCount; ++i) {
            this.similarities[i][i] = -this.selfDivergence;
        }
        this.setAssignments(new int[this.exampleCount]);
        this.setChangedCount(this.exampleCount);
        this.setClusters(new HashMap<Integer, CentroidCluster<DataType>>());
        for (i = 0; i < this.exampleCount; ++i) {
            this.assignments[i] = -1;
        }
        return true;
    }

    @Override
    protected boolean step() {
        this.updateResponsibilities();
        this.updateAvailabilities();
        this.setChangedCount(0);
        this.updateAssignments();
        return this.getChangedCount() > 0;
    }

    protected void updateResponsibilities() {
        for (int i = 0; i < this.exampleCount; ++i) {
            for (int k = 0; k < this.exampleCount; ++k) {
                double max = Double.NEGATIVE_INFINITY;
                for (int c = 0; c < this.exampleCount; ++c) {
                    double value;
                    if (c == k || !((value = this.availabilities[i][c] + this.similarities[i][c]) > max)) continue;
                    max = value;
                }
                double responsibility = this.similarities[i][k] - max;
                double oldResponsibility = this.responsibilities[i][k];
                this.responsibilities[i][k] = this.dampingFactor * oldResponsibility + this.oneMinusDampingFactor * responsibility;
            }
        }
    }

    protected void updateAvailabilities() {
        for (int i = 0; i < this.exampleCount; ++i) {
            for (int k = 0; k < this.exampleCount; ++k) {
                double availability = 0.0;
                for (int j = 0; j < this.exampleCount; ++j) {
                    double responsibility;
                    if (j == i || j == k || !((responsibility = this.responsibilities[j][k]) > 0.0)) continue;
                    availability += responsibility;
                }
                if (i != k) {
                    availability += this.responsibilities[k][k];
                    availability = Math.min(0.0, availability);
                }
                double oldAvailability = this.availabilities[i][k];
                this.availabilities[i][k] = this.dampingFactor * oldAvailability + this.oneMinusDampingFactor * availability;
            }
        }
    }

    protected void updateAssignments() {
        this.setClusters(new HashMap<Integer, CentroidCluster<DataType>>());
        for (int i = 0; i < this.exampleCount; ++i) {
            int assignment = -1;
            double maximum = Double.NEGATIVE_INFINITY;
            for (int k = 0; k < this.exampleCount; ++k) {
                double value = this.availabilities[i][k] + this.responsibilities[i][k];
                if (assignment >= 0 && !(value > maximum)) continue;
                assignment = k;
                maximum = value;
            }
            this.assignCluster(i, assignment);
        }
    }

    protected void assignCluster(int i, int newAssignment) {
        double oldAssignment = this.assignments[i];
        if ((double)newAssignment != oldAssignment) {
            ++this.changedCount;
        }
        this.assignments[i] = newAssignment;
        DataType example = this.examples.get(i);
        CentroidCluster<DataType> newCluster = this.clusters.get(newAssignment);
        if (newCluster == null) {
            DataType exemplar = this.examples.get(newAssignment);
            newCluster = new CentroidCluster<DataType>(exemplar);
            newCluster.setIndex(newAssignment);
            this.clusters.put(newAssignment, newCluster);
        }
        ((ArrayList)newCluster.getMembers()).add(example);
    }

    @Override
    protected void cleanupAlgorithm() {
        this.setExamples(null);
        this.setSimilarities(null);
        this.setResponsibilities(null);
        this.setAvailabilities(null);
    }

    @Override
    public ArrayList<CentroidCluster<DataType>> getResult() {
        if (this.getClusters() == null) {
            return null;
        }
        return new ArrayList<CentroidCluster<DataType>>(this.getClusters().values());
    }

    public DivergenceFunction<? super DataType, ? super DataType> getDivergence() {
        return this.divergence;
    }

    public void setDivergence(DivergenceFunction<? super DataType, ? super DataType> divergence) {
        this.divergence = divergence;
    }

    public double getSelfDivergence() {
        return this.selfDivergence;
    }

    public void setSelfDivergence(double selfDivergence) {
        this.selfDivergence = selfDivergence;
    }

    public double getDampingFactor() {
        return this.dampingFactor;
    }

    public void setDampingFactor(double dampingFactor) {
        if (dampingFactor < 0.0 || dampingFactor > 1.0) {
            throw new IllegalArgumentException("The damping factor must be between 0.0 and 1.0.");
        }
        this.dampingFactor = dampingFactor;
        this.oneMinusDampingFactor = 1.0 - this.dampingFactor;
    }

    protected ArrayList<DataType> getExamples() {
        return this.examples;
    }

    protected void setExamples(ArrayList<DataType> examples) {
        this.examples = examples;
        this.exampleCount = examples == null ? 0 : examples.size();
    }

    protected double[][] getSimilarities() {
        return this.similarities;
    }

    protected void setSimilarities(double[][] similarities) {
        this.similarities = similarities;
    }

    protected double[][] getResponsibilities() {
        return this.responsibilities;
    }

    protected void setResponsibilities(double[][] responsibilities) {
        this.responsibilities = responsibilities;
    }

    protected double[][] getAvailabilities() {
        return this.availabilities;
    }

    protected void setAvailabilities(double[][] availabilities) {
        this.availabilities = availabilities;
    }

    protected int[] getAssignments() {
        return this.assignments;
    }

    protected void setAssignments(int[] assignments) {
        this.assignments = assignments;
    }

    public int getChangedCount() {
        return this.changedCount;
    }

    protected void setChangedCount(int changedCount) {
        this.changedCount = changedCount;
    }

    protected HashMap<Integer, CentroidCluster<DataType>> getClusters() {
        return this.clusters;
    }

    protected void setClusters(HashMap<Integer, CentroidCluster<DataType>> clusters) {
        this.clusters = clusters;
    }

    public NamedValue<Integer> getPerformance() {
        return new DefaultNamedValue<Integer>("number changed", this.getChangedCount());
    }
}

