package ucd.mlg.metrics.similarity;

import no.uib.cipr.matrix.Matrix;
import no.uib.cipr.matrix.MatrixEntry;
import no.uib.cipr.matrix.Vector;
import no.uib.cipr.matrix.sparse.CompRowMatrix;
import ucd.mlg.core.data.Dataset;
import ucd.mlg.core.data.SparseColumnDataset;
import ucd.mlg.matrix.DensePairwiseMatrix;

/* loaded from: input_file:ucd/mlg/metrics/similarity/CosineSimilarity.class */
public class CosineSimilarity extends AbstractSimilarityMetric {
    @Override // ucd.mlg.metrics.similarity.SimilarityMetric
    public double similarity(Vector vector, Vector vector2) {
        int size = vector.size();
        double d = 0.0d;
        double d2 = 0.0d;
        double d3 = 0.0d;
        for (int i = 0; i < size; i++) {
            double d4 = vector.get(i);
            double d5 = vector2.get(i);
            d += d4 * d5;
            d2 += d4 * d4;
            d3 += d5 * d5;
        }
        double sqrt = Math.sqrt(d2 * d3);
        if (sqrt != 0.0d) {
            return d / sqrt;
        }
        return 0.0d;
    }

    @Override // ucd.mlg.metrics.similarity.SimilarityMetric
    public double distance(Vector vector, Vector vector2) {
        return 1.0d - similarity(vector, vector2);
    }

    @Override // ucd.mlg.metrics.similarity.AbstractSimilarityMetric, ucd.mlg.metrics.similarity.SimilarityMetric
    public double[] distance(Dataset dataset, Vector vector) {
        double[] similarity = similarity(dataset, vector);
        for (int i = 0; i < similarity.length; i++) {
            similarity[i] = 1.0d - similarity[i];
        }
        return similarity;
    }

    @Override // ucd.mlg.metrics.similarity.AbstractSimilarityMetric, ucd.mlg.metrics.similarity.SimilarityMetric
    public double[] similarity(Dataset dataset, Vector vector) {
        if (dataset instanceof SparseColumnDataset) {
            return similarity((SparseColumnDataset) dataset, vector);
        }
        Matrix featureObjectMatrix = dataset.getFeatureObjectMatrix();
        int numRows = featureObjectMatrix.numRows();
        int numColumns = featureObjectMatrix.numColumns();
        double d = 0.0d;
        for (int i = 0; i < numRows; i++) {
            double d2 = vector.get(i);
            d += d2 * d2;
        }
        double[] dArr = new double[numColumns];
        double[] dArr2 = new double[numColumns];
        for (MatrixEntry matrixEntry : featureObjectMatrix) {
            int column = matrixEntry.column();
            double d3 = matrixEntry.get();
            dArr[column] = dArr[column] + (d3 * vector.get(matrixEntry.row()));
            dArr2[column] = dArr2[column] + (d3 * d3);
        }
        for (int i2 = 0; i2 < numColumns; i2++) {
            double sqrt = Math.sqrt(dArr2[i2] * d);
            dArr[i2] = sqrt != 0.0d ? dArr[i2] / sqrt : 0.0d;
        }
        return dArr;
    }

    public double[] similarity(SparseColumnDataset sparseColumnDataset, Vector vector) {
        CompRowMatrix featureObjectMatrix = sparseColumnDataset.getFeatureObjectMatrix();
        int size = sparseColumnDataset.size();
        int numFeatures = sparseColumnDataset.numFeatures();
        double[] data = featureObjectMatrix.getData();
        int[] columnIndices = featureObjectMatrix.getColumnIndices();
        int[] rowPointers = featureObjectMatrix.getRowPointers();
        double[] dArr = new double[size];
        double d = 0.0d;
        for (int i = 0; i < numFeatures; i++) {
            double d2 = vector.get(i);
            d += d2 * d2;
            int i2 = rowPointers[i];
            int i3 = rowPointers[i + 1];
            while (true) {
                i3--;
                if (i3 < i2) {
                    break;
                }
                int i4 = columnIndices[i3];
                dArr[i4] = dArr[i4] + (data[i3] * d2);
            }
        }
        double[] dArr2 = new double[size];
        for (int i5 = 0; i5 < data.length; i5++) {
            int i6 = columnIndices[i5];
            dArr2[i6] = dArr2[i6] + (data[i5] * data[i5]);
        }
        for (int i7 = 0; i7 < size; i7++) {
            double sqrt = Math.sqrt(dArr2[i7] * d);
            dArr[i7] = sqrt != 0.0d ? dArr[i7] / sqrt : 0.0d;
        }
        return dArr;
    }

