package com.rapidminer.kobra.topicmodels;

import cc.mallet.optimize.Optimizable;
import gnu.trove.list.array.TIntArrayList;
import java.util.Random;
import org.apache.commons.math3.special.Gamma;

/* loaded from: input_file:com/rapidminer/kobra/topicmodels/MyWordFeatOptimizable.class */
public class MyWordFeatOptimizable implements Optimizable.ByGradientValue {
    int k;
    int v;
    double[][] parameters;
    public TIntArrayList[] Phi;
    public double[] p_v;
    public int[] n_k;
    public int[] n_kv;
    double[] m = null;
    double[] b = null;
    double[][] documentFeatures = (double[][]) null;
    public double lambda = 1.0d;
    public double sigma = 2.0d;

    public MyWordFeatOptimizable(int i, int i2) {
        this.k = 10;
        this.v = 1928;
        this.parameters = (double[][]) null;
        this.k = i;
        this.v = i2;
        this.parameters = new double[i][i2];
        this.p_v = new double[i2];
        for (int i3 = 0; i3 < i2; i3++) {
            this.p_v[i3] = 1.0d / i2;
        }
        for (int i4 = 0; i4 < i; i4++) {
            for (int i5 = 0; i5 < i2; i5++) {
                this.parameters[i4][i5] = ((2.0d * Math.random()) * this.sigma) - this.sigma;
            }
        }
    }

    public MyWordFeatOptimizable(int i, int i2, Random random) {
        this.k = 10;
        this.v = 1928;
        this.parameters = (double[][]) null;
        this.k = i;
        this.v = i2;
        this.parameters = new double[i][i2];
        this.p_v = new double[i2];
        for (int i3 = 0; i3 < i2; i3++) {
            this.p_v[i3] = 1.0d / i2;
        }
        for (int i4 = 0; i4 < i; i4++) {
            for (int i5 = 0; i5 < i2; i5++) {
                this.parameters[i4][i5] = ((2.0d * random.nextDouble()) * this.sigma) - this.sigma;
            }
        }
    }

    @Override // cc.mallet.optimize.Optimizable
    public int getNumParameters() {
        return this.k * this.v;
    }

    @Override // cc.mallet.optimize.Optimizable
    public void getParameters(double[] dArr) {
        for (int i = 0; i < this.k; i++) {
            for (int i2 = 0; i2 < this.v; i2++) {
                dArr[(i * this.v) + i2] = this.parameters[i][i2];
            }
        }
    }

    @Override // cc.mallet.optimize.Optimizable
    public double getParameter(int i) {
        return this.parameters[i / this.v][i % this.v];
    }

    @Override // cc.mallet.optimize.Optimizable
    public void setParameters(double[] dArr) {
        for (int i = 0; i < this.k; i++) {
            for (int i2 = 0; i2 < this.v; i2++) {
                this.parameters[i][i2] = dArr[(i * this.v) + i2];
            }
        }
    }

    @Override // cc.mallet.optimize.Optimizable
    public void setParameter(int i, double d) {
        this.parameters[i / this.v][i % this.v] = d;
    }

    @Override // cc.mallet.optimize.Optimizable.ByGradientValue
    public void getValueGradient(double[] dArr) {
        for (int i = 0; i < dArr.length; i++) {
            dArr[i] = 0.0d;
        }
        for (int i2 = 0; i2 < this.v; i2++) {
            for (int i3 = 0; i3 < this.k; i3++) {
                dArr[(i3 * this.v) + i2] = this.parameters[i3][i2] / (this.lambda * this.lambda);
            }
        }
        double[] dArr2 = new double[this.k];
        for (int i4 = 0; i4 < this.k; i4++) {
            for (int i5 = 0; i5 < this.v; i5++) {
                int i6 = i4;
                dArr2[i6] = dArr2[i6] + (Math.exp(this.parameters[i4][i5]) * this.p_v[i5]);
            }
        }
        for (int i7 = 0; i7 < this.v; i7++) {
            for (int i8 = 0; i8 < this.k; i8++) {
                int i9 = (i8 * this.v) + i7;
                dArr[i9] = dArr[i9] + ((Gamma.digamma(dArr2[i8] + this.n_k[i8]) - Gamma.digamma(dArr2[i8])) * Math.exp(this.parameters[i8][i7]) * this.p_v[i7]);
                if (this.n_kv[(i7 * this.k) + i8] > 0) {
                    int i10 = (i8 * this.v) + i7;
                    dArr[i10] = dArr[i10] + ((Gamma.digamma(Math.exp(this.parameters[i8][i7]) * this.p_v[i7]) - Gamma.digamma((Math.exp(this.parameters[i8][i7]) * this.p_v[i7]) + this.n_kv[(i7 * this.k) + i8])) * Math.exp(this.parameters[i8][i7]) * this.p_v[i7]);
                }
            }
        }
        for (int i11 = 0; i11 < dArr.length; i11++) {
            dArr[i11] = dArr[i11];
        }
    }

    @Override // cc.mallet.optimize.Optimizable.ByGradientValue
    public double getValue() {
        double d = 0.0d;
        for (int i = 0; i < this.k; i++) {
            for (int i2 = 0; i2 < this.v; i2++) {
                d += (this.parameters[i][i2] * this.parameters[i][i2]) / ((2.0d * this.sigma) * this.sigma);
            }
        }
        double d2 = 0.0d;
        for (int i3 = 0; i3 < this.k; i3++) {
            double d3 = 0.0d;
            for (int i4 = 0; i4 < this.v; i4++) {
                d3 += Math.exp(this.parameters[i3][i4]) * this.p_v[i4];
            }
            d2 += Gamma.logGamma(d3 + this.n_k[i3]) - Gamma.logGamma(d3);
        }
        double d4 = 0.0d;
        for (int i5 = 0; i5 < this.k; i5++) {
            for (int i6 = 0; i6 < this.v; i6++) {
                double d5 = 0.0d;
                if (this.n_kv[(i6 * this.k) + i5] > 0) {
                    d5 = 0.0d + (Gamma.logGamma(Math.exp(this.parameters[i5][i6]) * this.p_v[i6]) - Gamma.logGamma((Math.exp(this.parameters[i5][i6]) * this.p_v[i6]) + this.n_kv[(i6 * this.k) + i5]));
                }
                d4 += d5;
            }
        }
        return (-1.0d) * (d2 + d4 + d);
    }
}
