package com.rapidminer.kobra.topicmodels;

import cc.mallet.optimize.OptimizationException;
import com.rapidminer.kobra.opt.MyOrthantWiseLimitedMemoryBFGS;
import com.rapidminer.kobra.topicmodels.SamplersHLDA;
import com.rapidminer.tools.RandomGenerator;
import gnu.trove.TIntArrayList;
import gnu.trove.map.hash.TDoubleIntHashMap;
import gnu.trove.map.hash.TIntIntHashMap;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.Random;

/* loaded from: input_file:com/rapidminer/kobra/topicmodels/SamplersGompertzHLDA.class */
public class SamplersGompertzHLDA extends SamplersHLDA {
    double[] times;
    double[][] pGompertz;
    double[][] pi;
    ArrayList<SamplersHLDA.NCRPNode> allCurrentNodes = null;
    int size = 0;
    double[] predictions = null;

    /* loaded from: input_file:com/rapidminer/kobra/topicmodels/SamplersGompertzHLDA$GNCRPNode.class */
    class GNCRPNode extends SamplersHLDA.NCRPNode {
        public GNCRPNode(SamplersHLDA.NCRPNode nCRPNode, int i, int i2) {
            super(nCRPNode, i, i2);
            this.customers = 0;
            this.parent = nCRPNode;
            this.children = new ArrayList<>();
            this.level = i2;
            this.totalTokens = 0;
            this.typeCounts = new int[i];
            this.nodeID = SamplersGompertzHLDA.this.totalNodes;
            SamplersGompertzHLDA.this.totalNodes++;
            this.hsValues = new TDoubleIntHashMap();
            this.pGompertz = new double[]{Math.random(), Math.random()};
        }

        public GNCRPNode(int i) {
            super(SamplersGompertzHLDA.this, i);
        }
    }

    public static double dist(double d, double d2, double d3) {
        return d3 * Math.exp(-((d3 * d) + (d2 * Math.exp((-d3) * d)))) * (1.0d + (d2 * (1.0d - Math.exp((-d3) * d))));
    }

    @Override // com.rapidminer.kobra.topicmodels.SamplersHLDA
    public void sampleTopics(int i) {
        int length = this.sequences[i].length;
        int[] iArr = this.levels[i];
        SamplersHLDA.NCRPNode[] nCRPNodeArr = new SamplersHLDA.NCRPNode[this.numLevels];
        int[] iArr2 = new int[this.numLevels];
        SamplersHLDA.NCRPNode nCRPNode = this.documentLeaves[i];
        for (int i2 = this.numLevels - 1; i2 >= 0; i2--) {
            nCRPNodeArr[i2] = nCRPNode;
            nCRPNode = nCRPNode.parent;
        }
        double[] dArr = new double[this.numLevels];
        for (int i3 = 0; i3 < length; i3++) {
            int i4 = iArr[i3];
            iArr2[i4] = iArr2[i4] + 1;
        }
        for (int i5 = 0; i5 < length; i5++) {
            int i6 = this.sequences[i][i5];
            int i7 = iArr[i5];
            iArr2[i7] = iArr2[i7] - 1;
            SamplersHLDA.NCRPNode nCRPNode2 = nCRPNodeArr[iArr[i5]];
            int[] iArr3 = nCRPNode2.typeCounts;
            iArr3[i6] = iArr3[i6] - 1;
            nCRPNode2.totalTokens--;
            double d = 0.0d;
            for (int i8 = 0; i8 < this.numLevels; i8++) {
                this.NodeIdToTopic.get(nCRPNodeArr[i8].nodeID);
                double[] dArr2 = nCRPNodeArr[i8].pGompertz;
                dArr[i8] = (((this.alpha + iArr2[i8]) * (this.eta + nCRPNodeArr[i8].typeCounts[i6])) / (this.etaSum + nCRPNodeArr[i8].totalTokens)) * dist(this.times[i], dArr2[0], dArr2[1]);
                d += dArr[i8];
            }
            for (int i9 = 0; i9 < this.numLevels; i9++) {
                int i10 = i9;
                dArr[i10] = dArr[i10] / d;
            }
            int randomIndex = RandomGenerator.getGlobalRandomGenerator().randomIndex(dArr);
            iArr[i5] = randomIndex;
            int i11 = iArr[i5];
            iArr2[i11] = iArr2[i11] + 1;
            SamplersHLDA.NCRPNode nCRPNode3 = nCRPNodeArr[randomIndex];
            int[] iArr4 = nCRPNode3.typeCounts;
            iArr4[i6] = iArr4[i6] + 1;
            nCRPNode3.totalTokens++;
            this.NodeIdToTopic.get(nCRPNode3.nodeID);
            TDoubleIntHashMap tDoubleIntHashMap = nCRPNode3.hsValues;
            int i12 = 1;
            if (tDoubleIntHashMap.contains(this.times[i])) {
                i12 = 1 + tDoubleIntHashMap.get(this.times[i]);
            }
            tDoubleIntHashMap.put(this.times[i], i12);
        }
    }

