package ucd.mlg.clustering.kernel;

import java.util.Arrays;
import ucd.mlg.clustering.ClusteringException;
import ucd.mlg.clustering.HardClustering;
import ucd.mlg.clustering.capability.KernelClusterer;
import ucd.mlg.clustering.capability.PredictiveSamplingClusterer;
import ucd.mlg.clustering.capability.WeightedClusterer;
import ucd.mlg.clustering.util.AbstractIterativeClusterer;
import ucd.mlg.clustering.util.ClusterUtils;
import ucd.mlg.clustering.util.PredictiveSamplingDelegate;
import ucd.mlg.core.data.Dataset;
import ucd.mlg.matrix.DensePairwiseMatrix;
import ucd.mlg.matrix.MatrixUtils;
import ucd.mlg.matrix.PairwiseMatrix;
import ucd.mlg.metrics.kernel.KernelFunction;

/* loaded from: input_file:ucd/mlg/clustering/kernel/KernelKMeans.class */
public class KernelKMeans extends AbstractIterativeClusterer implements KernelClusterer, PredictiveSamplingClusterer, WeightedClusterer {
    static final int DEFAULT_MAX_PING_PONG_ITERATIONS = 5;
    protected PredictiveSamplingDelegate dSampling;
    protected KernelFunction kernel;
    protected int maxPingPongIterations;
    protected DensePairwiseMatrix K;
    protected double[] weights;

    public KernelKMeans(KernelFunction kernelFunction, int i) {
        this.dSampling = new PredictiveSamplingDelegate();
        setK(i);
        setKernel(kernelFunction);
        this.maxPingPongIterations = 5;
    }

    public KernelKMeans(KernelFunction kernelFunction) {
        this(kernelFunction, 2);
    }

