package ucd.mlg.validation.hierarchical;

import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import no.uib.cipr.matrix.Vector;
import ucd.mlg.clustering.OverlappingClustering;
import ucd.mlg.clustering.hierarchical.util.SoftClusterNode;
import ucd.mlg.math.BinomialPValueCalculator;
import ucd.mlg.math.PValueCalculator;
import ucd.mlg.matrix.ColumnVectorMatrix;
import ucd.mlg.metrics.similarity.PearsonCorrelation;
import ucd.mlg.util.Pair;
import ucd.mlg.validation.hierarchical.NodeValidation;

/* loaded from: input_file:ucd/mlg/validation/hierarchical/NodeValidator.class */
public class NodeValidator {
    public static final double DEFAULT_WEIGHT_THRESHOLD = 1.0E-14d;
    protected OverlappingClustering classes;
    protected PearsonCorrelation metric = new PearsonCorrelation(false);
    protected PValueCalculator pValue = new BinomialPValueCalculator();
    protected ColumnVectorMatrix P;
    protected HashMap<Pair<SoftClusterNode, Integer>, NodeValidation> cache;

    public void init(OverlappingClustering overlappingClustering) {
        this.classes = overlappingClustering;
        if (overlappingClustering == null) {
            this.P = null;
        } else {
            this.P = new ColumnVectorMatrix(overlappingClustering.buildPartitionMatrix());
        }
        clearCache();
    }

    public void clearCache() {
        this.cache = new HashMap<>();
    }

    public boolean hasClasses() {
        return this.classes != null;
    }

    public void reset() {
        this.classes = null;
        this.P = null;
        clearCache();
    }

    public NodeValidation validate(SoftClusterNode softClusterNode, int i) {
        if (this.classes == null) {
            return null;
        }
        Pair<SoftClusterNode, Integer> pair = new Pair<>(softClusterNode, Integer.valueOf(i));
        if (this.cache.containsKey(pair)) {
            return this.cache.get(pair);
        }
        double[] dArr = new double[5];
        Vector weights = softClusterNode.getWeights();
        dArr[0] = this.metric.correlation(weights, this.P.getColumn(i));
        int size = weights.size();
        int clusterSize = this.classes.getClusterSize(i);
        int i2 = 0;
        for (int i3 = 0; i3 < size; i3++) {
            if (weights.get(i3) >= 1.0E-14d) {
                i2++;
            }
        }
        int i4 = 0;
        Iterator<Integer> it = this.classes.get(i).iterator();
        while (it.hasNext()) {
            if (weights.get(it.next().intValue()) >= 1.0E-14d) {
                i4++;
            }
        }
        dArr[1] = this.pValue.calculate(size, clusterSize, i2, i4);
        dArr[2] = i4 / i2;
        dArr[3] = i4 / clusterSize;
        dArr[4] = i4;
        NodeValidation nodeValidation = new NodeValidation(dArr);
        this.cache.put(pair, nodeValidation);
        return nodeValidation;
    }

    public double validate(SoftClusterNode softClusterNode, int i, NodeValidation.NodeValidationType nodeValidationType) {
        return validate(softClusterNode, i).getValue(nodeValidationType);
    }

    public NodeValidation[] validateAllClasses(SoftClusterNode softClusterNode) {
        if (this.classes == null) {
            return null;
        }
        int size = this.classes.size();
        NodeValidation[] nodeValidationArr = new NodeValidation[size];
        for (int i = 0; i < size; i++) {
            nodeValidationArr[i] = validate(softClusterNode, i);
        }
        return nodeValidationArr;
    }

    public double[] validateAllClasses(SoftClusterNode softClusterNode, NodeValidation.NodeValidationType nodeValidationType) {
        if (this.classes == null) {
            return null;
        }
        return getAllValues(validateAllClasses(softClusterNode), nodeValidationType);
    }

    public NodeValidation[] validateAllNodes(List<SoftClusterNode> list, int i) {
        if (this.classes == null) {
            return null;
        }
        int size = list.size();
        NodeValidation[] nodeValidationArr = new NodeValidation[size];
        for (int i2 = 0; i2 < size; i2++) {
            nodeValidationArr[i2] = validate(list.get(i2), i);
        }
        return nodeValidationArr;
    }

    public double[] validateAllNodes(List<SoftClusterNode> list, int i, NodeValidation.NodeValidationType nodeValidationType) {
        if (this.classes == null) {
            return null;
        }
        return getAllValues(validateAllNodes(list, i), nodeValidationType);
    }

    public String toString() {
        return getClass().getSimpleName();
    }

    protected double[] getAllValues(NodeValidation[] nodeValidationArr, NodeValidation.NodeValidationType nodeValidationType) {
        double[] dArr = new double[nodeValidationArr.length];
        int indexOf = NodeValidation.indexOf(nodeValidationType);
        for (int i = 0; i < dArr.length; i++) {
            dArr[i] = nodeValidationArr[i].getValue(indexOf);
        }
        return dArr;
    }
}
