package com.rapidminer.kobra.topicmodels;

import gnu.trove.TIntHashSet;
import gnu.trove.list.array.TDoubleArrayList;
import gnu.trove.list.array.TIntArrayList;
import gnu.trove.map.hash.TDoubleIntHashMap;
import java.util.Random;
import org.apache.commons.math3.distribution.BetaDistribution;
import org.apache.commons.math3.stat.descriptive.moment.Mean;
import org.apache.commons.math3.stat.descriptive.moment.Variance;

/* loaded from: input_file:com/rapidminer/kobra/topicmodels/SamplersDTLDA.class */
public class SamplersDTLDA extends SamplersSLDA {
    double[] times;
    double[] meanTimes;
    double[] varianceTimes;
    int[] uniqueIds;
    double[][] pi;
    BetaDistribution[] pBeta = null;
    TDoubleIntHashMap[] hsValues = null;

    public void init(int[] iArr, int[] iArr2, double[] dArr, int[] iArr3, int i, int i2, int i3, int i4, double d, double d2, boolean z, int i5) {
        this.maxIter = i4;
        this.BETA = d;
        this.ALPHA = d2;
        this.numTokens = iArr2.length;
        this.numTopics = i;
        this.numDocs = i3;
        this.numWords = i2;
        this.topics = new int[this.numTokens];
        this.wordtopiccounts = new int[i2 * i];
        this.doctopiccounts = new int[i3 * i];
        this.topiccounts = new int[i];
        this.words = iArr2;
        this.docs = iArr;
        if (z) {
            this.seed = i5;
            this.rn = new Random(i5);
        } else {
            this.rn = new Random();
        }
        for (int i6 = 0; i6 < iArr2.length; i6++) {
            int i7 = this.words[i6];
            int i8 = this.docs[i6];
            int nextInt = this.rn.nextInt(i);
            this.topics[i6] = nextInt;
            int[] iArr4 = this.wordtopiccounts;
            int i9 = (i7 * i) + nextInt;
            iArr4[i9] = iArr4[i9] + 1;
            int[] iArr5 = this.doctopiccounts;
            int i10 = (i8 * i) + nextInt;
            iArr5[i10] = iArr5[i10] + 1;
            int[] iArr6 = this.topiccounts;
            iArr6[nextInt] = iArr6[nextInt] + 1;
        }
        this.times = dArr;
        this.pi = new double[i][2];
        this.pBeta = new BetaDistribution[i];
        this.meanTimes = new double[i];
        this.varianceTimes = new double[i];
        for (int i11 = 0; i11 < i; i11++) {
            this.pi[i11][0] = 1.0d;
            this.pi[i11][1] = 1.0d;
            this.pBeta[i11] = new BetaDistribution(this.pi[i11][0], this.pi[i11][1]);
        }
        this.uniqueIds = iArr3;
    }

