package com.rapidminer.kobra.topicmodels;

import com.rapidminer.example.Attribute;
import com.rapidminer.example.Attributes;
import com.rapidminer.example.Example;
import com.rapidminer.example.ExampleSet;
import com.rapidminer.example.set.SimpleExampleSet;
import com.rapidminer.example.table.AttributeFactory;
import com.rapidminer.example.table.DataRow;
import com.rapidminer.example.table.DataRowFactory;
import com.rapidminer.example.table.MemoryExampleTable;
import com.rapidminer.operator.Operator;
import com.rapidminer.operator.OperatorDescription;
import com.rapidminer.operator.OperatorException;
import com.rapidminer.operator.ports.InputPort;
import com.rapidminer.operator.ports.OutputPort;
import com.rapidminer.parameter.ParameterType;
import com.rapidminer.parameter.ParameterTypeDouble;
import com.rapidminer.parameter.ParameterTypeInt;
import com.rapidminer.parameter.ParameterTypeString;
import com.rapidminer.tools.Ontology;
import com.rapidminer.tools.RandomGenerator;
import gnu.trove.list.array.TIntArrayList;
import gnu.trove.map.hash.TObjectIntHashMap;
import java.io.PrintStream;
import java.util.ArrayList;
import java.util.List;
import java.util.Random;

/* loaded from: input_file:com/rapidminer/kobra/topicmodels/LDAEvaluationOperator.class */
public class LDAEvaluationOperator extends Operator {
    static String PARAMETER_NUMITERATIONS = "iterations";
    static String PARAMETER_NUMTOPICS = "number_of_topics";
    static String PARAMETER_NUMTESTS = "tests";
    static String PARAMETER_ALPHA = "alpha";
    static String PARAMETER_BETA = "beta";
    static String PARAMETER_TEXT_ATTRIBUTE = "text_attribute";
    int iters;
    protected int numTopics;
    protected double alpha;
    protected double alphaSum;
    protected double beta;
    protected double betaSum;
    protected double smoothingOnlyMass;
    protected double[] cachedCoefficients;
    protected int[][] typeTopicCounts;
    protected int[] tokensPerTopic;
    protected Random random;
    private final InputPort input;
    private final InputPort inputWords;
    private final InputPort inputTopics;
    private final OutputPort output;
    Random rn;

    public LDAEvaluationOperator(OperatorDescription operatorDescription) {
        super(operatorDescription);
        this.iters = 2000;
        this.smoothingOnlyMass = 0.0d;
        this.input = getInputPorts().createPort("example set input");
        this.inputWords = getInputPorts().createPort("example set words assignments");
        this.inputTopics = getInputPorts().createPort("example set topic assignments");
        this.output = getOutputPorts().createPort("output neg log likelihoods");
        this.rn = null;
    }

    public void MarginalProbEstimator(int i, double d, double d2, double d3, int[][] iArr, int[] iArr2) {
        this.numTopics = i;
        this.typeTopicCounts = iArr;
        this.tokensPerTopic = iArr2;
        this.alphaSum = d2;
        this.alpha = d;
        this.beta = d3;
        this.betaSum = d3 * iArr.length;
        this.random = new Random();
        this.cachedCoefficients = new double[i];
        this.smoothingOnlyMass = 0.0d;
        for (int i2 = 0; i2 < i; i2++) {
            this.smoothingOnlyMass += (d * d3) / (iArr2[i2] + this.betaSum);
            this.cachedCoefficients[i2] = d / (iArr2[i2] + this.betaSum);
        }
    }

    public int[] getTokensPerTopic() {
        return this.tokensPerTopic;
    }

    public int[][] getTypeTopicCounts() {
        return this.typeTopicCounts;
    }

