package com.rapidminer.kobra.topicmodels;

import cc.mallet.optimize.OptimizationException;
import com.rapidminer.kobra.opt.MyOrthantWiseLimitedMemoryBFGS;
import gnu.trove.list.array.TIntArrayList;
import java.util.Random;
import org.apache.commons.math3.util.FastMath;
import org.apache.lucene.analysis.ar.ArabicNormalizer;

/* loaded from: input_file:com/rapidminer/kobra/topicmodels/SamplersSLDA.class */
public class SamplersSLDA extends SamplersLDA {
    double[][] z_bar;
    int[] doc_lengths;
    double[] labels;
    double[] labels_train;
    double[] predictions;
    int start_test = ArabicNormalizer.TATWEEL;
    public Random rn = null;

    public void init(int[] iArr, int[] iArr2, int i, int i2, int i3, int i4, double d, double d2, double[] dArr, int[] iArr3, boolean z, int i5) {
        this.labels = dArr;
        this.labels_train = new double[this.start_test];
        for (int i6 = 0; i6 < this.start_test; i6++) {
            this.labels_train[i6] = dArr[i6];
        }
        this.doc_lengths = iArr3;
        this.maxIter = i4;
        this.BETA = d;
        this.ALPHA = d2;
        this.numTokens = iArr2.length;
        this.numTopics = i;
        this.numDocs = i3;
        this.numWords = i2;
        this.topics = new int[this.numTokens];
        this.wordtopiccounts = new int[i2 * i];
        this.doctopiccounts = new int[i3 * i];
        this.topiccounts = new int[i];
        this.words = iArr2;
        this.docs = iArr;
        if (z) {
            this.seed = i5;
            this.rn = new Random(i5);
        } else {
            this.rn = new Random();
        }
        for (int i7 = 0; i7 < iArr2.length; i7++) {
            int i8 = this.words[i7];
            int i9 = this.docs[i7];
            int nextInt = this.rn.nextInt(i);
            this.topics[i7] = nextInt;
            int[] iArr4 = this.wordtopiccounts;
            int i10 = (i8 * i) + nextInt;
            iArr4[i10] = iArr4[i10] + 1;
            int[] iArr5 = this.doctopiccounts;
            int i11 = (i9 * i) + nextInt;
            iArr5[i11] = iArr5[i11] + 1;
            int[] iArr6 = this.topiccounts;
            iArr6[nextInt] = iArr6[nextInt] + 1;
        }
    }