    @Override // com.rapidminer.kobra.topicmodels.SamplersSLDA, com.rapidminer.kobra.topicmodels.SamplersLDA
    public void GibbsSampling() {
        double d;
        double d2;
        int i = (int) (this.maxIter * 0.9d);
        this.WBETA = this.numWords * this.BETA;
        this.probs = new double[this.numTopics];
        this.tokenToTopic = new int[this.numTokens];
        TIntArrayList tIntArrayList = new TIntArrayList(this.numTokens);
        for (int i2 = 0; i2 < this.numTokens; i2++) {
            tIntArrayList.add(i2);
        }
        tIntArrayList.shuffle(this.rn);
        this.hsValues = new TDoubleIntHashMap[this.numTopics];
        for (int i3 = 0; i3 < this.maxIter; i3++) {
            System.out.println("Current Gibbs Sampler iteration: " + i3);
            Mean[] meanArr = new Mean[this.numTopics];
            Variance[] varianceArr = new Variance[this.numTopics];
            TIntHashSet[] tIntHashSetArr = new TIntHashSet[this.numTopics];
            for (int i4 = 0; i4 < this.numTopics; i4++) {
                this.hsValues[i4] = new TDoubleIntHashMap();
                meanArr[i4] = new Mean();
                varianceArr[i4] = new Variance();
            }
            for (int i5 = 0; i5 < this.numTokens; i5++) {
                int i6 = tIntArrayList.get(i5);
                int i7 = this.words[i6];
                int i8 = this.docs[i6];
                int i9 = this.topics[i6];
                int[] iArr = this.topiccounts;
                iArr[i9] = iArr[i9] - 1;
                int i10 = i7 * this.numTopics;
                int i11 = i8 * this.numTopics;
                int[] iArr2 = this.wordtopiccounts;
                int i12 = i10 + i9;
                iArr2[i12] = iArr2[i12] - 1;
                int[] iArr3 = this.doctopiccounts;
                int i13 = i11 + i9;
                iArr3[i13] = iArr3[i13] - 1;
                double d3 = 0.0d;
                int i14 = 0;
                for (int i15 = 0; i15 < this.numTopics; i15++) {
                    i14 += this.doctopiccounts[i11 + i15];
                }
                for (int i16 = 0; i16 < this.numTopics; i16++) {
                    this.probs[i16] = ((this.wordtopiccounts[i10 + i16] + this.BETA) / (this.topiccounts[i16] + this.WBETA)) * (this.doctopiccounts[i11 + i16] + this.ALPHA) * this.pBeta[i16].density(this.times[i8]);
                    d3 += this.probs[i16];
                }
                double nextDouble = d3 * this.rn.nextDouble();
                double d4 = this.probs[0];
                int i17 = 0;
                while (nextDouble > d4) {
                    i17++;
                    d4 += this.probs[i17];
                }
                this.topics[i6] = i17;
                int[] iArr4 = this.wordtopiccounts;
                int i18 = i10 + i17;
                iArr4[i18] = iArr4[i18] + 1;
                int[] iArr5 = this.doctopiccounts;
                int i19 = i11 + i17;
                iArr5[i19] = iArr5[i19] + 1;
                int[] iArr6 = this.topiccounts;
                int i20 = i17;
                iArr6[i20] = iArr6[i20] + 1;
                meanArr[i17].increment(this.times[i8]);
                varianceArr[i17].increment(this.times[i8]);
                int i21 = 1;
                if (this.hsValues[i17].contains(this.times[i8])) {
                    i21 = 1 + this.hsValues[i17].get(this.times[i8]);
                }
                this.hsValues[i17].put(this.times[i8], i21);
                this.tokenToTopic[i5] = i17;
            }
            for (int i22 = 0; i22 < this.numWords * this.numTopics; i22++) {
                if (this.wordtopiccounts[i22] < 0) {
                    this.wordtopiccounts[i22] = 0;
                }
            }
            for (int i23 = 0; i23 < this.numDocs * this.numTopics; i23++) {
                if (this.doctopiccounts[i23] < 0) {
                    this.doctopiccounts[i23] = 0;
                }
            }
            for (int i24 = 0; i24 < this.numTopics; i24++) {
                if (meanArr[i24].getN() == 1) {
                    double result = meanArr[i24].getResult();
                    for (int i25 = 0; i25 < 10; i25++) {
                        double nextGaussian = this.rn.nextGaussian();
                        while (true) {
                            d2 = (nextGaussian * 0.01d) + result;
                            if (d2 <= 0.0d && d2 >= 1.0d) {
                                nextGaussian = this.rn.nextGaussian();
                            }
                        }
                        meanArr[i24].increment(d2);
                        varianceArr[i24].increment(d2);
                    }
                } else if (varianceArr[i24].getResult() <= 0.01d) {
                    double result2 = meanArr[i24].getResult();
                    long n = meanArr[i24].getN();
                    meanArr[i24].clear();
                    varianceArr[i24].clear();
                    for (int i26 = 0; i26 < n; i26++) {
                        double nextGaussian2 = this.rn.nextGaussian();
                        while (true) {
                            d = (nextGaussian2 * 0.01d) + result2;
                            if (d <= 0.0d && d >= 1.0d) {
                                nextGaussian2 = this.rn.nextGaussian();
                            }
                        }
                        meanArr[i24].increment(d);
                        varianceArr[i24].increment(d);
                    }
                }
                if (meanArr[i24].getN() < 1) {
                    this.pi[i24][0] = 1.0d;
                    this.pi[i24][1] = 1.0d;
                } else {
                    this.meanTimes[i24] = meanArr[i24].getResult();
                    this.varianceTimes[i24] = varianceArr[i24].getResult();
                    double d5 = this.meanTimes[i24];
                    double d6 = this.varianceTimes[i24];
                    this.pi[i24][0] = d5 * (((d5 * (1.0d - d5)) / d6) - 1.0d);
                    this.pi[i24][1] = (1.0d - d5) * (((d5 * (1.0d - d5)) / d6) - 1.0d);
                    this.pBeta[i24] = new BetaDistribution(this.pi[i24][0], this.pi[i24][1]);
                }
            }
            if (i3 >= i && i3 % 2 == 0) {
                updateDistributions();
            }
        }
    }

    public TDoubleArrayList[] getAssignedTimes() {
        TDoubleArrayList[] tDoubleArrayListArr = new TDoubleArrayList[this.numTopics];
        for (int i = 0; i < this.numTopics; i++) {
            TDoubleArrayList tDoubleArrayList = new TDoubleArrayList();
            for (double d : this.hsValues[i].keys()) {
                int i2 = this.hsValues[i].get(d);
                for (int i3 = 0; i3 < i2; i3++) {
                    tDoubleArrayList.add(d);
                }
            }
            tDoubleArrayListArr[i] = tDoubleArrayList;
        }
        return tDoubleArrayListArr;
    }

    public double[][] getPi() {
        return this.pi;
    }

    public static void main(String[] strArr) {
    }
}