    /* JADX WARN: Multi-variable type inference failed */
    public double evaluateLeftToRight(TIntArrayList[] tIntArrayListArr, int i, boolean z, PrintStream printStream) {
        this.random = this.rn;
        double log = Math.log(i);
        double d = 0.0d;
        for (TIntArrayList tIntArrayList : tIntArrayListArr) {
            tIntArrayList.shuffle(this.rn);
            int[] array = tIntArrayList.toArray();
            double d2 = 0.0d;
            double[] dArr = new double[i];
            for (int i2 = 0; i2 < i; i2++) {
                dArr[i2] = leftToRight(array, z);
            }
            for (int i3 = 0; i3 < dArr[0].length; i3++) {
                double d3 = 0.0d;
                for (int i4 = 0; i4 < i; i4++) {
                    d3 += dArr[i4][i3];
                }
                if (d3 > 0.0d) {
                    d2 += Math.log(d3) - log;
                }
            }
            if (printStream != null) {
                printStream.println(d2);
            }
            d += d2;
        }
        return d;
    }

    protected double[] leftToRight(int[] iArr, boolean z) {
        int[] iArr2 = new int[iArr.length];
        double[] dArr = new double[iArr.length];
        int length = iArr.length;
        int i = 0;
        int[] iArr3 = new int[this.numTopics];
        int[] iArr4 = new int[this.numTopics];
        int i2 = 0;
        double d = 0.0d;
        double[] dArr2 = new double[this.numTopics];
        for (int i3 = 0; i3 < length; i3++) {
            if (z) {
                for (int i4 = 0; i4 < i3; i4++) {
                    int i5 = iArr[i4];
                    int i6 = iArr2[i4];
                    if (i5 < this.typeTopicCounts.length && this.typeTopicCounts[i5] != null) {
                        int[] iArr5 = this.typeTopicCounts[i5];
                        double d2 = d - ((this.beta * iArr3[i6]) / (this.tokensPerTopic[i6] + this.betaSum));
                        iArr3[i6] = iArr3[i6] - 1;
                        if (iArr3[i6] == 0) {
                            int i7 = 0;
                            while (iArr4[i7] != i6) {
                                i7++;
                            }
                            while (i7 < i2) {
                                if (i7 < iArr4.length - 1) {
                                    iArr4[i7] = iArr4[i7 + 1];
                                }
                                i7++;
                            }
                            i2--;
                        }
                        double d3 = d2 + ((this.beta * iArr3[i6]) / (this.tokensPerTopic[i6] + this.betaSum));
                        this.cachedCoefficients[i6] = (this.alpha + iArr3[i6]) / (this.tokensPerTopic[i6] + this.betaSum);
                        double d4 = 0.0d;
                        for (int i8 = 0; i8 < iArr5.length && iArr5[i8] > 0; i8++) {
                            double d5 = this.cachedCoefficients[i8] * iArr5[i8];
                            d4 += d5;
                            dArr2[i8] = d5;
                        }
                        double nextDouble = this.random.nextDouble() * (this.smoothingOnlyMass + d3 + d4);
                        int i9 = -1;
                        if (nextDouble >= d4) {
                            double d6 = nextDouble - d4;
                            if (d6 >= d3) {
                                i9 = 0;
                                double d7 = (d6 - d3) / this.beta;
                                double d8 = this.alpha;
                                double d9 = this.tokensPerTopic[0];
                                double d10 = this.betaSum;
                                while (true) {
                                    nextDouble = d7 - (d8 / (d9 + d10));
                                    if (nextDouble <= 0.0d) {
                                        break;
                                    }
                                    i9++;
                                    d7 = nextDouble;
                                    d8 = this.alpha;
                                    d9 = this.tokensPerTopic[i9];
                                    d10 = this.betaSum;
                                }
                            } else {
                                nextDouble = d6 / this.beta;
                                int i10 = 0;
                                while (true) {
                                    if (i10 >= i2) {
                                        break;
                                    }
                                    int i11 = iArr4[i10];
                                    nextDouble -= iArr3[i11] / (this.tokensPerTopic[i11] + this.betaSum);
                                    if (nextDouble <= 0.0d) {
                                        i9 = i11;
                                        break;
                                    }
                                    i10++;
                                }
                            }
                        } else {
                            int i12 = -1;
                            while (nextDouble > 0.0d) {
                                i12++;
                                nextDouble -= dArr2[i12];
                            }
                            i9 = i12;
                        }
                        if (i9 == -1) {
                            System.err.println("sampling error: " + nextDouble + " " + nextDouble + " " + this.smoothingOnlyMass + " " + d3 + " " + d4);
                            i9 = this.numTopics - 1;
                        }
                        iArr2[i4] = i9;
                        double d11 = d3 - ((this.beta * iArr3[i9]) / (this.tokensPerTopic[i9] + this.betaSum));
                        int i13 = i9;
                        iArr3[i13] = iArr3[i13] + 1;
                        if (iArr3[i9] == 1) {
                            int i14 = i2;
                            while (i14 > 0 && iArr4[i14 - 1] > i9) {
                                iArr4[i14] = iArr4[i14 - 1];
                                i14--;
                            }
                            iArr4[i14] = i9;
                            i2++;
                        }
                        this.cachedCoefficients[i9] = (this.alpha + iArr3[i9]) / (this.tokensPerTopic[i9] + this.betaSum);
                        d = d11 + ((this.beta * iArr3[i9]) / (this.tokensPerTopic[i9] + this.betaSum));
                    }
                }
            }
            int i15 = iArr[i3];
            if (i15 < this.typeTopicCounts.length && this.typeTopicCounts[i15] != null) {
                int[] iArr6 = this.typeTopicCounts[i15];
                double d12 = 0.0d;
                for (int i16 = 0; i16 < iArr6.length && iArr6[i16] > 0; i16++) {
                    double d13 = this.cachedCoefficients[i16] * iArr6[i16];
                    d12 += d13;
                    dArr2[i16] = d13;
                }
                double nextDouble2 = this.random.nextDouble() * (this.smoothingOnlyMass + d + d12);
                int i17 = i3;
                dArr[i17] = dArr[i17] + (((this.smoothingOnlyMass + d) + d12) / (this.alphaSum + i));
                i++;
                int i18 = -1;
                if (nextDouble2 >= d12) {
                    double d14 = nextDouble2 - d12;
                    if (d14 >= d) {
                        i18 = 0;
                        double d15 = (d14 - d) / this.beta;
                        double d16 = this.alpha;
                        double d17 = this.tokensPerTopic[0];
                        double d18 = this.betaSum;
                        while (true) {
                            nextDouble2 = d15 - (d16 / (d17 + d18));
                            if (nextDouble2 <= 0.0d) {
                                break;
                            }
                            i18++;
                            d15 = nextDouble2;
                            d16 = this.alpha;
                            d17 = this.tokensPerTopic[i18];
                            d18 = this.betaSum;
                        }
                    } else {
                        nextDouble2 = d14 / this.beta;
                        int i19 = 0;
                        while (true) {
                            if (i19 >= i2) {
                                break;
                            }
                            int i20 = iArr4[i19];
                            nextDouble2 -= iArr3[i20] / (this.tokensPerTopic[i20] + this.betaSum);
                            if (nextDouble2 <= 0.0d) {
                                i18 = i20;
                                break;
                            }
                            i19++;
                        }
                    }
                } else {
                    int i21 = -1;
                    while (nextDouble2 > 0.0d) {
                        i21++;
                        nextDouble2 -= dArr2[i21];
                    }
                    i18 = i21;
                }
                if (i18 == -1) {
                    System.err.println("sampling error: " + nextDouble2 + " " + nextDouble2 + " " + this.smoothingOnlyMass + " " + d + " " + d12);
                    i18 = this.numTopics - 1;
                }
                iArr2[i3] = i18;
                double d19 = d - ((this.beta * iArr3[i18]) / (this.tokensPerTopic[i18] + this.betaSum));
                int i22 = i18;
                iArr3[i22] = iArr3[i22] + 1;
                if (iArr3[i18] == 1) {
                    int i23 = i2;
                    while (i23 > 0 && iArr4[i23 - 1] > i18) {
                        iArr4[i23] = iArr4[i23 - 1];
                        i23--;
                    }
                    iArr4[i23] = i18;
                    i2++;
                }
                this.cachedCoefficients[i18] = (this.alpha + iArr3[i18]) / (this.tokensPerTopic[i18] + this.betaSum);
                d = d19 + ((this.beta * iArr3[i18]) / (this.tokensPerTopic[i18] + this.betaSum));
            }
        }
        for (int i24 = 0; i24 < i2; i24++) {
            this.cachedCoefficients[iArr4[i24]] = this.alpha / (this.tokensPerTopic[r0] + this.betaSum);
        }
        return dArr;
    }

