package com.rapidminer.kobra.topicmodels;

import java.io.PrintWriter;
import java.util.Iterator;

/* loaded from: input_file:com/rapidminer/kobra/topicmodels/LDA.class */
public class LDA {
    int sNumDocs;
    int numTopics;
    int numTerms;
    Corpus corpus;
    double alpha;
    double[][] log_prob_w;
    double[][] var_gamma;
    double[][] phi;
    double[][] class_word;
    double[] class_total;
    double sAlpha = 0.0d;
    int EM_MAX_ITER = 1000;
    int VAR_MAX_ITER = 20;
    double EM_CONVERGED = 1.0E-4d;
    boolean ESTIMATE_ALPHA = true;
    private double NEWTON_THRESH = 1.0E-5d;
    private int MAX_ALPHA_ITER = 1000;
    double VAR_CONVERGED = 1.0E-6d;

    public static void main(String[] strArr) {
        LDA lda = new LDA(4, new Corpus("ap.dat", "vocab.txt"), 0.5d);
        lda.runEM();
        lda.storeBeta();
    }

    /* JADX WARN: Type inference failed for: r1v22, types: [double[], double[][]] */
    /* JADX WARN: Type inference failed for: r1v25, types: [double[], double[][]] */
    /* JADX WARN: Type inference failed for: r1v32, types: [double[], double[][]] */
    /* JADX WARN: Type inference failed for: r1v40, types: [double[], double[][]] */
    public LDA(int i, Corpus corpus, double d) {
        this.sNumDocs = 0;
        System.out.println("topics\t\t: " + i);
        System.out.println("initial alpha\t: " + d);
        this.corpus = corpus;
        this.numTopics = i;
        this.alpha = d;
        this.sNumDocs = 0;
        int numDocs = corpus.getNumDocs();
        this.var_gamma = new double[numDocs];
        for (int i2 = 0; i2 < numDocs; i2++) {
            this.var_gamma[i2] = new double[this.numTopics];
        }
        int maxCorpusLength = this.corpus.maxCorpusLength();
        this.phi = new double[maxCorpusLength];
        for (int i3 = 0; i3 < maxCorpusLength; i3++) {
            this.phi[i3] = new double[this.numTopics];
        }
        this.numTerms = this.corpus.getNumTerms();
        this.log_prob_w = new double[this.numTopics];
        for (int i4 = 0; i4 < this.numTopics; i4++) {
            this.log_prob_w[i4] = new double[this.numTerms];
            for (int i5 = 0; i5 < this.numTerms; i5++) {
                this.log_prob_w[i4][i5] = 0.0d;
            }
        }
        this.class_total = new double[this.numTopics];
        this.class_word = new double[this.numTopics];
        for (int i6 = 0; i6 < this.numTopics; i6++) {
            this.class_total[i6] = 0.0d;
            this.class_word[i6] = new double[this.numTerms];
            for (int i7 = 0; i7 < this.numTerms; i7++) {
                this.class_word[i6][i7] = (1.0d / this.numTerms) + ldaRand();
                double[] dArr = this.class_total;
                int i8 = i6;
                dArr[i8] = dArr[i8] + this.class_word[i6][i7];
            }
        }
        ldaMLE(false);
    }

    void zeroInitialize() {
        for (int i = 0; i < this.numTopics; i++) {
            this.class_total[i] = 0.0d;
            this.class_word[i] = new double[this.numTerms];
            for (int i2 = 0; i2 < this.numTerms; i2++) {
                this.class_word[i][i2] = 0.0d;
            }
        }
        this.sNumDocs = 0;
        this.sAlpha = 0.0d;
    }

    public void runEM() {
        int i = 0;
        double d = 0.0d;
        double d2 = 1.0d;
        while (true) {
            if ((d2 >= 0.0d && d2 <= this.EM_CONVERGED && i > 2) || i > this.EM_MAX_ITER) {
                return;
            }
            i++;
            System.out.println("**** em iteration " + i + " ****");
            double d3 = 0.0d;
            zeroInitialize();
            for (int i2 = 0; i2 < this.corpus.getNumDocs(); i2++) {
                if (i2 % 1000 == 0) {
                    System.out.println("document " + i2);
                }
                d3 += docEstep(i2);
            }
            ldaMLE(this.ESTIMATE_ALPHA);
            d2 = (d - d3) / d;
            if (d2 < 0.0d) {
                this.VAR_MAX_ITER *= 2;
            }
            d = d3;
            System.out.println("Likelihood = " + d3 + ", delta = " + d2);
        }
    }