    /* JADX WARN: Type inference failed for: r1v29, types: [int[], int[][]] */
    /* JADX WARN: Type inference failed for: r1v33, types: [int[], int[][]] */
    public void init(int[] iArr, int[] iArr2, double[] dArr, int i, int i2, int i3, int i4, double d, double d2, boolean z, int i5) {
        this.times = dArr;
        this.maxIter = i4;
        this.BETA = d;
        this.ALPHA = d2;
        this.alpha = d2;
        this.eta = d;
        this.etaSum = i2 * d;
        this.numTokens = iArr2.length;
        this.numTopics = i;
        this.numDocs = i3;
        this.numWords = i2;
        this.topics = new int[this.numTokens];
        this.wordtopiccounts = new int[i2 * i];
        this.doctopiccounts = new int[i3 * i];
        this.topiccounts = new int[i];
        this.words = iArr2;
        this.docs = iArr;
        if (z) {
            this.seed = i5;
            this.rn = new Random(i5);
        } else {
            this.rn = new Random();
        }
        this.levels = new int[i3];
        this.documentLeaves = new SamplersHLDA.NCRPNode[i3];
        this.sequences = new int[i3];
        SamplersHLDA.NCRPNode[] nCRPNodeArr = new SamplersHLDA.NCRPNode[this.numLevels];
        this.rootNode = new SamplersHLDA.NCRPNode(this, i2);
        TIntArrayList[] tIntArrayListArr = new TIntArrayList[i3];
        for (int i6 = 0; i6 < i3; i6++) {
            tIntArrayListArr[i6] = new TIntArrayList();
        }
        for (int i7 = 0; i7 < iArr2.length; i7++) {
            int i8 = this.words[i7];
            int i9 = this.docs[i7];
            tIntArrayListArr[i9].add(i8);
            int nextInt = this.rn.nextInt(i);
            this.topics[i7] = nextInt;
            int[] iArr3 = this.wordtopiccounts;
            int i10 = (i8 * i) + nextInt;
            iArr3[i10] = iArr3[i10] + 1;
            int[] iArr4 = this.doctopiccounts;
            int i11 = (i9 * i) + nextInt;
            iArr4[i11] = iArr4[i11] + 1;
            int[] iArr5 = this.topiccounts;
            iArr5[nextInt] = iArr5[nextInt] + 1;
        }
        for (int i12 = 0; i12 < i3; i12++) {
            tIntArrayListArr[i12].shuffle(this.rn);
            this.sequences[i12] = tIntArrayListArr[i12].toNativeArray();
        }
        this.documentLeaves = new SamplersHLDA.NCRPNode[i3];
        for (int i13 = 0; i13 < i3; i13++) {
            int length = this.sequences[i13].length;
            nCRPNodeArr[0] = this.rootNode;
            this.rootNode.customers++;
            for (int i14 = 1; i14 < this.numLevels; i14++) {
                nCRPNodeArr[i14] = nCRPNodeArr[i14 - 1].select();
                nCRPNodeArr[i14].customers++;
            }
            this.node = nCRPNodeArr[this.numLevels - 1];
            this.levels[i13] = new int[length];
            this.documentLeaves[i13] = this.node;
            for (int i15 = 0; i15 < length; i15++) {
                int i16 = this.sequences[i13][i15];
                this.levels[i13][i15] = this.rn.nextInt(this.numLevels);
                this.node = nCRPNodeArr[this.levels[i13][i15]];
                this.node.totalTokens++;
                int[] iArr6 = this.node.typeCounts;
                iArr6[i16] = iArr6[i16] + 1;
            }
        }
        this.allCurrentNodes = new ArrayList<>();
        this.NodeIdToTopic = new TIntIntHashMap();
        System.out.println(this.totalNodes);
    }