    public void doWork() throws OperatorException {
        this.iters = getParameterAsInt(PARAMETER_NUMITERATIONS);
        this.numTopics = getParameterAsInt(PARAMETER_NUMTOPICS);
        this.alpha = getParameterAsDouble(PARAMETER_ALPHA);
        this.beta = getParameterAsDouble(PARAMETER_BETA);
        int parameterAsInt = getParameterAsInt(PARAMETER_NUMTESTS);
        boolean parameterAsBoolean = getParameterAsBoolean("use_local_random_seed");
        int parameterAsInt2 = getParameterAsInt("local_random_seed");
        if (parameterAsBoolean) {
            this.rn = new Random(parameterAsInt2);
        } else {
            this.rn = new Random();
        }
        ArrayList arrayList = new ArrayList();
        arrayList.add(AttributeFactory.createAttribute("negloglikelihood", 2));
        MemoryExampleTable memoryExampleTable = new MemoryExampleTable(arrayList);
        DataRowFactory dataRowFactory = new DataRowFactory(0, '.');
        ExampleSet data = this.input.getData(ExampleSet.class);
        Attributes<Attribute> attributes = data.getExample(0).getAttributes();
        String[] strArr = new String[attributes.size()];
        TIntArrayList[] tIntArrayListArr = new TIntArrayList[data.size()];
        Attribute attribute = null;
        String parameterAsString = getParameterAsString(PARAMETER_TEXT_ATTRIBUTE);
        for (Attribute attribute2 : attributes) {
            if (Ontology.ATTRIBUTE_VALUE_TYPE.isA(attribute2.getValueType(), 5)) {
                if (parameterAsString.equals("")) {
                    attribute = attribute2;
                } else if (parameterAsString.equals(attribute2.getName())) {
                    attribute = attribute2;
                }
            }
        }
        if (attribute != null) {
            TObjectIntHashMap tObjectIntHashMap = new TObjectIntHashMap();
            int i = 0;
            for (Attribute attribute3 : data.getExample(0).getAttributes()) {
                if (attribute3 != attribute) {
                    tObjectIntHashMap.put(attribute3.getName().trim().toLowerCase(), i);
                    i++;
                }
            }
            for (int i2 = 0; i2 < data.size(); i2++) {
                Example example = data.getExample(0);
                example.getAttributes();
                String[] split = example.getValueAsString(attribute).split(" ");
                tIntArrayListArr[i2] = new TIntArrayList();
                for (String str : split) {
                    if (tObjectIntHashMap.contains(str.trim().toLowerCase())) {
                        tIntArrayListArr[i2].add(tObjectIntHashMap.get(str.trim().toLowerCase()));
                    }
                }
                tIntArrayListArr[i2].shuffle(this.rn);
            }
        } else {
            for (int i3 = 0; i3 < data.size(); i3++) {
                tIntArrayListArr[i3] = new TIntArrayList();
                Example example2 = data.getExample(i3);
                int i4 = 0;
                for (Attribute attribute4 : example2.getAttributes()) {
                    strArr[i4] = attribute4.getName();
                    double value = example2.getValue(attribute4);
                    if (value != 0.0d) {
                        for (int i5 = 0; i5 < ((int) value); i5++) {
                            tIntArrayListArr[i3].add(i4);
                        }
                    }
                    i4++;
                }
            }
        }
        ExampleSet data2 = this.inputWords.getData(ExampleSet.class);
        Example example3 = data2.getExample(0);
        int size = data2.size();
        this.numTopics = example3.getAttributes().size() - 2;
        int[][] iArr = new int[size][this.numTopics];
        for (int i6 = 0; i6 < data2.size(); i6++) {
            Example example4 = data2.getExample(i6);
            int i7 = 0;
            for (Attribute attribute5 : example4.getAttributes()) {
                if (attribute5.getName().contains("Topic_")) {
                    iArr[i6][i7] = (int) example4.getValue(attribute5);
                    i7++;
                }
            }
        }
        int[] iArr2 = new int[this.numTopics];
        ExampleSet data3 = this.inputTopics.getData(ExampleSet.class);
        data3.getExample(0);
        for (int i8 = 0; i8 < data3.size(); i8++) {
            Example example5 = data3.getExample(i8);
            int i9 = 0;
            for (Attribute attribute6 : example5.getAttributes()) {
                if (attribute6.getName().contains("Topic_")) {
                    iArr2[i9] = (int) example5.getValue(attribute6);
                    i9++;
                }
            }
        }
        MarginalProbEstimator(this.numTopics, this.alpha, this.numTopics * this.alpha, this.beta, iArr, iArr2);
        for (int i10 = 0; i10 < parameterAsInt; i10++) {
            double evaluateLeftToRight = evaluateLeftToRight(tIntArrayListArr, this.iters, true, null);
            System.out.println(0);
            System.out.println(evaluateLeftToRight);
            DataRow create = dataRowFactory.create(memoryExampleTable.getNumberOfAttributes());
            memoryExampleTable.addDataRow(create);
            create.set((Attribute) arrayList.get(0), evaluateLeftToRight);
        }
        this.output.deliver(new SimpleExampleSet(memoryExampleTable));
    }

