package ucd.mlg.clustering.ensemble.integration;

import no.uib.cipr.matrix.DenseMatrix;
import no.uib.cipr.matrix.DenseVector;
import ucd.mlg.clustering.Clustering;
import ucd.mlg.clustering.ClusteringException;
import ucd.mlg.clustering.HardClustering;
import ucd.mlg.clustering.SoftClustering;
import ucd.mlg.clustering.ensemble.IntegrationException;
import ucd.mlg.clustering.util.ClusterUtils;
import ucd.mlg.core.data.Dataset;
import ucd.mlg.core.data.util.DataUtils;
import ucd.mlg.metrics.similarity.SimilarityMetric;
import ucd.mlg.validation.util.HungarianSolver;
import ucd.mlg.validation.util.MatchingException;

/* loaded from: input_file:ucd/mlg/clustering/ensemble/integration/CentroidCorrespondenceIntegrator.class */
public class CentroidCorrespondenceIntegrator extends AbstractIntegrator {
    protected HungarianSolver correspondenceFunction;
    protected SimilarityMetric metric;
    protected DenseVector[] centroids;
    protected int clusteringCount;

    public CentroidCorrespondenceIntegrator(SimilarityMetric similarityMetric, int i) {
        super(i);
        this.correspondenceFunction = new HungarianSolver();
        this.metric = similarityMetric;
    }

    public CentroidCorrespondenceIntegrator(SimilarityMetric similarityMetric) {
        this(similarityMetric, 2);
    }

    public CentroidCorrespondenceIntegrator(int i) {
        this(DataUtils.getDefaultSimilarityMetric(), i);
    }

    public CentroidCorrespondenceIntegrator() {
        this(2);
    }

    @Override // ucd.mlg.clustering.ensemble.integration.AbstractIntegrator, ucd.mlg.clustering.ensemble.Integrator
    public void init(Dataset dataset) {
        super.init(dataset);
        this.centroids = null;
        this.clusteringCount = 0;
    }

    public void init(Dataset dataset, HardClustering hardClustering) {
        init(dataset);
        try {
            addClustering(hardClustering);
        } catch (ClusteringException e) {
            throw new IllegalArgumentException("Invalid initial clustering: " + e.getMessage());
        }
    }

    @Override // ucd.mlg.clustering.ensemble.Integrator
    public void addClustering(Clustering clustering) throws IntegrationException {
        int size = clustering.size();
        if (this.k != size) {
            throw new IntegrationException("Cannot add base clustering of different size (" + size + "!=" + this.k + ")");
        }
        DenseVector[] buildCentroids = ClusterUtils.buildCentroids(clustering);
        if (this.centroids == null) {
            this.centroids = buildCentroids;
        } else {
            DenseMatrix denseMatrix = new DenseMatrix(size, size);
            for (int i = 0; i < size; i++) {
                DenseVector copy = this.centroids[i].copy();
                copy.scale(1.0d / this.clusteringCount);
                for (int i2 = 0; i2 < size; i2++) {
                    denseMatrix.set(i, i2, 1.0d - this.metric.similarity(copy, buildCentroids[i2]));
                }
            }
            try {
                int[] match = this.correspondenceFunction.match(denseMatrix);
                for (int i3 = 0; i3 < size; i3++) {
                    this.centroids[match[i3]].add(buildCentroids[i3]);
                }
            } catch (MatchingException e) {
                System.err.println("Warning: Failed to perform matching procedure");
                return;
            }
        }
        this.clusteringCount++;
    }

    @Override // ucd.mlg.clustering.ensemble.Integrator
    public SoftClustering findClusters() throws IntegrationException {
        if (this.centroids == null) {
            throw new IntegrationException("Cannot perform integration. No base clusterings have been added to the ensemble.");
        }
        for (int i = 0; i < this.k; i++) {
            this.centroids[i].scale(1.0d / this.clusteringCount);
        }
        return new SoftClustering(this.dataset, ClusterUtils.buildCentroidSimilarityMatrix(this.dataset, this.metric, this.centroids));
    }

    @Override // ucd.mlg.clustering.ensemble.integration.AbstractIntegrator
    public String toString() {
        return String.valueOf(getClass().getSimpleName()) + " (k=" + this.k + " metric=" + this.metric.getClass().getSimpleName() + ")";
    }
}