    @Override // com.rapidminer.kobra.topicmodels.SamplersLDA
    public void GibbsSampling() {
        int i = (int) (this.maxIter * 0.9d);
        this.WBETA = this.numWords * this.BETA;
        this.probs = new double[this.numTopics];
        this.tokenToTopic = new int[this.numTokens];
        TIntArrayList tIntArrayList = new TIntArrayList(this.numTokens);
        for (int i2 = 0; i2 < this.numTokens; i2++) {
            tIntArrayList.add(i2);
        }
        tIntArrayList.shuffle(this.rn);
        double[] dArr = new double[this.numTopics];
        for (int i3 = 0; i3 < this.numTopics; i3++) {
            dArr[i3] = 1.0d / this.numTopics;
        }
        double log = FastMath.log(0.1d) + (0.5d * FastMath.log(6.283185307179586d));
        for (int i4 = 0; i4 < this.maxIter; i4++) {
            this.z_bar = new double[this.numDocs][this.numTopics];
            for (int i5 = 0; i5 < this.numTokens; i5++) {
                int i6 = tIntArrayList.get(i5);
                int i7 = this.words[i6];
                int i8 = this.docs[i6];
                if (i8 <= this.start_test) {
                    int i9 = this.topics[i6];
                    int[] iArr = this.topiccounts;
                    iArr[i9] = iArr[i9] - 1;
                    int i10 = i7 * this.numTopics;
                    int i11 = i8 * this.numTopics;
                    int[] iArr2 = this.wordtopiccounts;
                    int i12 = i10 + i9;
                    iArr2[i12] = iArr2[i12] - 1;
                    int[] iArr3 = this.doctopiccounts;
                    int i13 = i11 + i9;
                    iArr3[i13] = iArr3[i13] - 1;
                    double d = 0.0d;
                    double d2 = 0.0d;
                    int i14 = 0;
                    for (int i15 = 0; i15 < this.numTopics; i15++) {
                        d2 += this.doctopiccounts[i11 + i15] * dArr[i15];
                        i14 += this.doctopiccounts[i11 + i15];
                    }
                    double d3 = d2 / this.doc_lengths[i8];
                    for (int i16 = 0; i16 < this.numTopics; i16++) {
                        double d4 = (this.labels[i8] - ((d2 + dArr[i16]) / (i14 + 1))) / 0.1d;
                        this.probs[i16] = ((this.wordtopiccounts[i10 + i16] + this.BETA) / (this.topiccounts[i16] + this.WBETA)) * (this.doctopiccounts[i11 + i16] + this.ALPHA) * FastMath.exp((((-0.5d) * d4) * d4) - log);
                        d += this.probs[i16];
                    }
                    double nextDouble = d * this.rn.nextDouble();
                    double d5 = this.probs[0];
                    int i17 = 0;
                    while (nextDouble > d5) {
                        i17++;
                        d5 += this.probs[i17];
                    }
                    this.topics[i6] = i17;
                    int[] iArr4 = this.wordtopiccounts;
                    int i18 = i10 + i17;
                    iArr4[i18] = iArr4[i18] + 1;
                    int[] iArr5 = this.doctopiccounts;
                    int i19 = i11 + i17;
                    iArr5[i19] = iArr5[i19] + 1;
                    int[] iArr6 = this.topiccounts;
                    int i20 = i17;
                    iArr6[i20] = iArr6[i20] + 1;
                    double[] dArr2 = this.z_bar[i8];
                    int i21 = i17;
                    dArr2[i21] = dArr2[i21] + (1.0d / this.doc_lengths[i8]);
                    this.tokenToTopic[i5] = i17;
                }
            }
            for (int i22 = 0; i22 < this.numWords * this.numTopics; i22++) {
                if (this.wordtopiccounts[i22] < 0) {
                    this.wordtopiccounts[i22] = 0;
                }
            }
            for (int i23 = 0; i23 < this.numDocs * this.numTopics; i23++) {
                if (this.doctopiccounts[i23] < 0) {
                    this.doctopiccounts[i23] = 0;
                }
            }
            if (i4 >= i && i4 % 2 == 0) {
                updateDistributions();
            }
            if ((i4 == 0 || i4 % 10 == 0) && i4 < i) {
                MyEtaOptimizable myEtaOptimizable = new MyEtaOptimizable(this.labels_train, this.numTopics, this.z_bar);
                try {
                    new MyOrthantWiseLimitedMemoryBFGS(myEtaOptimizable, 0.0d).optimize(1000);
                } catch (OptimizationException e) {
                    e.printStackTrace();
                }
                dArr = new double[this.numTopics];
                myEtaOptimizable.getParameters(dArr);
            }
        }
        for (int i24 = 0; i24 < this.maxIter / 20; i24++) {
            this.z_bar = new double[this.numDocs][this.numTopics];
            for (int i25 = 0; i25 < this.numTokens; i25++) {
                int i26 = tIntArrayList.get(i25);
                int i27 = this.words[i26];
                int i28 = this.docs[i26];
                if (i28 >= this.start_test) {
                    int i29 = this.topics[i26];
                    int[] iArr7 = this.topiccounts;
                    iArr7[i29] = iArr7[i29] - 1;
                    int i30 = i27 * this.numTopics;
                    int i31 = i28 * this.numTopics;
                    int[] iArr8 = this.wordtopiccounts;
                    int i32 = i30 + i29;
                    iArr8[i32] = iArr8[i32] - 1;
                    int[] iArr9 = this.doctopiccounts;
                    int i33 = i31 + i29;
                    iArr9[i33] = iArr9[i33] - 1;
                    double d6 = 0.0d;
                    for (int i34 = 0; i34 < this.numTopics; i34++) {
                        this.probs[i34] = ((this.wordtopiccounts[i30 + i34] + this.BETA) / (this.topiccounts[i34] + this.WBETA)) * (this.doctopiccounts[i31 + i34] + this.ALPHA);
                        d6 += this.probs[i34];
                    }
                    double random = d6 * Math.random();
                    double d7 = this.probs[0];
                    int i35 = 0;
                    while (random > d7) {
                        i35++;
                        d7 += this.probs[i35];
                    }
                    int[] iArr10 = this.wordtopiccounts;
                    int i36 = i30 + i35;
                    iArr10[i36] = iArr10[i36] + 1;
                    int[] iArr11 = this.doctopiccounts;
                    int i37 = i31 + i35;
                    iArr11[i37] = iArr11[i37] + 1;
                    int[] iArr12 = this.topiccounts;
                    int i38 = i35;
                    iArr12[i38] = iArr12[i38] + 1;
                    this.topics[i26] = i35;
                    this.tokenToTopic[i25] = i35;
                    double[] dArr3 = this.z_bar[i28];
                    int i39 = i35;
                    dArr3[i39] = dArr3[i39] + (1.0d / this.doc_lengths[i28]);
                }
            }
        }
        this.predictions = new double[this.labels.length];
        for (int i40 = 0; i40 < this.z_bar.length; i40++) {
            double d8 = 0.0d;
            for (int i41 = 0; i41 < this.z_bar[i40].length; i41++) {
                d8 += dArr[i41] * this.z_bar[i40][i41];
            }
            this.predictions[i40] = d8 > 0.0d ? 1.0d : -1.0d;
        }
    }

    public double[] getPredictions() {
        if (this.predictions == null) {
            this.predictions = new double[this.numDocs];
        }
        return this.predictions;
    }

    public static void main(String[] strArr) {
    }
}
