package com.rapidminer.kobra.topicmodels;

import cc.mallet.optimize.Optimizable;
import gnu.trove.list.array.TIntArrayList;
import org.apache.commons.math3.special.Gamma;

/* loaded from: input_file:com/rapidminer/kobra/topicmodels/MyWordSimilarityOptimizableOld2.class */
public class MyWordSimilarityOptimizableOld2 implements Optimizable.ByGradientValue {
    int k;
    int v;
    double[][] parameters;
    public TIntArrayList[] Phi;
    public int[] n_k;
    public int[] n_kv;
    double bias = 0.0d;
    double[][] b = (double[][]) null;
    double[][] documentFeatures = (double[][]) null;
    double lambda = 1.0d;

    public MyWordSimilarityOptimizableOld2(int i, int i2) {
        this.k = 10;
        this.v = 1928;
        this.parameters = (double[][]) null;
        this.k = i;
        this.v = i2;
        this.parameters = new double[i + 1][i2];
        for (int i3 = 0; i3 < i; i3++) {
            for (int i4 = 0; i4 < i2; i4++) {
                this.parameters[i3][i4] = (2.0d * Math.random()) - 1.0d;
            }
        }
    }

    @Override // cc.mallet.optimize.Optimizable
    public int getNumParameters() {
        return (this.k * this.v) + this.v + 1;
    }

    @Override // cc.mallet.optimize.Optimizable
    public void getParameters(double[] dArr) {
        for (int i = 0; i < this.k + 1; i++) {
            for (int i2 = 0; i2 < this.v; i2++) {
                dArr[(i * this.v) + i2] = this.parameters[i][i2];
            }
        }
    }

    @Override // cc.mallet.optimize.Optimizable
    public double getParameter(int i) {
        return this.parameters[i / this.v][i % this.v];
    }

    @Override // cc.mallet.optimize.Optimizable
    public void setParameters(double[] dArr) {
        for (int i = 0; i < this.k + 1; i++) {
            for (int i2 = 0; i2 < this.v; i2++) {
                this.parameters[i][i2] = dArr[(i * this.v) + i2];
            }
        }
    }

    @Override // cc.mallet.optimize.Optimizable
    public void setParameter(int i, double d) {
        this.parameters[i / this.v][i % this.v] = d;
    }

    @Override // cc.mallet.optimize.Optimizable.ByGradientValue
    public void getValueGradient(double[] dArr) {
        for (int i = 0; i < dArr.length; i++) {
            dArr[i] = 0.0d;
        }
        this.b = new double[this.k][this.v];
        for (int i2 = 0; i2 < this.k; i2++) {
            for (int i3 = 0; i3 < this.v; i3++) {
                double d = 0.0d;
                if (this.Phi != null && this.Phi[i2] != null) {
                    for (int i4 = 0; i4 < this.Phi[i3].size(); i4++) {
                        d += this.parameters[i2][this.Phi[i3].get(i4)];
                    }
                }
                this.b[i2][i3] = Math.exp(d + this.parameters[this.k][i3] + this.bias);
            }
        }
        for (int i5 = 0; i5 < this.k; i5++) {
            for (int i6 = 0; i6 < this.v; i6++) {
                dArr[(i5 * this.v) + i6] = (1.0d / (this.lambda * this.lambda)) * this.parameters[i5][i6];
            }
        }
        for (int i7 = 0; i7 < this.v; i7++) {
            dArr[(this.k * this.v) + i7] = (1.0d / (this.lambda * this.lambda)) * this.parameters[this.k][i7];
        }
        dArr[(this.k * this.v) + this.v] = this.bias * (1.0d / (this.lambda * this.lambda));
        double[] dArr2 = new double[this.k];
        for (int i8 = 0; i8 < this.k; i8++) {
            double d2 = 0.0d;
            for (int i9 = 0; i9 < this.v; i9++) {
                d2 += this.b[i8][i9];
            }
            int i10 = i8;
            dArr2[i10] = dArr2[i10] + d2;
        }
        for (int i11 = 0; i11 < this.v; i11++) {
            for (int i12 = 0; i12 < this.k; i12++) {
                double d3 = 0.0d;
                if (this.Phi != null && this.Phi[i11] != null) {
                    for (int i13 = 0; i13 < this.Phi[i11].size(); i13++) {
                        d3 += this.b[i12][this.Phi[i11].get(i13)];
                    }
                }
                int i14 = (i12 * this.v) + i11;
                dArr[i14] = dArr[i14] + (d3 * (Gamma.digamma(dArr2[i12] + this.n_k[i12]) - Gamma.digamma(dArr2[i12])));
                double d4 = 0.0d;
                if (this.Phi != null && this.Phi[i11] != null) {
                    for (int i15 = 0; i15 < this.Phi[i11].size(); i15++) {
                        d4 += (Gamma.digamma(this.b[i12][i15]) - Gamma.digamma(this.b[i12][i15] + this.n_kv[(i15 * this.k) + i12])) * this.b[i12][this.Phi[i11].get(i15)];
                    }
                }
                int i16 = (i12 * this.v) + i11;
                dArr[i16] = dArr[i16] + d4;
                int i17 = (this.k * this.v) + i11;
                dArr[i17] = dArr[i17] + ((Gamma.digamma(this.b[i12][this.v]) - Gamma.digamma(this.b[i12][this.v] + this.n_kv[(this.v * this.k) + i12])) * this.b[i12][i11]);
            }
        }
    }

    @Override // cc.mallet.optimize.Optimizable.ByGradientValue
    public double getValue() {
        double d = 0.0d;
        for (int i = 0; i < this.k; i++) {
            for (int i2 = 0; i2 < this.v; i2++) {
                d += Math.pow(this.parameters[i][i2], 2.0d);
            }
        }
        double sqrt = ((this.k * this.v) / Math.sqrt((3.141592653589793d * this.lambda) * this.lambda)) - (d / ((2.0d * this.lambda) * this.lambda));
        double[] dArr = new double[this.k];
        double d2 = 0.0d;
        for (int i3 = 0; i3 < this.k; i3++) {
            double d3 = 0.0d;
            for (int i4 = 0; i4 < this.v; i4++) {
                if (this.Phi != null && this.Phi[i4] != null) {
                    for (int i5 = 0; i5 < this.Phi[i4].size(); i5++) {
                        d3 += this.parameters[i3][this.Phi[i4].get(i5)];
                    }
                }
                int i6 = i3;
                dArr[i6] = dArr[i6] + Math.exp(d3);
            }
            d2 += Gamma.logGamma(dArr[i3] + this.n_k[i3]) - Gamma.logGamma(dArr[i3]);
        }
        double d4 = 0.0d;
        for (int i7 = 0; i7 < this.k; i7++) {
            for (int i8 = 0; i8 < this.v; i8++) {
                if (this.n_kv[(i8 * this.k) + i7] > 0 && this.Phi != null) {
                    double d5 = 0.0d;
                    if (this.Phi[i8] != null) {
                        for (int i9 = 0; i9 < this.Phi[i8].size(); i9++) {
                            d5 += this.parameters[i7][this.Phi[i8].get(i9)];
                        }
                    }
                    d4 += 0.0d + (Gamma.logGamma(Math.exp(d5)) - Gamma.logGamma(Math.exp(d5) + this.n_kv[(i8 * this.k) + i7]));
                }
            }
        }
        return (d2 + d4) - sqrt;
    }
}
