package com.rapidminer.kobra.topicmodels;

import gnu.trove.list.array.TDoubleArrayList;
import gnu.trove.list.array.TIntArrayList;
import java.util.Random;
import org.apache.commons.math3.linear.Array2DRowRealMatrix;
import org.apache.commons.math3.linear.ArrayRealVector;
import org.apache.commons.math3.linear.RealMatrix;

/* loaded from: input_file:com/rapidminer/kobra/topicmodels/SamplersLDAWordRegularize.class */
public class SamplersLDAWordRegularize extends SamplersLDAWordFeatures {
    TDoubleArrayList[] graphWeights;
    public double nu = 1.0d;
    int reg_iter = 100;
    Array2DRowRealMatrix S = new Array2DRowRealMatrix();
    double[][] phi = (double[][]) null;
    double[][] theta = (double[][]) null;
    int numStats = 0;

    public void init(int[] iArr, int[] iArr2, int i, int i2, int i3, int i4, double d, double d2) {
        this.maxIter = i4;
        this.BETA = d;
        this.ALPHA = d2;
        this.numTokens = iArr2.length;
        this.numTopics = i;
        this.numDocs = i3;
        this.numWords = i2;
        this.topics = new int[this.numTokens];
        this.wordtopiccounts = new int[i2 * i];
        this.doctopiccounts = new int[i3 * i];
        this.topiccounts = new int[i];
        int[] iArr3 = new int[i2];
        int[] iArr4 = new int[i3];
        this.words = iArr2;
        this.docs = iArr;
        this.b = new double[i * i2];
        this.paras = new double[i * i2];
        this.parameters = new double[i][i2];
        if (this.p_v == null) {
            this.p_v = new double[i2];
            for (int i5 = 0; i5 < i2; i5++) {
                this.p_v[i5] = this.BETA;
            }
        }
        for (int i6 = 0; i6 < i2; i6++) {
            for (int i7 = 0; i7 < i; i7++) {
                this.b[(i6 * i) + i7] = this.BETA;
                this.paras[(i6 * i) + i7] = ((2.0d * Math.random()) * this.LAMBDA) - this.LAMBDA;
                this.parameters[i7][i6] = ((2.0d * Math.random()) * this.LAMBDA) - this.LAMBDA;
                this.b[(i6 * i) + i7] = Math.exp(this.parameters[i7][i6]) * this.p_v[i7];
            }
        }
        for (int i8 = 0; i8 < iArr2.length; i8++) {
            int i9 = this.words[i8];
            int i10 = this.docs[i8];
            int nextInt = new Random().nextInt(i);
            this.topics[i8] = nextInt;
            int[] iArr5 = this.wordtopiccounts;
            int i11 = (i9 * i) + nextInt;
            iArr5[i11] = iArr5[i11] + 1;
            int[] iArr6 = this.doctopiccounts;
            int i12 = (i10 * i) + nextInt;
            iArr6[i12] = iArr6[i12] + 1;
            int[] iArr7 = this.topiccounts;
            iArr7[nextInt] = iArr7[nextInt] + 1;
        }
    }

    @Override // com.rapidminer.kobra.topicmodels.SamplersLDA
    public int[] getTokenToTopic() {
        return this.tokenToTopic;
    }

    public void train() {
    }

    @Override // com.rapidminer.kobra.topicmodels.SamplersLDAWordFeatures
    public void setReg(boolean z) {
        this.reg = z;
    }

    public double sum(double[] dArr) {
        double d = 0.0d;
        for (double d2 : dArr) {
            d += d2;
        }
        return d;
    }

