package com.rapidminer.kobra.topicmodels;

import com.rapidminer.example.Attribute;
import com.rapidminer.example.Example;
import com.rapidminer.example.ExampleSet;
import com.rapidminer.example.set.SimpleExampleSet;
import com.rapidminer.example.table.AttributeFactory;
import com.rapidminer.example.table.DataRow;
import com.rapidminer.example.table.DataRowFactory;
import com.rapidminer.example.table.MemoryExampleTable;
import com.rapidminer.operator.Operator;
import com.rapidminer.operator.OperatorDescription;
import com.rapidminer.operator.OperatorException;
import com.rapidminer.operator.ports.InputPort;
import com.rapidminer.operator.ports.OutputPort;
import com.rapidminer.parameter.ParameterType;
import com.rapidminer.parameter.ParameterTypeDouble;
import com.rapidminer.parameter.ParameterTypeInt;
import com.rapidminer.parameter.ParameterTypeString;
import gnu.trove.list.array.TDoubleArrayList;
import gnu.trove.list.array.TIntArrayList;
import gnu.trove.map.hash.TIntObjectHashMap;
import gnu.trove.map.hash.TObjectIntHashMap;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import org.apache.lucene.analysis.pattern.PatternTokenizerFactory;

/* loaded from: input_file:com/rapidminer/kobra/topicmodels/MyLDAWordFeaturesOperator.class */
public class MyLDAWordFeaturesOperator extends Operator {
    static String PARAMETER_NUMITERATIONS = "iterations";
    static String PARAMETER_NUMTOPICS = "number_of_topics";
    static String PARAMETER_ALPHA = "alpha";
    static String PARAMETER_BETA = "beta";
    static String PARAMETER_GROUP = PatternTokenizerFactory.GROUP;
    int iters;
    int numTopics;
    double alpha;
    double beta;
    private final InputPort input;
    private final InputPort inputGroup;
    private final InputPort inputWords;
    private final OutputPort outputWords;
    private final OutputPort outputDocs;
    private final OutputPort outputGroups;

    public MyLDAWordFeaturesOperator(OperatorDescription operatorDescription) {
        super(operatorDescription);
        this.iters = 2000;
        this.numTopics = 4;
        this.alpha = 0.25d;
        this.beta = 0.1d;
        this.input = getInputPorts().createPort("example set input");
        this.inputGroup = getInputPorts().createPort("example set input group");
        this.inputWords = getInputPorts().createPort("example set input word adjacency matrix");
        this.outputWords = getOutputPorts().createPort("example set words");
        this.outputDocs = getOutputPorts().createPort("example set docs");
        this.outputGroups = getOutputPorts().createPort("example set groups");
    }

