package ucd.mlg.clustering.spectral;

import no.uib.cipr.matrix.DenseMatrix;
import no.uib.cipr.matrix.Matrix;
import ucd.mlg.clustering.ClusteringException;
import ucd.mlg.clustering.HardClustering;
import ucd.mlg.clustering.SoftBiclustering;
import ucd.mlg.clustering.init.InitOrthogonal;
import ucd.mlg.clustering.partitional.KMeans;
import ucd.mlg.clustering.spectral.extraction.BipartiteDecomposition;
import ucd.mlg.clustering.util.ClusterUtils;
import ucd.mlg.core.data.Dataset;
import ucd.mlg.core.data.prep.extraction.EmbeddingException;
import ucd.mlg.matrix.MatrixUtils;
import ucd.mlg.metrics.similarity.ScaledCosineSimilarity;
import ucd.mlg.metrics.similarity.SimilarityMetric;

/* loaded from: input_file:ucd/mlg/clustering/spectral/SSCClusterer.class */
public class SSCClusterer extends SpectralClusterer {
    public SSCClusterer(int i) {
        super(new KMeans(new ScaledCosineSimilarity(), i));
        getClusterer().setInitStrategy(new InitOrthogonal(InitOrthogonal.FirstSeedType.MEDIAN));
    }

    public SSCClusterer() {
        this(2);
    }

    @Override // ucd.mlg.clustering.spectral.SpectralClusterer, ucd.mlg.clustering.Clusterer
    public SoftBiclustering findClusters(Dataset dataset) throws ClusteringException {
        int size = dataset.size();
        int numFeatures = dataset.numFeatures();
        int k = getK();
        if (k > size) {
            throw new ClusteringException("Unable to cluster. Number of clusters k=" + k + " is greater than number of instances n=" + size);
        }
        if (numFeatures < size) {
            throw new ClusteringException("Cannot apply " + getClass().getSimpleName() + " to dataset where number of elements is less than the number of features");
        }
        DenseMatrix denseMatrix = new DenseMatrix(numFeatures, k);
        DenseMatrix denseMatrix2 = new DenseMatrix(size, k);
        try {
            factorize(dataset, denseMatrix, denseMatrix2);
            return new SoftBiclustering(dataset, denseMatrix2, denseMatrix);
        } catch (ClusteringException e) {
            throw new ClusteringException("Unable to clustering embedded data: " + e.getMessage());
        } catch (EmbeddingException e2) {
            throw new ClusteringException("Unable to perform spectral decomposition: " + e2.getMessage());
        }
    }

    @Override // ucd.mlg.clustering.spectral.SpectralClusterer
    public BipartiteDecomposition getDecomposition() {
        return new BipartiteDecomposition(getK());
    }

    @Override // ucd.mlg.clustering.spectral.SpectralClusterer
    public KMeans getClusterer() {
        return (KMeans) this.clusterer;
    }

    protected void factorize(Dataset dataset, DenseMatrix denseMatrix, DenseMatrix denseMatrix2) throws ClusteringException, EmbeddingException {
        Matrix featureObjectMatrix = dataset.getFeatureObjectMatrix();
        if (this.embedding == null) {
            try {
                this.embedding = getDecomposition().m180apply(dataset);
            } catch (EmbeddingException e) {
                throw new ClusteringException("Unable to perform decomposition: " + e.getMessage());
            }
        }
        KMeans clusterer = getClusterer();
        reverseMapping(clusterer.findClusters((Dataset) this.embedding), clusterer.getMetric(), featureObjectMatrix, denseMatrix, denseMatrix2);
        MatrixUtils.normalizeColumnL1(denseMatrix);
        Runtime.getRuntime().gc();
    }

    protected void reverseMapping(HardClustering hardClustering, SimilarityMetric similarityMetric, Matrix matrix, DenseMatrix denseMatrix, DenseMatrix denseMatrix2) {
        int numColumns = denseMatrix.numColumns();
        int numRows = denseMatrix.numRows();
        int numRows2 = denseMatrix2.numRows();
        DenseMatrix buildCentroidSimilarityMatrix = ClusterUtils.buildCentroidSimilarityMatrix(this.embedding, similarityMetric, ClusterUtils.buildCentroids(hardClustering));
        MatrixUtils.normalizeColumnL1(buildCentroidSimilarityMatrix);
        matrix.transAmult(MatrixUtils.viewWindow(buildCentroidSimilarityMatrix, 0, 0, numRows, numColumns), denseMatrix2);
        MatrixUtils.normalizeRowL1(denseMatrix2);
        DenseMatrix denseMatrix3 = new DenseMatrix(numRows2, numColumns);
        for (int i = 0; i < numRows2; i++) {
            denseMatrix3.set(i, hardClustering.getClusterIndex(numRows + i), 1.0d);
        }
        MatrixUtils.normalizeColumnL1(denseMatrix3);
        matrix.mult(denseMatrix3, denseMatrix);
    }
}