    @Override // com.rapidminer.kobra.topicmodels.SamplersLDAWordFeatures, com.rapidminer.kobra.topicmodels.SamplersLDA
    public void GibbsSampling() {
        int i = (int) (this.maxIter * 0.9d);
        this.WBETA = this.numWords * this.BETA;
        this.WBETAs = new double[this.numTopics];
        for (int i2 = 0; i2 < this.numTopics; i2++) {
            this.WBETAs[i2] = 0.0d;
            for (int i3 = 0; i3 < this.numWords; i3++) {
                double[] dArr = this.WBETAs;
                int i4 = i2;
                dArr[i4] = dArr[i4] + this.b[(i3 * this.numTopics) + i2];
            }
        }
        this.probs = new double[this.numTopics];
        this.tokenToTopic = new int[this.numTokens];
        TIntArrayList tIntArrayList = new TIntArrayList(this.numTokens);
        for (int i5 = 0; i5 < this.numTokens; i5++) {
            tIntArrayList.add(i5);
        }
        tIntArrayList.shuffle(new Random(2000L));
        double[][] dArr2 = new double[this.numTopics][this.numWords];
        for (int i6 = 0; i6 < this.numTopics; i6++) {
            double d = 0.0d;
            dArr2[i6] = new double[this.numWords];
            for (int i7 = 0; i7 < this.numWords; i7++) {
                dArr2[i6][i7] = this.wordtopiccounts[(i7 * this.numTopics) + i6] + (2.0d * this.nu);
                d += dArr2[i6][i7];
            }
            for (int i8 = 0; i8 < this.numWords; i8++) {
                double[] dArr3 = dArr2[i6];
                int i9 = i8;
                dArr3[i9] = dArr3[i9] / d;
            }
        }
        this.S = new Array2DRowRealMatrix(this.numWords, this.numWords);
        for (int i10 = 0; i10 < this.numWords; i10++) {
            for (int i11 = 0; i11 < this.Phi[i10].size(); i11++) {
                this.S.setEntry(i10, this.Phi[i10].get(i11), this.graphWeights[i10].get(i11));
            }
        }
        Array2DRowRealMatrix array2DRowRealMatrix = new Array2DRowRealMatrix(this.numWords, this.numTopics);
        Array2DRowRealMatrix array2DRowRealMatrix2 = new Array2DRowRealMatrix(this.numWords, this.numTopics);
        for (int i12 = 0; i12 < this.numTopics; i12++) {
            array2DRowRealMatrix2.setColumnVector(i12, array2DRowRealMatrix.getColumnVector(i12));
            array2DRowRealMatrix2.getColumnVector(i12).mapDivideToSelf(array2DRowRealMatrix2.getColumnVector(i12).getL1Norm());
        }
        RealMatrix scalarAdd = array2DRowRealMatrix2.scalarAdd(this.BETA);
        ArrayRealVector arrayRealVector = new ArrayRealVector(this.numWords);
        arrayRealVector.mapAdd(this.BETA);
        for (int i13 = 0; i13 < this.maxIter; i13++) {
            if (i13 > this.maxIter / 5 && i13 % 10 == 0 && i13 < i) {
                if (0 == 0) {
                    Array2DRowRealMatrix array2DRowRealMatrix3 = new Array2DRowRealMatrix(this.numWords, this.numTopics);
                    for (int i14 = 0; i14 < this.numTopics; i14++) {
                        for (int i15 = 0; i15 < this.numWords; i15++) {
                            array2DRowRealMatrix3.setEntry(i15, i14, this.wordtopiccounts[(i15 * this.numTopics) + i14]);
                        }
                    }
                    Array2DRowRealMatrix array2DRowRealMatrix4 = new Array2DRowRealMatrix(this.numWords, this.numTopics);
                    for (int i16 = 0; i16 < this.numTopics; i16++) {
                        array2DRowRealMatrix4.setColumnVector(i16, array2DRowRealMatrix3.getColumnVector(i16));
                        array2DRowRealMatrix4.setColumnVector(i16, array2DRowRealMatrix4.getColumnVector(i16).mapDivideToSelf(array2DRowRealMatrix4.getColumnVector(i16).getL1Norm()));
                    }
                    scalarAdd = array2DRowRealMatrix4.scalarAdd(this.BETA);
                    for (int i17 = 0; i17 < this.numTopics; i17++) {
                        Array2DRowRealMatrix array2DRowRealMatrix5 = new Array2DRowRealMatrix(this.numWords, 1);
                        array2DRowRealMatrix5.setColumnVector(0, array2DRowRealMatrix3.getColumnVector(i17));
                        array2DRowRealMatrix5.setColumnVector(0, array2DRowRealMatrix5.getColumnVector(0).mapAddToSelf(0.001d));
                        array2DRowRealMatrix5.setColumnVector(0, array2DRowRealMatrix5.getColumnVector(0).mapDivideToSelf(array2DRowRealMatrix5.getColumnVector(0).getL1Norm()));
                        for (int i18 = 0; i18 < this.reg_iter; i18++) {
                            RealMatrix multiply = this.S.multiply((RealMatrix) array2DRowRealMatrix5);
                            array2DRowRealMatrix5.setColumnVector(0, array2DRowRealMatrix3.getColumnVector(i17).add(array2DRowRealMatrix5.getColumnVector(0).ebeMultiply(multiply.getColumnVector(0)).mapMultiplyToSelf(2.0d * this.nu).mapDivideToSelf(array2DRowRealMatrix5.transpose().multiply(multiply).getEntry(0, 0))));
                            array2DRowRealMatrix5.setColumnVector(0, array2DRowRealMatrix5.getColumnVector(0).mapDivideToSelf(array2DRowRealMatrix5.getColumnVector(0).getL1Norm()));
                        }
                        scalarAdd.setColumnVector(i17, array2DRowRealMatrix5.getColumnVector(0));
                        scalarAdd.setColumnVector(i17, scalarAdd.getColumnVector(i17).mapDivideToSelf(scalarAdd.getColumnVector(i17).getL1Norm()));
                    }
                } else {
                    arrayRealVector.mapDivide(arrayRealVector.getL1Norm());
                    for (int i19 = 0; i19 < this.reg_iter; i19++) {
                    }
                }
            }
            for (int i20 = 0; i20 < this.numTokens; i20++) {
                int i21 = tIntArrayList.get(i20);
                int i22 = this.words[i21];
                int i23 = this.docs[i21];
                int i24 = this.topics[i21];
                int[] iArr = this.topiccounts;
                iArr[i24] = iArr[i24] - 1;
                int i25 = i22 * this.numTopics;
                int i26 = i23 * this.numTopics;
                int[] iArr2 = this.wordtopiccounts;
                int i27 = i25 + i24;
                iArr2[i27] = iArr2[i27] - 1;
                int[] iArr3 = this.doctopiccounts;
                int i28 = i26 + i24;
                iArr3[i28] = iArr3[i28] - 1;
                double d2 = 0.0d;
                for (int i29 = 0; i29 < this.numTopics; i29++) {
                    this.probs[i29] = scalarAdd.getEntry(i22, i29) * (this.doctopiccounts[i26 + i29] + this.ALPHA);
                    d2 += this.probs[i29];
                }
                double random = d2 * Math.random();
                double d3 = this.probs[0];
                int i30 = 0;
                while (random > d3) {
                    i30++;
                    d3 += this.probs[i30];
                }
                this.topics[i21] = i30;
                int[] iArr4 = this.wordtopiccounts;
                int i31 = i25 + i30;
                iArr4[i31] = iArr4[i31] + 1;
                int[] iArr5 = this.doctopiccounts;
                int i32 = i26 + i30;
                iArr5[i32] = iArr5[i32] + 1;
                int[] iArr6 = this.topiccounts;
                int i33 = i30;
                iArr6[i33] = iArr6[i33] + 1;
                this.tokenToTopic[i20] = i30;
            }
            for (int i34 = 0; i34 < this.numWords * this.numTopics; i34++) {
                if (this.wordtopiccounts[i34] < 0) {
                    this.wordtopiccounts[i34] = 0;
                }
            }
            for (int i35 = 0; i35 < this.numDocs * this.numTopics; i35++) {
                if (this.doctopiccounts[i35] < 0) {
                    this.doctopiccounts[i35] = 0;
                }
            }
            if (i13 >= i && Math.random() < 0.5d) {
                updateDistributions();
            }
        }
    }