    public void doWork() throws OperatorException {
        this.iters = getParameterAsInt(PARAMETER_NUMITERATIONS);
        this.numTopics = getParameterAsInt(PARAMETER_NUMTOPICS);
        this.alpha = getParameterAsDouble(PARAMETER_ALPHA);
        this.beta = getParameterAsDouble(PARAMETER_BETA);
        TIntArrayList tIntArrayList = new TIntArrayList();
        TIntArrayList tIntArrayList2 = new TIntArrayList();
        ExampleSet data = this.input.getData(ExampleSet.class);
        int size = data.getExample(0).getAttributes().size();
        String[] strArr = new String[size];
        for (int i = 0; i < data.size(); i++) {
            int i2 = i;
            Example example = data.getExample(i);
            int i3 = 0;
            for (Attribute attribute : example.getAttributes()) {
                strArr[i3] = attribute.getName();
                int i4 = i3;
                i3++;
                double value = example.getValue(attribute);
                if (value != 0.0d) {
                    for (int i5 = 0; i5 < ((int) value); i5++) {
                        tIntArrayList2.add(i2);
                        tIntArrayList.add(i4);
                    }
                }
            }
        }
        TIntArrayList[] tIntArrayListArr = new TIntArrayList[size];
        TDoubleArrayList[] tDoubleArrayListArr = new TDoubleArrayList[size];
        double[] dArr = new double[size];
        ExampleSet dataOrNull = this.inputWords.getDataOrNull(ExampleSet.class);
        double d = 0.0d;
        if (dataOrNull != null) {
            for (int i6 = 0; i6 < dataOrNull.size(); i6++) {
                Example example2 = dataOrNull.getExample(i6);
                int i7 = 0;
                Iterator it = example2.getAttributes().iterator();
                while (it.hasNext()) {
                    double value2 = example2.getValue((Attribute) it.next());
                    if (value2 != 0.0d) {
                        int i8 = i6;
                        dArr[i8] = dArr[i8] + value2;
                        d += value2;
                    }
                    i7++;
                }
            }
        }
        for (int i9 = 0; i9 < dArr.length; i9++) {
            int i10 = i9;
            dArr[i10] = dArr[i10] / d;
        }
        SamplersLDAMyWordFeatures samplersLDAMyWordFeatures = new SamplersLDAMyWordFeatures();
        samplersLDAMyWordFeatures.Phi = tIntArrayListArr;
        samplersLDAMyWordFeatures.p_v = dArr;
        samplersLDAMyWordFeatures.init(tIntArrayList2.toArray(), tIntArrayList.toArray(), this.numTopics, size, data.size(), this.iters, this.beta, this.alpha);
        samplersLDAMyWordFeatures.GibbsSampling();
        double[][] documentDistribution = samplersLDAMyWordFeatures.documentDistribution();
        ArrayList arrayList = new ArrayList();
        arrayList.add(AttributeFactory.createAttribute("Doc", 2));
        arrayList.add(AttributeFactory.createAttribute("Topic", 2));
        for (int i11 = 0; i11 < this.numTopics; i11++) {
            arrayList.add(AttributeFactory.createAttribute("Topic_" + i11, 2));
        }
        MemoryExampleTable memoryExampleTable = new MemoryExampleTable(arrayList);
        DataRowFactory dataRowFactory = new DataRowFactory(0, '.');
        for (int i12 = 0; i12 < data.size(); i12++) {
            DataRow create = dataRowFactory.create(memoryExampleTable.getNumberOfAttributes());
            memoryExampleTable.addDataRow(create);
            create.set((Attribute) arrayList.get(0), i12 + 1);
            int i13 = -1;
            double d2 = 0.0d;
            for (int i14 = 0; i14 < this.numTopics; i14++) {
                if (documentDistribution[i14][i12] > d2) {
                    d2 = documentDistribution[i14][i12];
                    i13 = i14;
                }
                create.set((Attribute) arrayList.get(2 + i14), documentDistribution[i14][i12]);
            }
            create.set((Attribute) arrayList.get(1), i13);
        }
        this.outputDocs.deliver(new SimpleExampleSet(memoryExampleTable));
        double[][] wordDistribution = samplersLDAMyWordFeatures.wordDistribution();
        double[][] betas = samplersLDAMyWordFeatures.getBetas();
        ArrayList arrayList2 = new ArrayList();
        arrayList2.add(AttributeFactory.createAttribute("Word", 5));
        arrayList2.add(AttributeFactory.createAttribute("Word_id", 2));
        arrayList2.add(AttributeFactory.createAttribute("Topic", 2));
        for (int i15 = 0; i15 < this.numTopics; i15++) {
            arrayList2.add(AttributeFactory.createAttribute("Topic_" + i15, 2));
        }
        for (int i16 = 0; i16 < this.numTopics; i16++) {
            arrayList2.add(AttributeFactory.createAttribute("Betas_" + i16, 2));
        }
        MemoryExampleTable memoryExampleTable2 = new MemoryExampleTable(arrayList2);
        DataRowFactory dataRowFactory2 = new DataRowFactory(0, '.');
        for (int i17 = 0; i17 < size; i17++) {
            DataRow create2 = dataRowFactory2.create(memoryExampleTable2.getNumberOfAttributes());
            memoryExampleTable2.addDataRow(create2);
            create2.set((Attribute) arrayList2.get(0), ((Attribute) arrayList2.get(0)).getMapping().mapString(strArr[i17]));
            create2.set((Attribute) arrayList2.get(1), i17 + 1);
            int i18 = -1;
            double d3 = 0.0d;
            for (int i19 = 0; i19 < this.numTopics; i19++) {
                if (wordDistribution[i19][i17] > d3) {
                    d3 = wordDistribution[i19][i17];
                    i18 = i19;
                }
                create2.set((Attribute) arrayList2.get(3 + i19), wordDistribution[i19][i17]);
                create2.set((Attribute) arrayList2.get(3 + this.numTopics + i19), betas[i19][i17]);
            }
            create2.set((Attribute) arrayList2.get(2), i18);
        }
        this.outputWords.deliver(new SimpleExampleSet(memoryExampleTable2));
        String parameter = getParameter(PARAMETER_GROUP);
        ExampleSet dataOrNull2 = this.inputGroup.getDataOrNull(ExampleSet.class);
        if (parameter == "" || dataOrNull2 == null || dataOrNull2.getAttributes().get(parameter) == null) {
            return;
        }
        Attribute attribute2 = dataOrNull2.getAttributes().get(parameter);
        String[] strArr2 = new String[data.size()];
        int[] tokenToTopic = samplersLDAMyWordFeatures.getTokenToTopic();
        TObjectIntHashMap tObjectIntHashMap = new TObjectIntHashMap();
        TIntObjectHashMap tIntObjectHashMap = new TIntObjectHashMap();
        int i20 = 1;
        for (int i21 = 0; i21 < data.size(); i21++) {
            String valueAsString = dataOrNull2.getExample(i21).getValueAsString(attribute2);
            if (tObjectIntHashMap.contains(valueAsString)) {
                int[] iArr = (int[]) tIntObjectHashMap.get(tObjectIntHashMap.get(valueAsString));
                iArr[tokenToTopic[i21]] = iArr[tokenToTopic[i21]] + 1;
            } else {
                int[] iArr2 = new int[this.numTopics];
                iArr2[tokenToTopic[i21]] = 1;
                tIntObjectHashMap.put(i20, iArr2);
                tObjectIntHashMap.put(valueAsString, i20);
                i20++;
            }
        }
        ArrayList arrayList3 = new ArrayList();
        arrayList3.add(AttributeFactory.createAttribute("Group", 5));
        for (int i22 = 0; i22 < this.numTopics; i22++) {
            arrayList3.add(AttributeFactory.createAttribute("Topic_" + (i22 + 1), 3));
        }
        MemoryExampleTable memoryExampleTable3 = new MemoryExampleTable(arrayList3);
        DataRowFactory dataRowFactory3 = new DataRowFactory(0, '.');
        for (int i23 = 0; i23 < tObjectIntHashMap.keys().length; i23++) {
            int[] iArr3 = (int[]) tIntObjectHashMap.get(tObjectIntHashMap.get((String) tObjectIntHashMap.keys()[i23]));
            DataRow create3 = dataRowFactory3.create(memoryExampleTable3.getNumberOfAttributes());
            memoryExampleTable3.addDataRow(create3);
            create3.set((Attribute) arrayList3.get(0), ((Attribute) arrayList3.get(0)).getMapping().mapString(r0));
            double d4 = 0.0d;
            for (int i24 : iArr3) {
                d4 += i24;
            }
            for (int i25 = 0; i25 < iArr3.length; i25++) {
                create3.set((Attribute) arrayList3.get(i25 + 1), iArr3[i25] / d4);
            }
        }
        this.outputGroups.deliver(new SimpleExampleSet(memoryExampleTable3));
    }

    public List<ParameterType> getParameterTypes() {
        List<ParameterType> parameterTypes = super.getParameterTypes();
        parameterTypes.add(new ParameterTypeInt(PARAMETER_NUMITERATIONS, "Number of Iterations for Gibbs Sampling.", 1, Integer.MAX_VALUE, 2000));
        parameterTypes.add(new ParameterTypeInt(PARAMETER_NUMTOPICS, "Number of Topics.", 1, Integer.MAX_VALUE, 5));
        parameterTypes.add(new ParameterTypeDouble(PARAMETER_ALPHA, "Alpha", 0.0d, Double.MAX_VALUE, 0.25d));
        parameterTypes.add(new ParameterTypeDouble(PARAMETER_BETA, "Beta", 0.0d, Double.MAX_VALUE, 0.1d));
        parameterTypes.add(new ParameterTypeString(PARAMETER_GROUP, "Attribute name for grouping the word counts."));
        return parameterTypes;
    }

    public static void main(String[] strArr) {
    }
}