    @Override // ucd.mlg.metrics.similarity.AbstractSimilarityMetric, ucd.mlg.metrics.similarity.SimilarityMetric
    public DensePairwiseMatrix buildSimilarityMatrix(Dataset dataset) {
        return dataset instanceof SparseColumnDataset ? buildMatrix((SparseColumnDataset) dataset, null, false) : super.buildSimilarityMatrix(dataset);
    }

    @Override // ucd.mlg.metrics.similarity.AbstractSimilarityMetric, ucd.mlg.metrics.similarity.SimilarityMetric
    public DensePairwiseMatrix buildDistanceMatrix(Dataset dataset) {
        return dataset instanceof SparseColumnDataset ? buildMatrix((SparseColumnDataset) dataset, null, true) : super.buildDistanceMatrix(dataset);
    }

    public static DensePairwiseMatrix buildMatrix(SparseColumnDataset sparseColumnDataset, boolean[] zArr, boolean z) {
        CompRowMatrix featureObjectMatrix = sparseColumnDataset.getFeatureObjectMatrix();
        int numColumns = featureObjectMatrix.numColumns();
        int numRows = featureObjectMatrix.numRows();
        double[] data = featureObjectMatrix.getData();
        int[] columnIndices = featureObjectMatrix.getColumnIndices();
        int[] rowPointers = featureObjectMatrix.getRowPointers();
        DensePairwiseMatrix densePairwiseMatrix = new DensePairwiseMatrix(numColumns);
        double[] data2 = densePairwiseMatrix.getData();
        for (int i = 0; i < numRows; i++) {
            if (zArr == null || zArr[i]) {
                int i2 = rowPointers[i];
                int i3 = rowPointers[i + 1];
                while (true) {
                    i3--;
                    if (i3 < i2) {
                        break;
                    }
                    int i4 = rowPointers[i + 1];
                    while (true) {
                        i4--;
                        if (i4 < i2) {
                            break;
                        }
                        int i5 = columnIndices[i3];
                        int i6 = columnIndices[i4];
                        if (i5 <= i6) {
                            int i7 = i5 + ((i6 * (i6 + 1)) / 2);
                            data2[i7] = data2[i7] + (data[i3] * data[i4]);
                        }
                    }
                }
            }
        }
        double[] dArr = new double[numColumns];
        for (int i8 = 0; i8 < data.length; i8++) {
            int i9 = columnIndices[i8];
            dArr[i9] = dArr[i9] + (data[i8] * data[i8]);
        }
        if (z) {
            for (int i10 = 0; i10 < numColumns; i10++) {
                int i11 = i10 + ((i10 * (i10 + 1)) / 2);
                data2[i11] = 0.0d;
                for (int i12 = i10 + 1; i12 < numColumns; i12++) {
                    i11 += i12;
                    double d = dArr[i10] * dArr[i12];
                    data2[i11] = d == 0.0d ? 1.0d : 1.0d - (data2[i11] / d);
                }
            }
        } else {
            for (int i13 = 0; i13 < numColumns; i13++) {
                int i14 = i13 + ((i13 * (i13 + 1)) / 2);
                data2[i14] = 1.0d;
                for (int i15 = i13 + 1; i15 < numColumns; i15++) {
                    i14 += i15;
                    double sqrt = Math.sqrt(dArr[i13] * dArr[i15]);
                    data2[i14] = sqrt == 0.0d ? 0.0d : data2[i14] / sqrt;
                }
            }
        }
        return densePairwiseMatrix;
    }
}
