package com.rapidminer.kobra.topicmodels;

import cc.mallet.optimize.Optimizable;
import cc.mallet.types.Dirichlet;
import java.util.Random;
import org.apache.commons.math3.special.Gamma;

/* loaded from: input_file:com/rapidminer/kobra/topicmodels/MyDMROptimizable.class */
public class MyDMROptimizable implements Optimizable.ByGradientValue {
    int f;
    int k;
    int n;
    double sigma = 1.0d;
    double[][] parameters = (double[][]) null;
    double[] b = null;
    double[][] documentFeatures = (double[][]) null;
    int[] n_d = null;
    int[] n_td = null;

    public MyDMROptimizable(int i, int i2, int i3) {
        this.f = 10;
        this.k = 10;
        this.n = 1928;
        this.k = i2;
        this.f = i;
        this.n = i3;
    }

    public void init() {
        this.parameters = new double[this.k][this.f];
        for (int i = 0; i < this.k; i++) {
            for (int i2 = 0; i2 < this.f; i2++) {
                this.parameters[i][i2] = ((2.0d * Math.random()) * this.sigma) - this.sigma;
            }
        }
    }

    public void init(Random random) {
        this.parameters = new double[this.k][this.f];
        for (int i = 0; i < this.k; i++) {
            for (int i2 = 0; i2 < this.f; i2++) {
                this.parameters[i][i2] = ((2.0d * random.nextDouble()) * this.sigma) - this.sigma;
            }
        }
    }

    @Override // cc.mallet.optimize.Optimizable
    public int getNumParameters() {
        return this.f * this.k;
    }

    @Override // cc.mallet.optimize.Optimizable
    public void getParameters(double[] dArr) {
        for (int i = 0; i < this.k; i++) {
            for (int i2 = 0; i2 < this.f; i2++) {
                dArr[(i * this.f) + i2] = this.parameters[i][i2];
            }
        }
    }

    @Override // cc.mallet.optimize.Optimizable
    public double getParameter(int i) {
        return this.parameters[i / this.f][i % this.f];
    }

    @Override // cc.mallet.optimize.Optimizable
    public void setParameters(double[] dArr) {
        for (int i = 0; i < this.k; i++) {
            for (int i2 = 0; i2 < this.f; i2++) {
                this.parameters[i][i2] = dArr[(i * this.f) + i2];
            }
        }
    }

    @Override // cc.mallet.optimize.Optimizable
    public void setParameter(int i, double d) {
        this.parameters[i / this.f][i % this.f] = d;
    }

    @Override // cc.mallet.optimize.Optimizable.ByGradientValue
    public void getValueGradient(double[] dArr) {
        for (int i = 0; i < dArr.length; i++) {
            dArr[i] = 0.0d;
        }
        for (int i2 = 0; i2 < this.k; i2++) {
            for (int i3 = 0; i3 < this.f; i3++) {
                dArr[(i2 * this.f) + i3] = (-this.parameters[i2][i3]) / (this.sigma * this.sigma);
            }
        }
        for (int i4 = 0; i4 < this.n; i4++) {
            double[] dArr2 = new double[this.k];
            double[] dArr3 = new double[this.k];
            double d = 0.0d;
            for (int i5 = 0; i5 < this.k; i5++) {
                for (int i6 = 0; i6 < this.f; i6++) {
                    int i7 = i5;
                    dArr2[i7] = dArr2[i7] + (this.documentFeatures[i4][i6] * this.parameters[i5][i6]);
                }
                dArr3[i5] = Math.exp(dArr2[i5]);
                d += dArr3[i5];
            }
            for (int i8 = 0; i8 < this.f; i8++) {
                for (int i9 = 0; i9 < this.k; i9++) {
                    int i10 = (i9 * this.f) + i8;
                    dArr[i10] = dArr[i10] + (this.documentFeatures[i9][i8] * dArr3[i9] * (((Dirichlet.digamma(d) - Dirichlet.digamma(d + this.n_d[i4])) + Dirichlet.digamma(dArr3[i9] + this.n_td[(i4 * this.k) + i9])) - Dirichlet.digamma(dArr3[i9])));
                }
            }
        }
        for (int i11 = 0; i11 < dArr.length; i11++) {
            dArr[i11] = (-dArr[i11]) / this.n;
        }
    }

    @Override // cc.mallet.optimize.Optimizable.ByGradientValue
    public double getValue() {
        double d = 0.0d;
        for (int i = 0; i < this.f; i++) {
            for (int i2 = 0; i2 < this.k; i2++) {
                d = (-this.parameters[i2][i]) / (this.sigma * this.sigma);
            }
        }
        for (int i3 = 0; i3 < this.n; i3++) {
            double[] dArr = new double[this.k];
            double[] dArr2 = new double[this.k];
            double d2 = 0.0d;
            for (int i4 = 0; i4 < this.k; i4++) {
                for (int i5 = 0; i5 < this.f; i5++) {
                    int i6 = i4;
                    dArr[i6] = dArr[i6] + (this.documentFeatures[i3][i5] * this.parameters[i4][i5]);
                }
                dArr2[i4] = Math.exp(dArr[i4]);
                d2 += dArr2[i4];
            }
            d += Gamma.logGamma(d2) - Gamma.logGamma(d2 + this.n_d[i3]);
            for (int i7 = 0; i7 < this.k; i7++) {
                d += Gamma.logGamma(dArr2[i7] + this.n_td[(i3 * this.k) + i7]) - Gamma.logGamma(dArr2[i7]);
            }
        }
        return (-d) / this.n;
    }
}
