package ucd.mlg.clustering.ensemble.generation;

import java.util.Random;
import ucd.mlg.clustering.ClusteringException;
import ucd.mlg.clustering.HardClustering;
import ucd.mlg.clustering.capability.PredictiveSamplingClusterer;
import ucd.mlg.core.data.Dataset;

/* loaded from: input_file:ucd/mlg/clustering/ensemble/generation/FoldGenerator.class */
public class FoldGenerator extends IterativeGenerator implements FixedKGenerator {
    protected static Random random;
    protected boolean[][] trainingMask;
    protected boolean[][] testingMask;

    public FoldGenerator(PredictiveSamplingClusterer predictiveSamplingClusterer, int i) {
        super(predictiveSamplingClusterer, i);
        predictiveSamplingClusterer.setPredict(true);
        if (random == null) {
            random = new Random(System.currentTimeMillis());
        }
    }

    public FoldGenerator(PredictiveSamplingClusterer predictiveSamplingClusterer) {
        this(predictiveSamplingClusterer, 100);
    }

    @Override // ucd.mlg.clustering.ensemble.generation.IterativeGenerator, ucd.mlg.clustering.ensemble.Generator
    public void init(Dataset dataset) {
        super.init(dataset);
        int size = dataset.size();
        this.trainingMask = new boolean[this.members][size];
        this.testingMask = new boolean[this.members][size];
        for (int i = 0; i < this.members; i++) {
            int i2 = 0;
            while (i2 < size / 2) {
                int nextInt = random.nextInt(size);
                if (!this.trainingMask[i][nextInt]) {
                    this.trainingMask[i][nextInt] = true;
                    i2++;
                }
            }
            for (int i3 = 0; i3 < size; i3++) {
                this.testingMask[i][i3] = !this.trainingMask[i][i3];
            }
        }
    }

    public HardClustering nextTrainingClustering() throws ClusteringException {
        PredictiveSamplingClusterer baseClusterer = getBaseClusterer();
        baseClusterer.setMask(this.trainingMask[this.currentMemberIndex]);
        return baseClusterer.findClusters(this.dataset);
    }

    @Override // ucd.mlg.clustering.ensemble.generation.IterativeGenerator, ucd.mlg.clustering.ensemble.Generator
    public HardClustering nextClustering() throws ClusteringException {
        PredictiveSamplingClusterer baseClusterer = getBaseClusterer();
        baseClusterer.setMask(this.testingMask[this.currentMemberIndex]);
        HardClustering findClusters = baseClusterer.findClusters(this.dataset);
        this.currentMemberIndex++;
        return findClusters;
    }

    @Override // ucd.mlg.clustering.ensemble.generation.FixedKGenerator
    public int getK() {
        return getBaseClusterer().getK();
    }

    @Override // ucd.mlg.clustering.ensemble.generation.FixedKGenerator
    public void setK(int i) {
        getBaseClusterer().setK(i);
        reset();
    }

    @Override // ucd.mlg.clustering.ensemble.generation.IterativeGenerator
    public PredictiveSamplingClusterer getBaseClusterer() {
        return (PredictiveSamplingClusterer) this.baseClusterer;
    }
}
