package ucd.mlg.clustering.nmf.factor;

import no.uib.cipr.matrix.AbstractTriangPackMatrix;
import no.uib.cipr.matrix.DenseMatrix;
import no.uib.cipr.matrix.Matrix;
import ucd.mlg.math.Functions;
import ucd.mlg.matrix.MatrixStats;
import ucd.mlg.matrix.MatrixUtils;

/* loaded from: input_file:ucd/mlg/clustering/nmf/factor/DivergenceFactorization.class */
public class DivergenceFactorization implements FactorizationAlgorithm {
    protected static final int CONVERGENCE_CHECK_COUNT = 2;

    /* JADX WARN: Multi-variable type inference failed */
    @Override // ucd.mlg.clustering.nmf.factor.FactorizationAlgorithm
    public int factor(Matrix matrix, DenseMatrix[] denseMatrixArr, int i) throws FactorizationException {
        if (denseMatrixArr == 0 || denseMatrixArr.length < 2) {
            throw new FactorizationException("Incorrect number of initial factors specified.");
        }
        AbstractTriangPackMatrix abstractTriangPackMatrix = denseMatrixArr[0];
        DenseMatrix denseMatrix = denseMatrixArr[1];
        if (abstractTriangPackMatrix.numColumns() != denseMatrix.numColumns()) {
            throw new FactorizationException(String.format("Both factors should contain the same number of basis vectors: %d!=%d", Integer.valueOf(abstractTriangPackMatrix.numColumns()), Integer.valueOf(denseMatrix.numColumns())));
        }
        if (abstractTriangPackMatrix.numRows() != matrix.numRows()) {
            throw new FactorizationException(String.format("Incorrect number of rows in first factor: %d!=%d", Integer.valueOf(abstractTriangPackMatrix.numRows()), Integer.valueOf(matrix.numRows())));
        }
        if (denseMatrix.numRows() != matrix.numColumns()) {
            throw new FactorizationException(String.format("Incorrect number of rows in second factor: %d!=%d", Integer.valueOf(denseMatrix.numRows()), Integer.valueOf(matrix.numColumns())));
        }
        int numColumns = abstractTriangPackMatrix.numColumns();
        int numRows = abstractTriangPackMatrix.numRows();
        int numRows2 = denseMatrix.numRows();
        DenseMatrix denseMatrix2 = new DenseMatrix(numRows, numRows2);
        Matrix denseMatrix3 = new DenseMatrix(numRows2, numColumns);
        Matrix denseMatrix4 = new DenseMatrix(numRows, numColumns);
        FactorConvergence factorConvergence = new FactorConvergence();
        factorConvergence.initConvergence(denseMatrix);
        int i2 = 1;
        while (i2 <= i) {
            abstractTriangPackMatrix.transBmult(denseMatrix, denseMatrix2);
            MatrixUtils.assign(denseMatrix2, matrix, Functions.safeInvDiv);
            denseMatrix2.transAmult(abstractTriangPackMatrix, denseMatrix3);
            MatrixUtils.assign(denseMatrix, denseMatrix3, Functions.mult);
            abstractTriangPackMatrix.transBmult(denseMatrix, denseMatrix2);
            MatrixUtils.assign(denseMatrix2, matrix, Functions.safeInvDiv);
            denseMatrix2.mult(denseMatrix, denseMatrix4);
            MatrixUtils.assign(abstractTriangPackMatrix, denseMatrix4, Functions.mult);
            MatrixUtils.normalizeColumnL1(abstractTriangPackMatrix);
            if (i2 % 2 == 0 && factorConvergence.isConverged(denseMatrix)) {
                break;
            }
            i2++;
        }
        return i2 - 1;
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Override // ucd.mlg.clustering.nmf.factor.FactorizationAlgorithm
    public double calcError(Matrix matrix, DenseMatrix[] denseMatrixArr) {
        if (denseMatrixArr == 0 || denseMatrixArr.length < 2) {
            throw new IllegalArgumentException("Incorrect number of initial factors specified.");
        }
        int numRows = matrix.numRows();
        int numColumns = matrix.numColumns();
        DenseMatrix denseMatrix = new DenseMatrix(numRows, numColumns);
        denseMatrixArr[0].transBmult(denseMatrixArr[1], denseMatrix);
        double d = 0.0d;
        for (int i = 0; i < numRows; i++) {
            for (int i2 = 0; i2 < numColumns; i2++) {
                double d2 = matrix.get(i, i2);
                double d3 = denseMatrix.get(i, i2);
                d += ((d2 * ((d2 == 0.0d || d3 == 0.0d) ? 0.0d : Math.log(d2 / d3))) - d2) + d3;
            }
        }
        return d;
    }

    @Override // ucd.mlg.clustering.nmf.factor.FactorizationAlgorithm
    public int getFactorCount() {
        return 2;
    }

    protected void normalizeFactors(Matrix matrix, Matrix matrix2) {
        int numRows = matrix2.numRows();
        int numRows2 = matrix.numRows();
        int numColumns = matrix.numColumns();
        double[] columnL2Norms = MatrixStats.columnL2Norms(matrix);
        for (int i = 0; i < numRows; i++) {
            for (int i2 = 0; i2 < numColumns; i2++) {
                matrix2.set(i, i2, matrix2.get(i, i2) * columnL2Norms[i2]);
            }
        }
        for (int i3 = 0; i3 < numColumns; i3++) {
            if (columnL2Norms[i3] != 0.0d) {
                for (int i4 = 0; i4 < numRows2; i4++) {
                    matrix.set(i4, i3, matrix.get(i4, i3) / columnL2Norms[i3]);
                }
            }
        }
    }
}