    public int[] getDiscrete(int i, double[] dArr) {
        double d = 0.0d;
        for (double d2 : dArr) {
            d += d2;
        }
        for (int i2 = 0; i2 < dArr.length; i2++) {
            int i3 = i2;
            dArr[i3] = dArr[i3] / d;
        }
        int[] iArr = new int[i];
        for (int i4 = 0; i4 < i; i4++) {
            int i5 = 0;
            double nextDouble = this.rn.nextDouble();
            double d3 = dArr[0];
            while (true) {
                double d4 = d3;
                if (d4 < nextDouble) {
                    i5++;
                    d3 = d4 + dArr[i5];
                }
            }
            iArr[i4] = i5;
        }
        return iArr;
    }

    public List<ParameterType> getParameterTypes() {
        List<ParameterType> parameterTypes = super.getParameterTypes();
        parameterTypes.add(new ParameterTypeInt(PARAMETER_NUMITERATIONS, "Number of Iterations for Samplings.", 1, Integer.MAX_VALUE, 2000));
        parameterTypes.add(new ParameterTypeInt(PARAMETER_NUMTESTS, "Number of Iterations for Samplings.", 1, Integer.MAX_VALUE, 20));
        parameterTypes.add(new ParameterTypeInt(PARAMETER_NUMTOPICS, "Number of Topics.", 1, Integer.MAX_VALUE, 5));
        parameterTypes.add(new ParameterTypeDouble(PARAMETER_ALPHA, "Alpha", 0.0d, Double.MAX_VALUE, 0.25d));
        parameterTypes.add(new ParameterTypeDouble(PARAMETER_BETA, "Beta", 0.0d, Double.MAX_VALUE, 0.1d));
        parameterTypes.add(new ParameterTypeString(PARAMETER_TEXT_ATTRIBUTE, "Attribute name of text columns of interest.", ""));
        parameterTypes.addAll(RandomGenerator.getRandomGeneratorParameters(this));
        return parameterTypes;
    }

    public static void main(String[] strArr) {
    }
}
