package cc.mallet.cluster.neighbor_evaluator;

import cc.mallet.classify.Classifier;
import cc.mallet.cluster.Clustering;
import cc.mallet.cluster.util.PairwiseMatrix;
import cc.mallet.types.MatrixOps;
import jregex.WildcardPattern;

/* loaded from: input_file:cc/mallet/cluster/neighbor_evaluator/MedoidEvaluator.class */
public class MedoidEvaluator extends ClassifyingNeighborEvaluator {
    private static final long serialVersionUID = 1;
    boolean singleLink;
    CombiningStrategy combiningStrategy;
    boolean mergeFirst;
    PairwiseMatrix scoreCache;

    /* loaded from: input_file:cc/mallet/cluster/neighbor_evaluator/MedoidEvaluator$Average.class */
    public static class Average implements CombiningStrategy {
        @Override // cc.mallet.cluster.neighbor_evaluator.MedoidEvaluator.CombiningStrategy
        public double combine(double[] dArr) {
            return MatrixOps.mean(dArr);
        }
    }

    /* loaded from: input_file:cc/mallet/cluster/neighbor_evaluator/MedoidEvaluator$CombiningStrategy.class */
    public interface CombiningStrategy {
        double combine(double[] dArr);
    }

    /* loaded from: input_file:cc/mallet/cluster/neighbor_evaluator/MedoidEvaluator$Maximum.class */
    public static class Maximum implements CombiningStrategy {
        @Override // cc.mallet.cluster.neighbor_evaluator.MedoidEvaluator.CombiningStrategy
        public double combine(double[] dArr) {
            return MatrixOps.max(dArr);
        }
    }

    /* loaded from: input_file:cc/mallet/cluster/neighbor_evaluator/MedoidEvaluator$Minimum.class */
    public static class Minimum implements CombiningStrategy {
        @Override // cc.mallet.cluster.neighbor_evaluator.MedoidEvaluator.CombiningStrategy
        public double combine(double[] dArr) {
            return MatrixOps.min(dArr);
        }
    }

    public MedoidEvaluator(Classifier classifier, String str) {
        super(classifier, str);
        this.singleLink = false;
        this.mergeFirst = true;
        System.out.println("Using Medoid Evaluator");
    }

    public MedoidEvaluator(Classifier classifier, String str, boolean z, boolean z2) {
        super(classifier, str);
        this.singleLink = false;
        this.mergeFirst = true;
        this.singleLink = z;
        this.mergeFirst = z2;
        System.out.println("Using Medoid Evaluator. Single link=" + z + WildcardPattern.ANY_CHAR);
    }

    @Override // cc.mallet.cluster.neighbor_evaluator.ClassifyingNeighborEvaluator, cc.mallet.cluster.neighbor_evaluator.NeighborEvaluator
    public double[] evaluate(Neighbor[] neighborArr) {
        double[] dArr = new double[neighborArr.length];
        for (int i = 0; i < neighborArr.length; i++) {
            dArr[i] = evaluate(neighborArr[i]);
        }
        return dArr;
    }