    @Override // com.rapidminer.kobra.topicmodels.SamplersLDAWordFeatures
    public double[][] getBetas() {
        double[][] dArr = new double[this.numTopics][this.numWords];
        for (int i = 0; i < this.numTopics; i++) {
            for (int i2 = 0; i2 < this.numWords; i2++) {
                dArr[i][i2] = this.paras[(i * this.numWords) + i2];
            }
        }
        return dArr;
    }

    @Override // com.rapidminer.kobra.topicmodels.SamplersLDAWordFeatures, com.rapidminer.kobra.topicmodels.SamplersLDA
    public int[] assignedTopicsToWords() {
        int[] iArr = new int[this.numWords];
        for (int i = 0; i < this.numWords; i++) {
            double d = 0.0d;
            for (int i2 = 0; i2 < this.numTopics; i2++) {
                if (d < this.wordtopiccounts[(i * this.numTopics) + i2]) {
                    d = this.wordtopiccounts[(i * this.numTopics) + i2];
                    iArr[i] = i2;
                }
            }
        }
        return iArr;
    }

    @Override // com.rapidminer.kobra.topicmodels.SamplersLDAWordFeatures, com.rapidminer.kobra.topicmodels.SamplersLDA
    public double[] assignedTopicsToWordsProbs() {
        double[] dArr = new double[this.numWords];
        for (int i = 0; i < this.numWords; i++) {
            double d = 0.0d;
            double d2 = 0.0d;
            for (int i2 = 0; i2 < this.numTopics; i2++) {
                d2 += this.wordtopiccounts[(i * this.numTopics) + i2];
                if (d < this.wordtopiccounts[(i * this.numTopics) + i2]) {
                    d = this.wordtopiccounts[(i * this.numTopics) + i2];
                    dArr[i] = this.wordtopiccounts[(i * this.numTopics) + i2];
                }
            }
        }
        return dArr;
    }