    @Override // com.rapidminer.kobra.topicmodels.SamplersHLDA
    public void countNodes(SamplersHLDA.NCRPNode nCRPNode) {
        this.size++;
        Iterator<SamplersHLDA.NCRPNode> it = nCRPNode.children.iterator();
        while (it.hasNext()) {
            countNodes(it.next());
        }
    }

    public void numerateNodes(SamplersHLDA.NCRPNode nCRPNode) {
        this.allCurrentNodes.add(nCRPNode);
        Iterator<SamplersHLDA.NCRPNode> it = nCRPNode.children.iterator();
        while (it.hasNext()) {
            SamplersHLDA.NCRPNode next = it.next();
            numerateNodes(next);
            nCRPNode.hsValues.putAll(next.hsValues);
        }
    }

    @Override // com.rapidminer.kobra.topicmodels.SamplersHLDA, com.rapidminer.kobra.topicmodels.SamplersLDA
    public void GibbsSampling() {
        for (int i = 0; i < this.maxIter; i++) {
            for (int i2 = 0; i2 < this.numDocs; i2++) {
                samplePath(i2, i);
            }
            this.next = 0;
            this.allCurrentNodes = new ArrayList<>();
            numerateNodes(this.rootNode);
            this.numTopics = this.allCurrentNodes.size();
            for (int i3 = 0; i3 < this.numDocs; i3++) {
                sampleTopics(i3);
            }
            this.pGompertz = new double[this.numTopics][2];
            int i4 = 0;
            for (int i5 = 0; i5 < this.numTopics; i5++) {
                SamplersHLDA.NCRPNode nCRPNode = this.allCurrentNodes.get(i5);
                MyHashGompertzOptimizable myHashGompertzOptimizable = new MyHashGompertzOptimizable();
                myHashGompertzOptimizable.vals = nCRPNode.hsValues;
                i4 += myHashGompertzOptimizable.vals.size();
                myHashGompertzOptimizable.alpha = this.rn.nextDouble();
                myHashGompertzOptimizable.beta = this.rn.nextDouble();
                myHashGompertzOptimizable.a = this.rn.nextDouble();
                myHashGompertzOptimizable.b = this.rn.nextDouble();
                MyOrthantWiseLimitedMemoryBFGS myOrthantWiseLimitedMemoryBFGS = new MyOrthantWiseLimitedMemoryBFGS(myHashGompertzOptimizable);
                double[] dArr = new double[2];
                try {
                    myOrthantWiseLimitedMemoryBFGS.optimize(100);
                } catch (OptimizationException e) {
                    e.printStackTrace();
                }
                myHashGompertzOptimizable.getParameters(dArr);
                nCRPNode.pGompertz[0] = Math.exp(dArr[0]);
                nCRPNode.pGompertz[1] = Math.exp(dArr[1]);
                this.pGompertz[i5] = nCRPNode.pGompertz;
                myHashGompertzOptimizable.vals.clear();
            }
            this.size = 0;
            countNodes(this.rootNode);
            System.out.println(this.size);
        }
    }

    public double[][] getPi() {
        return this.pGompertz;
    }

    public double[] getPredictions() {
        if (this.predictions == null) {
            this.predictions = new double[this.numDocs];
        }
        return this.predictions;
    }
}