    double docEstep(int i) {
        int i2 = 0;
        double lda_inference = lda_inference(i);
        double d = 0.0d;
        for (int i3 = 0; i3 < this.numTopics; i3++) {
            d += this.var_gamma[i][i3];
            this.sAlpha += digamma(this.var_gamma[i][i3]);
        }
        this.sAlpha -= this.numTopics * digamma(d);
        Iterator<Integer> it = this.corpus.getWordIdsInDoc(i).iterator();
        while (it.hasNext()) {
            int intValue = it.next().intValue();
            for (int i4 = 0; i4 < this.numTopics; i4++) {
                double[] dArr = this.class_word[i4];
                dArr[intValue] = dArr[intValue] + (this.corpus.getWordFreqInDoc(i, intValue) * this.phi[i2][i4]);
                double[] dArr2 = this.class_total;
                int i5 = i4;
                dArr2[i5] = dArr2[i5] + (this.corpus.getWordFreqInDoc(i, intValue) * this.phi[i2][i4]);
            }
            i2++;
        }
        this.sNumDocs++;
        return lda_inference;
    }

    private double opt_alpha(double d, int i, int i2) {
        double d2 = 100.0d;
        int i3 = 0;
        double log = Math.log(100.0d);
        do {
            i3++;
            double exp = Math.exp(log);
            if (exp == Double.NaN) {
                d2 *= 10.0d;
                System.out.println("warning : alpha is nan; new init = " + d2);
                exp = d2;
                log = Math.log(exp);
            }
            double alhood = alhood(exp, d, i, i2);
            double d_alhood = d_alhood(exp, d, i, i2);
            log -= (1.0d / ((d2_alhood(exp, i, i2) * exp) + d_alhood)) * d_alhood;
            System.out.println("alpha maximization : " + alhood + "   " + d_alhood);
            if (Math.abs(d_alhood) <= this.NEWTON_THRESH) {
                break;
            }
        } while (i3 < this.MAX_ALPHA_ITER);
        return Math.exp(log);
    }

    public void ldaMLE(boolean z) {
        for (int i = 0; i < this.numTopics; i++) {
            for (int i2 = 0; i2 < this.numTerms; i2++) {
                if (this.class_word[i][i2] > 0.0d) {
                    this.log_prob_w[i][i2] = Math.log(this.class_word[i][i2]) - Math.log(this.class_total[i]);
                } else {
                    this.log_prob_w[i][i2] = -100.0d;
                }
            }
        }
        if (z) {
            this.alpha = opt_alpha(this.sAlpha, this.sNumDocs, this.numTopics);
            System.out.println("new alpha = " + this.alpha);
        }
    }

    private double ldaRand() {
        return Math.random();
    }

    public void storeBeta() {
        try {
            PrintWriter printWriter = new PrintWriter("beta", "UTF-8");
            for (int i = 0; i < this.numTopics; i++) {
                for (int i2 = 0; i2 < this.numTerms; i2++) {
                    printWriter.print(" " + this.log_prob_w[i][i2]);
                }
                printWriter.println();
            }
            printWriter.close();
        } catch (Exception e) {
            System.out.println(e);
        }
    }

    double log_sum(double d, double d2) {
        return d < d2 ? d2 + Math.log(1.0d + Math.exp(d - d2)) : d + Math.log(1.0d + Math.exp(d2 - d));
    }

    double lda_inference(int i) {
        double d = 1.0d;
        double d2 = 0.0d;
        double d3 = 0.0d;
        double[] dArr = new double[this.numTopics];
        double[] dArr2 = new double[this.numTopics];
        for (int i2 = 0; i2 < this.numTopics; i2++) {
            this.var_gamma[i][i2] = this.alpha + (this.corpus.getTotal(i) / this.numTopics);
            dArr2[i2] = digamma(this.var_gamma[i][i2]);
            for (int i3 = 0; i3 < this.corpus.getDocumentLength(i); i3++) {
                this.phi[i3][i2] = 1.0d / this.numTopics;
            }
        }
        int i4 = 0;
        while (d > this.VAR_CONVERGED && (i4 < this.VAR_MAX_ITER || this.VAR_MAX_ITER == -1)) {
            i4++;
            int i5 = 0;
            Iterator<Integer> it = this.corpus.getWordIdsInDoc(i).iterator();
            while (it.hasNext()) {
                int intValue = it.next().intValue();
                double d4 = 0.0d;
                int i6 = 0;
                while (i6 < this.numTopics) {
                    dArr[i6] = this.phi[i5][i6];
                    this.phi[i5][i6] = dArr2[i6] + this.log_prob_w[i6][intValue];
                    d4 = i6 > 0 ? log_sum(d4, this.phi[i5][i6]) : this.phi[i5][i6];
                    i6++;
                }
                for (int i7 = 0; i7 < this.numTopics; i7++) {
                    this.phi[i5][i7] = Math.exp(this.phi[i5][i7] - d4);
                    double[] dArr3 = this.var_gamma[i];
                    int i8 = i7;
                    dArr3[i8] = dArr3[i8] + (this.corpus.getWordFreqInDoc(i, intValue) * (this.phi[i5][i7] - dArr[i7]));
                    dArr2[i7] = digamma(this.var_gamma[i][i7]);
                }
                i5++;
            }
            d2 = compute_likelihood(i);
            d = (d3 - d2) / d3;
            d3 = d2;
        }
        return d2;
    }