    @Override // com.rapidminer.kobra.topicmodels.SamplersLDAWordFeatures, com.rapidminer.kobra.topicmodels.SamplersLDA
    public int[] assignedTopicsToDocs() {
        int[] iArr = new int[this.numDocs];
        for (int i = 0; i < this.numDocs; i++) {
            double d = 0.0d;
            for (int i2 = 0; i2 < this.numTopics; i2++) {
                if (d < this.doctopiccounts[(i * this.numTopics) + i2]) {
                    d = this.doctopiccounts[(i * this.numTopics) + i2];
                    iArr[i] = i2;
                }
            }
        }
        return iArr;
    }

    @Override // com.rapidminer.kobra.topicmodels.SamplersLDAWordFeatures, com.rapidminer.kobra.topicmodels.SamplersLDA
    public void updateDistributions() {
        this.numStats++;
        updateWordDistribution();
        updateDocumentDistribution();
    }

    @Override // com.rapidminer.kobra.topicmodels.SamplersLDAWordFeatures, com.rapidminer.kobra.topicmodels.SamplersLDA
    public void updateWordDistribution() {
        if (this.theta == null) {
            this.theta = new double[this.numTopics][this.numWords];
        }
        for (int i = 0; i < this.numWords; i++) {
            for (int i2 = 0; i2 < this.numTopics; i2++) {
                double[] dArr = this.theta[i2];
                int i3 = i;
                dArr[i3] = dArr[i3] + ((this.wordtopiccounts[(i * this.numTopics) + i2] + this.b[(i * this.numTopics) + i2]) / (this.topiccounts[i2] + this.WBETAs[i2]));
            }
        }
    }

    @Override // com.rapidminer.kobra.topicmodels.SamplersLDAWordFeatures, com.rapidminer.kobra.topicmodels.SamplersLDA
    public void updateDocumentDistribution() {
        if (this.phi == null) {
            this.phi = new double[this.numTopics][this.numDocs];
        }
        for (int i = 0; i < this.numDocs; i++) {
            int i2 = 0;
            for (int i3 = 0; i3 < this.numTopics; i3++) {
                i2 += this.doctopiccounts[(i * this.numTopics) + i3];
            }
            for (int i4 = 0; i4 < this.numTopics; i4++) {
                double[] dArr = this.phi[i4];
                int i5 = i;
                dArr[i5] = dArr[i5] + ((this.doctopiccounts[(i * this.numTopics) + i4] + this.ALPHA) / (i2 + (this.numTopics * this.ALPHA)));
            }
        }
    }

    @Override // com.rapidminer.kobra.topicmodels.SamplersLDAWordFeatures, com.rapidminer.kobra.topicmodels.SamplersLDA
    public double[][] wordDistribution() {
        double[][] dArr = new double[this.numTopics][this.numWords];
        if (this.numStats > 0) {
            for (int i = 0; i < this.numWords; i++) {
                for (int i2 = 0; i2 < this.numTopics; i2++) {
                    dArr[i2][i] = this.theta[i2][i] / this.numStats;
                }
            }
            return dArr;
        }
        for (int i3 = 0; i3 < this.numWords; i3++) {
            for (int i4 = 0; i4 < this.numTopics; i4++) {
                dArr[i4][i3] = (this.wordtopiccounts[(i3 * this.numTopics) + i4] + this.b[(i3 * this.numTopics) + i4]) / (this.topiccounts[i4] + this.WBETAs[i4]);
            }
        }
        return dArr;
    }

    @Override // com.rapidminer.kobra.topicmodels.SamplersLDAWordFeatures, com.rapidminer.kobra.topicmodels.SamplersLDA
    public double[][] documentDistribution() {
        double[][] dArr = new double[this.numTopics][this.numDocs];
        if (this.numStats > 0) {
            for (int i = 0; i < this.numDocs; i++) {
                for (int i2 = 0; i2 < this.numTopics; i2++) {
                    dArr[i2][i] = this.phi[i2][i] / this.numStats;
                }
            }
            return dArr;
        }
        for (int i3 = 0; i3 < this.numDocs; i3++) {
            int i4 = 0;
            for (int i5 = 0; i5 < this.numTopics; i5++) {
                i4 += this.doctopiccounts[(i3 * this.numTopics) + i5];
            }
            for (int i6 = 0; i6 < this.numTopics; i6++) {
                dArr[i6][i3] = (this.doctopiccounts[(i3 * this.numTopics) + i6] + this.ALPHA) / (i4 + (this.numTopics * this.ALPHA));
            }
        }
        return dArr;
    }

    public static void main(String[] strArr) {
    }
}