    @Override // cc.mallet.cluster.neighbor_evaluator.ClassifyingNeighborEvaluator, cc.mallet.cluster.neighbor_evaluator.NeighborEvaluator
    public double evaluate(Neighbor neighbor) {
        int[] iArr = new int[2];
        if (!(neighbor instanceof AgglomerativeNeighbor)) {
            throw new IllegalArgumentException("Expect AgglomerativeNeighbor not " + neighbor.getClass().getName());
        }
        int[][] oldClusters = ((AgglomerativeNeighbor) neighbor).getOldClusters();
        ((AgglomerativeNeighbor) neighbor).getNewCluster();
        Clustering original = neighbor.getOriginal();
        iArr[0] = getCentroid(oldClusters[0], original);
        iArr[1] = getCentroid(oldClusters[1], original);
        if (this.singleLink) {
            return getScore(new AgglomerativeNeighbor(original, original, oldClusters[0][iArr[0]], oldClusters[1][iArr[1]]));
        }
        double[] medWeights = getMedWeights(iArr[0], oldClusters[0], original);
        double[] medWeights2 = getMedWeights(iArr[1], oldClusters[1], original);
        double d = 0.0d;
        double d2 = 0.0d;
        for (int i = 0; i < oldClusters[0].length; i++) {
            for (int i2 = 0; i2 < oldClusters[1].length; i2++) {
                d += getScore(new AgglomerativeNeighbor(original, original, oldClusters[0][i], oldClusters[1][i2])) * medWeights[i] * medWeights2[i2];
                d2 += medWeights[i] * medWeights2[i2];
            }
            if (this.mergeFirst) {
                for (int i3 = i + 1; i3 < oldClusters[0].length; i3++) {
                    d += getScore(new AgglomerativeNeighbor(original, original, oldClusters[0][i], oldClusters[0][i3])) * medWeights[i] * medWeights[i3];
                    d2 += medWeights[i] * medWeights[i3];
                }
            }
        }
        if (this.mergeFirst) {
            for (int i4 = 0; i4 < oldClusters[1].length; i4++) {
                for (int i5 = i4 + 1; i5 < oldClusters[1].length; i5++) {
                    d += getScore(new AgglomerativeNeighbor(original, original, oldClusters[1][i4], oldClusters[1][i5])) * medWeights2[i4] * medWeights2[i5];
                    d2 += medWeights2[i4] * medWeights2[i5];
                }
            }
        }
        return d / d2;
    }

    private double[] getMedWeights(int i, int[] iArr, Clustering clustering) {
        double[] dArr = new double[iArr.length];
        for (int i2 = 0; i2 < dArr.length; i2++) {
            if (i == i2) {
                dArr[i2] = 1.0d;
            } else {
                dArr[i2] = getScore(new AgglomerativeNeighbor(clustering, clustering, iArr[i], iArr[i2]));
            }
        }
        return dArr;
    }

    private int getCentroid(int[] iArr, Clustering clustering) {
        if (iArr.length < 2) {
            return 0;
        }
        double d = Double.NEGATIVE_INFINITY;
        int i = -1;
        double[] dArr = new double[iArr.length];
        for (int i2 = 0; i2 < iArr.length; i2++) {
            double d2 = 0.0d;
            for (int i3 = 0; i3 < iArr.length && i2 != i3; i3++) {
                d2 += getScore(new AgglomerativeNeighbor(clustering, clustering, iArr[i2], iArr[i3]));
            }
            dArr[i2] = d2 / (iArr.length - 1);
        }
        for (int i4 = 0; i4 < dArr.length; i4++) {
            if (dArr[i4] > d) {
                d = dArr[i4];
                i = i4;
            }
        }
        return i;
    }

    @Override // cc.mallet.cluster.neighbor_evaluator.ClassifyingNeighborEvaluator, cc.mallet.cluster.neighbor_evaluator.NeighborEvaluator
    public void reset() {
        this.scoreCache = null;
    }

    @Override // cc.mallet.cluster.neighbor_evaluator.ClassifyingNeighborEvaluator
    public String toString() {
        return "class=" + getClass().getName() + " classifier=" + this.classifier.getClass().getName();
    }

    private double getScore(AgglomerativeNeighbor agglomerativeNeighbor) {
        if (this.scoreCache == null) {
            this.scoreCache = new PairwiseMatrix(agglomerativeNeighbor.getOriginal().getNumInstances());
        }
        int[] newCluster = agglomerativeNeighbor.getNewCluster();
        if (this.scoreCache.get(newCluster[0], newCluster[1]) == 0.0d) {
            this.scoreCache.set(newCluster[0], newCluster[1], this.classifier.classify(agglomerativeNeighbor).getLabelVector().value(this.scoringLabel));
        }
        return this.scoreCache.get(newCluster[0], newCluster[1]);
    }
}