    @Override // ucd.mlg.clustering.Clusterer
    public HardClustering findClusters(Dataset dataset) throws ClusteringException {
        checkK(dataset);
        this.dSampling.checkMask(dataset);
        HardClustering hardClustering = ClusterUtils.toHardClustering(initalizeClustering(dataset, getK()));
        iterativeAssignment(hardClustering);
        return hardClustering;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void initKernel() {
        PairwiseMatrix matrix = this.kernel.getMatrix();
        if (matrix instanceof DensePairwiseMatrix) {
            this.K = (DensePairwiseMatrix) matrix;
        } else {
            this.K = new DensePairwiseMatrix(matrix);
        }
        if (this.weights == null || this.weights.length != this.K.numRows()) {
            this.weights = new double[this.K.numRows()];
            Arrays.fill(this.weights, 1.0d);
        }
    }

    protected void iterativeAssignment(HardClustering hardClustering) throws ClusteringException {
        int maxIterations = getMaxIterations();
        boolean[] mask = getMask();
        int size = hardClustering.size();
        if (this.K == null) {
            initKernel();
        }
        int numRows = this.K.numRows();
        double[] dArr = new double[numRows * size];
        double[] dArr2 = new double[size];
        double[] dArr3 = new double[size];
        int[] iArr = new int[numRows];
        int[] iArr2 = new int[numRows];
        int[] iArr3 = new int[numRows];
        System.arraycopy(hardClustering.getMembership(), 0, iArr2, 0, numRows);
        System.arraycopy(iArr2, 0, iArr3, 0, numRows);
        double[] diagonalVaues = MatrixUtils.getDiagonalVaues(this.K);
        int i = 0;
        double d = Double.MAX_VALUE;
        int i2 = 1;
        while (true) {
            if (i2 > maxIterations) {
                break;
            }
            for (int i3 = 0; i3 < size; i3++) {
                dArr3[i3] = 0.0d;
                dArr2[i3] = 0.0d;
            }
            Arrays.fill(dArr, 0.0d);
            for (int i4 = 0; i4 < numRows; i4++) {
                int i5 = iArr2[i4];
                if (i5 != -1) {
                    dArr2[i5] = dArr2[i5] + this.weights[i4];
                    int i6 = i4 + (i5 * numRows);
                    dArr[i6] = dArr[i6] + (this.weights[i4] * diagonalVaues[i4]);
                    dArr3[i5] = dArr3[i5] + (this.weights[i4] * this.weights[i4] * diagonalVaues[i4]);
                }
                for (int i7 = i4 + 1; i7 < numRows; i7++) {
                    int i8 = iArr2[i7];
                    double d2 = this.K.get(i4, i7);
                    if (i8 != -1 && i8 == i5) {
                        dArr3[i5] = dArr3[i5] + (2.0d * this.weights[i4] * this.weights[i7] * d2);
                    }
                    if (i8 != -1) {
                        int i9 = i4 + (i8 * numRows);
                        dArr[i9] = dArr[i9] + (this.weights[i7] * d2);
                    }
                    if (i5 != -1) {
                        int i10 = i7 + (i5 * numRows);
                        dArr[i10] = dArr[i10] + (this.weights[i4] * d2);
                    }
                }
            }
            for (int i11 = 0; i11 < size; i11++) {
                if (dArr2[i11] != 0.0d) {
                    int i12 = i11;
                    dArr3[i12] = dArr3[i12] / (dArr2[i11] * dArr2[i11]);
                }
            }
            int i13 = 0;
            double d3 = 0.0d;
            boolean z = i2 > 2;
            for (int i14 = 0; i14 < numRows; i14++) {
                if (mask == null || mask[i14]) {
                    int i15 = iArr2[i14];
                    int i16 = -1;
                    double d4 = Double.MAX_VALUE;
                    for (int i17 = 0; i17 < size; i17++) {
                        if (dArr2[i17] != 0.0d) {
                            double d5 = (diagonalVaues[i14] - ((2.0d * dArr[i14 + (i17 * numRows)]) / dArr2[i17])) + dArr3[i17];
                            if (d5 < d4) {
                                d4 = d5;
                                i16 = i17;
                            }
                        }
                    }
                    iArr3[i14] = i16;
                    if (i16 != -1) {
                        d3 += d4;
                        if (i15 != i16) {
                            i13++;
                            z = z && iArr[i14] == iArr3[i14];
                        }
                    }
                }
            }
            if (i13 == 0) {
                break;
            }
            if (z) {
                i++;
                if (i > this.maxPingPongIterations) {
                    if (d3 < d) {
                        System.arraycopy(iArr3, 0, iArr2, 0, numRows);
                    }
                }
            } else {
                i = 0;
            }
            d = d3;
            System.arraycopy(iArr2, 0, iArr, 0, numRows);
            System.arraycopy(iArr3, 0, iArr2, 0, numRows);
            i2++;
        }
        hardClustering.assignMembership(iArr2);
        this.dIterative.setLastIterationCount(i2 - 1);
        if (!getPredict() || mask == null) {
            this.dSampling.reset();
        } else {
            predictMembership(iArr2, dArr2, dArr3);
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void predictMembership(int[] iArr, double[] dArr, double[] dArr2) {
        boolean[] mask = getMask();
        int length = mask.length;
        int[] iArr2 = new int[length];
        for (int i = 0; i < length; i++) {
            if (mask[i]) {
                iArr2[i] = -1;
            } else {
                double[] dArr3 = new double[this.k];
                for (int i2 = 0; i2 < length; i2++) {
                    if (mask[i2] && iArr[i2] != -1) {
                        int i3 = iArr[i2];
                        dArr3[i3] = dArr3[i3] + (this.weights[i2] * this.K.get(i, i2));
                    }
                }
                iArr2[i] = -1;
                double d = Double.MAX_VALUE;
                for (int i4 = 0; i4 < this.k; i4++) {
                    if (dArr[i4] != 0.0d) {
                        double d2 = dArr2[i4] - ((2.0d * dArr3[i4]) / dArr[i4]);
                        if (d2 < d) {
                            d = d2;
                            iArr2[i] = i4;
                        }
                    }
                }
                if (iArr2[i] == -1) {
                    System.out.println("Failed to cluster " + i);
                }
            }
        }
        this.dSampling.setPredictedMembership(iArr2);
    }

    protected void reset() {
        this.K = null;
    }

    @Override // ucd.mlg.clustering.util.AbstractFixedKClusterer, ucd.mlg.clustering.capability.FixedKClusterer
    public int getK() {
        return this.k;
    }

    @Override // ucd.mlg.clustering.util.AbstractFixedKClusterer, ucd.mlg.clustering.capability.FixedKClusterer
    public void setK(int i) {
        if (i < 2) {
            throw new IllegalArgumentException("Invalid value for k: " + i + ". Number of clusters should be greater than one.");
        }
        this.k = i;
    }

    @Override // ucd.mlg.clustering.capability.KernelClusterer
    public KernelFunction getKernel() {
        return this.kernel;
    }

    @Override // ucd.mlg.clustering.capability.KernelClusterer
    public void setKernel(KernelFunction kernelFunction) {
        this.kernel = kernelFunction;
        reset();
    }

    @Override // ucd.mlg.clustering.capability.WeightedClusterer
    public void setWeights(double[] dArr) {
        this.weights = dArr;
    }

    @Override // ucd.mlg.clustering.capability.WeightedClusterer
    public double[] getWeights() {
        return this.weights;
    }

    @Override // ucd.mlg.clustering.capability.SamplingClusterer
    public void setMask(boolean[] zArr) {
        this.dSampling.setMask(zArr);
    }

    @Override // ucd.mlg.clustering.capability.SamplingClusterer
    public boolean[] getMask() {
        return this.dSampling.getMask();
    }

    @Override // ucd.mlg.clustering.capability.PredictiveSamplingClusterer
    public int[] getPredictedMembership() {
        return this.dSampling.getPredictedMembership();
    }

    @Override // ucd.mlg.clustering.capability.PredictiveSamplingClusterer
    public boolean getPredict() {
        return this.dSampling.getPredict();
    }

    @Override // ucd.mlg.clustering.capability.PredictiveSamplingClusterer
    public void setPredict(boolean z) {
        this.dSampling.setPredict(z);
    }

    @Override // ucd.mlg.clustering.util.AbstractIterativeClusterer, ucd.mlg.clustering.util.AbstractFixedKClusterer
    public String toString() {
        return String.format("%s (k=%d kernel=%s maxIters=%d init=%s)", getClass().getSimpleName(), Integer.valueOf(getK()), getKernel().toString(), Integer.valueOf(getMaxIterations()), getInitStrategy() == null ? "default" : getInitStrategy().toString());
    }
}