    double compute_likelihood(int i) {
        double d = 0.0d;
        double[] dArr = new double[this.numTopics];
        for (int i2 = 0; i2 < this.numTopics; i2++) {
            dArr[i2] = digamma(this.var_gamma[i][i2]);
            d += this.var_gamma[i][i2];
        }
        double digamma = digamma(d);
        double lgamma = (lgamma(this.alpha * this.numTopics) - (this.numTopics * lgamma(this.alpha))) - lgamma(d);
        for (int i3 = 0; i3 < this.numTopics; i3++) {
            lgamma += (((this.alpha - 1.0d) * (dArr[i3] - digamma)) + lgamma(this.var_gamma[i][i3])) - ((this.var_gamma[i][i3] - 1.0d) * (dArr[i3] - digamma));
            int i4 = 0;
            Iterator<Integer> it = this.corpus.getWordIdsInDoc(i).iterator();
            while (it.hasNext()) {
                int intValue = it.next().intValue();
                if (this.phi[i4][i3] > 0.0d) {
                    lgamma += this.corpus.getWordFreqInDoc(i, intValue) * this.phi[i4][i3] * (((dArr[i3] - digamma) - Math.log(this.phi[i4][i3])) + this.log_prob_w[i3][intValue]);
                }
                i4++;
            }
        }
        return lgamma;
    }

    private double alhood(double d, double d2, int i, int i2) {
        return (i * (lgamma(i2 * d) - (i2 * lgamma(d)))) + ((d - 1.0d) * d2);
    }

    private double d_alhood(double d, double d2, int i, int i2) {
        return (i * ((i2 * digamma(i2 * d)) - (i2 * digamma(d)))) + d2;
    }

    private double d2_alhood(double d, int i, int i2) {
        return i * (((i2 * i2) * trigamma(i2 * d)) - (i2 * trigamma(d)));
    }

    private double trigamma(double d) {
        double d2 = d + 6.0d;
        double d3 = 1.0d / (d2 * d2);
        double d4 = (((((((((((0.075757575757576d * d3) - 0.033333333333333d) * d3) + 0.0238095238095238d) * d3) - 0.033333333333333d) * d3) + 0.166666666666667d) * d3) + 1.0d) / d2) + (0.5d * d3);
        for (int i = 0; i < 6; i++) {
            d2 -= 1.0d;
            d4 = (1.0d / (d2 * d2)) + d4;
        }
        return d4;
    }

    private double digamma(double d) {
        double d2 = d + 6.0d;
        double d3 = 1.0d / (d2 * d2);
        return ((((((((((((((0.004166666666667d * d3) - 0.003968253986254d) * d3) + 0.008333333333333d) * d3) - 0.083333333333333d) * d3) + Math.log(d2)) - (0.5d / d2)) - (1.0d / (d2 - 1.0d))) - (1.0d / (d2 - 2.0d))) - (1.0d / (d2 - 3.0d))) - (1.0d / (d2 - 4.0d))) - (1.0d / (d2 - 5.0d))) - (1.0d / (d2 - 6.0d));
    }

    private double lgamma(double d) {
        double d2 = 1.0d / (d * d);
        double d3 = d + 6.0d;
        return ((((((((((d3 - 0.5d) * Math.log(d3)) - d3) + 0.918938533204673d) + ((((((((-5.95238095238E-4d) * d2) + 7.93650793651E-4d) * d2) - 0.002777777777778d) * d2) + 0.083333333333333d) / d3)) - Math.log(d3 - 1.0d)) - Math.log(d3 - 2.0d)) - Math.log(d3 - 3.0d)) - Math.log(d3 - 4.0d)) - Math.log(d3 - 5.0d)) - Math.log(d3 - 6.0d);
    }
}
