package com.rapidminer.kobra.topicmodels;

import com.rapidminer.example.Attribute;
import com.rapidminer.example.Example;
import com.rapidminer.example.ExampleSet;
import com.rapidminer.example.set.SimpleExampleSet;
import com.rapidminer.example.table.AttributeFactory;
import com.rapidminer.example.table.DataRow;
import com.rapidminer.example.table.DataRowFactory;
import com.rapidminer.example.table.MemoryExampleTable;
import com.rapidminer.kobra.data.CCSMatrix;
import com.rapidminer.operator.Operator;
import com.rapidminer.operator.OperatorDescription;
import com.rapidminer.operator.OperatorException;
import com.rapidminer.operator.ports.InputPort;
import com.rapidminer.operator.ports.OutputPort;
import com.rapidminer.parameter.ParameterType;
import com.rapidminer.parameter.ParameterTypeBoolean;
import com.rapidminer.parameter.ParameterTypeDirectory;
import com.rapidminer.parameter.ParameterTypeDouble;
import com.rapidminer.parameter.ParameterTypeInt;
import com.rapidminer.parameter.ParameterTypeString;
import com.rapidminer.tools.RandomGenerator;
import gnu.trove.list.array.TIntArrayList;
import gnu.trove.map.hash.TIntObjectHashMap;
import gnu.trove.map.hash.TObjectIntHashMap;
import java.io.BufferedInputStream;
import java.io.BufferedOutputStream;
import java.io.BufferedWriter;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.OutputStream;
import java.io.OutputStreamWriter;
import java.io.PrintStream;
import java.io.UnsupportedEncodingException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import java.util.zip.ZipEntry;
import java.util.zip.ZipOutputStream;
import org.apache.commons.math3.geometry.VectorFormat;
import org.apache.lucene.analysis.pattern.PatternTokenizerFactory;

/* loaded from: input_file:com/rapidminer/kobra/topicmodels/DMRLDAOperator.class */
public class DMRLDAOperator extends Operator {
    int iters;
    int numTopics;
    double alpha;
    double beta;
    String path;
    private final InputPort input;
    private final InputPort inputGroup;
    private final InputPort inputDocFeatures;
    private final OutputPort outputWords;
    private final OutputPort outputDocs;
    private final OutputPort outputGroups;
    int topK;
    static String PARAMETER_NUMITERATIONS = "iterations";
    static String PARAMETER_NUMTOPICS = "number_of_topics";
    static String PARAMETER_ALPHA = "alpha";
    static String PARAMETER_BETA = "beta";
    static String PARAMETER_GROUP = PatternTokenizerFactory.GROUP;
    static String PARAMETER_SIGMA = "sigma";
    static String PARAMETER_LAMBDA = "lambda";
    static String PARAMETER_DFR = "dfr";
    static String PARAMETER_PATH = "path";

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:com/rapidminer/kobra/topicmodels/DMRLDAOperator$Word.class */
    public class Word implements Comparable<Word> {
        public String word = "";
        public double weight = 0.0d;
        public int id = -1;

        Word() {
        }

        @Override // java.lang.Comparable
        public int compareTo(Word word) {
            return this.weight == word.weight ? -this.word.compareTo(word.word) : this.weight < word.weight ? 1 : -1;
        }

        public String toString() {
            return this.word + "," + this.weight;
        }
    }

    public DMRLDAOperator(OperatorDescription operatorDescription) {
        super(operatorDescription);
        this.iters = 2000;
        this.numTopics = 4;
        this.alpha = 0.25d;
        this.beta = 0.1d;
        this.path = "/home/poelitz/work/Datasets/ResultData/acl2015/";
        this.input = getInputPorts().createPort("example set of documents as Bag-of-Words vectors with term occurrences");
        this.inputGroup = getInputPorts().createPort("example set of groups for each document (optional)");
        this.inputDocFeatures = getInputPorts().createPort("example set containing document features");
        this.outputWords = getOutputPorts().createPort("example set of word distributions for the topics");
        this.outputDocs = getOutputPorts().createPort("example set of topic distributions for the documents");
        this.outputGroups = getOutputPorts().createPort("example set of counts of topics assigned to groups");
        this.topK = 40;
    }

    public void doWork() throws OperatorException {
        boolean parameterAsBoolean = getParameterAsBoolean("use_local_random_seed");
        int parameterAsInt = getParameterAsInt("local_random_seed");
        this.iters = getParameterAsInt(PARAMETER_NUMITERATIONS);
        this.numTopics = getParameterAsInt(PARAMETER_NUMTOPICS);
        this.alpha = getParameterAsDouble(PARAMETER_ALPHA);
        this.beta = getParameterAsDouble(PARAMETER_BETA);
        this.path = getParameterAsString(PARAMETER_PATH);
        TIntArrayList tIntArrayList = new TIntArrayList();
        TIntArrayList tIntArrayList2 = new TIntArrayList();
        ExampleSet data = this.input.getData(ExampleSet.class);
        int size = data.getExample(0).getAttributes().size();
        String[] strArr = new String[size];
        for (int i = 0; i < data.size(); i++) {
            int i2 = i;
            Example example = data.getExample(i);
            int i3 = 0;
            for (Attribute attribute : example.getAttributes()) {
                strArr[i3] = attribute.getName();
                int i4 = i3;
                i3++;
                double value = example.getValue(attribute);
                if (value != 0.0d) {
                    for (int i5 = 0; i5 < ((int) value); i5++) {
                        tIntArrayList2.add(i2);
                        tIntArrayList.add(i4);
                    }
                }
            }
        }
        double[][] dArr = (double[][]) null;
        ExampleSet dataOrNull = this.inputDocFeatures.getDataOrNull(ExampleSet.class);
        if (dataOrNull != null) {
            dArr = new double[dataOrNull.size()][dataOrNull.getExample(0).getAttributes().size()];
            for (int i6 = 0; i6 < dataOrNull.size(); i6++) {
                Example example2 = dataOrNull.getExample(i6);
                int i7 = 0;
                Iterator it = example2.getAttributes().iterator();
                while (it.hasNext()) {
                    dArr[i6][i7] = example2.getValue((Attribute) it.next());
                    i7++;
                }
            }
        }
        SamplersDMRLDA samplersDMRLDA = new SamplersDMRLDA();
        samplersDMRLDA.features = dArr;
        samplersDMRLDA.sigma = getParameterAsDouble(PARAMETER_SIGMA);
        samplersDMRLDA.lambda = getParameterAsDouble(PARAMETER_LAMBDA);
        samplersDMRLDA.numFeatures = dArr[0].length;
        samplersDMRLDA.init(tIntArrayList2.toArray(), tIntArrayList.toArray(), this.numTopics, size, data.size(), this.iters, this.beta, this.alpha, parameterAsBoolean, parameterAsInt);
        samplersDMRLDA.GibbsSampling();
        double[][] documentDistribution = samplersDMRLDA.documentDistribution();
        boolean parameterAsBoolean2 = getParameterAsBoolean(PARAMETER_DFR);
        if (parameterAsBoolean2) {
            writeDT(documentDistribution);
        }
        ArrayList arrayList = new ArrayList();
        arrayList.add(AttributeFactory.createAttribute("Doc", 2));
        arrayList.add(AttributeFactory.createAttribute("Topic", 2));
        for (int i8 = 0; i8 < this.numTopics; i8++) {
            arrayList.add(AttributeFactory.createAttribute("Topic_" + i8, 2));
        }
        MemoryExampleTable memoryExampleTable = new MemoryExampleTable(arrayList);
        DataRowFactory dataRowFactory = new DataRowFactory(0, '.');
        for (int i9 = 0; i9 < data.size(); i9++) {
            DataRow create = dataRowFactory.create(memoryExampleTable.getNumberOfAttributes());
            memoryExampleTable.addDataRow(create);
            create.set((Attribute) arrayList.get(0), i9 + 1);
            int i10 = -1;
            double d = 0.0d;
            for (int i11 = 0; i11 < this.numTopics; i11++) {
                if (documentDistribution[i11][i9] > d) {
                    d = documentDistribution[i11][i9];
                    i10 = i11;
                }
                create.set((Attribute) arrayList.get(2 + i11), documentDistribution[i11][i9]);
            }
            create.set((Attribute) arrayList.get(1), i10);
        }
        this.outputDocs.deliver(new SimpleExampleSet(memoryExampleTable));
        double[][] wordDistribution = samplersDMRLDA.wordDistribution();
        if (parameterAsBoolean2) {
            writeTW(strArr, wordDistribution, this.alpha);
            writeTopWords(strArr, wordDistribution);
        }
        ArrayList arrayList2 = new ArrayList();
        arrayList2.add(AttributeFactory.createAttribute("Word", 5));
        arrayList2.add(AttributeFactory.createAttribute("Word_id", 2));
        arrayList2.add(AttributeFactory.createAttribute("Topic", 2));
        for (int i12 = 0; i12 < this.numTopics; i12++) {
            arrayList2.add(AttributeFactory.createAttribute("Topic_" + i12, 2));
        }
        MemoryExampleTable memoryExampleTable2 = new MemoryExampleTable(arrayList2);
        DataRowFactory dataRowFactory2 = new DataRowFactory(0, '.');
        for (int i13 = 0; i13 < size; i13++) {
            DataRow create2 = dataRowFactory2.create(memoryExampleTable2.getNumberOfAttributes());
            memoryExampleTable2.addDataRow(create2);
            create2.set((Attribute) arrayList2.get(0), ((Attribute) arrayList2.get(0)).getMapping().mapString(strArr[i13]));
            create2.set((Attribute) arrayList2.get(1), i13 + 1);
            int i14 = -1;
            double d2 = 0.0d;
            for (int i15 = 0; i15 < this.numTopics; i15++) {
                if (wordDistribution[i15][i13] > d2) {
                    d2 = wordDistribution[i15][i13];
                    i14 = i15;
                }
                create2.set((Attribute) arrayList2.get(3 + i15), wordDistribution[i15][i13]);
            }
            create2.set((Attribute) arrayList2.get(2), i14);
        }
        this.outputWords.deliver(new SimpleExampleSet(memoryExampleTable2));
        String parameter = getParameter(PARAMETER_GROUP);
        ExampleSet dataOrNull2 = this.inputGroup.getDataOrNull(ExampleSet.class);
        if (parameter == "" || dataOrNull2 == null || dataOrNull2.getAttributes().get(parameter) == null) {
            return;
        }
        Attribute attribute2 = dataOrNull2.getAttributes().get(parameter);
        String[] strArr2 = new String[data.size()];
        int[] tokenToTopic = samplersDMRLDA.getTokenToTopic();
        TObjectIntHashMap tObjectIntHashMap = new TObjectIntHashMap();
        TIntObjectHashMap tIntObjectHashMap = new TIntObjectHashMap();
        int i16 = 1;
        for (int i17 = 0; i17 < data.size(); i17++) {
            String valueAsString = dataOrNull2.getExample(i17).getValueAsString(attribute2);
            if (tObjectIntHashMap.contains(valueAsString)) {
                int[] iArr = (int[]) tIntObjectHashMap.get(tObjectIntHashMap.get(valueAsString));
                iArr[tokenToTopic[i17]] = iArr[tokenToTopic[i17]] + 1;
            } else {
                int[] iArr2 = new int[this.numTopics];
                iArr2[tokenToTopic[i17]] = 1;
                tIntObjectHashMap.put(i16, iArr2);
                tObjectIntHashMap.put(valueAsString, i16);
                i16++;
            }
        }
        ArrayList arrayList3 = new ArrayList();
        arrayList3.add(AttributeFactory.createAttribute("Group", 5));
        for (int i18 = 0; i18 < this.numTopics; i18++) {
            arrayList3.add(AttributeFactory.createAttribute("Topic_" + (i18 + 1), 3));
        }
        MemoryExampleTable memoryExampleTable3 = new MemoryExampleTable(arrayList3);
        DataRowFactory dataRowFactory3 = new DataRowFactory(0, '.');
        for (int i19 = 0; i19 < tObjectIntHashMap.keys().length; i19++) {
            int[] iArr3 = (int[]) tIntObjectHashMap.get(tObjectIntHashMap.get((String) tObjectIntHashMap.keys()[i19]));
            DataRow create3 = dataRowFactory3.create(memoryExampleTable3.getNumberOfAttributes());
            memoryExampleTable3.addDataRow(create3);
            create3.set((Attribute) arrayList3.get(0), ((Attribute) arrayList3.get(0)).getMapping().mapString(r0));
            double d3 = 0.0d;
            for (int i20 : iArr3) {
                d3 += i20;
            }
            for (int i21 = 0; i21 < iArr3.length; i21++) {
                create3.set((Attribute) arrayList3.get(i21 + 1), iArr3[i21] / d3);
            }
        }
        this.outputGroups.deliver(new SimpleExampleSet(memoryExampleTable3));
    }

    public void writeDT(int[] iArr, double[] dArr) {
        String str = "\"i\":[" + iArr[0];
        String str2 = "\"p\":[0";
        String str3 = "\"x\":[" + dArr[0];
        for (int i = 1; i < iArr.length; i++) {
            str = str + "," + iArr[i];
            str2 = str2 + "," + i;
            str3 = str3 + "," + ((int) (dArr[i] * 1000.0d));
        }
        System.out.println(VectorFormat.DEFAULT_PREFIX + str + "]," + str2 + "]," + str3 + "]}");
    }

    public void writeDT(double[][] dArr) {
        double[][] dArr2 = new double[dArr[0].length][dArr.length];
        for (int i = 0; i < dArr2.length; i++) {
            for (int i2 = 0; i2 < dArr2[i].length; i2++) {
                dArr2[i][i2] = (int) (dArr[i2][i] * 100.0d);
            }
        }
        CCSMatrix from2DArray = CCSMatrix.from2DArray(dArr2);
        int[] iArr = from2DArray.columnPointers;
        int[] iArr2 = from2DArray.rowIndices;
        double[] dArr3 = from2DArray.values;
        String str = this.path + "dt.json";
        String str2 = "\"i\": [" + iArr2[0] + " ";
        for (int i3 = 1; i3 < iArr2.length; i3++) {
            str2 = str2 + ", " + iArr2[i3];
        }
        String str3 = "\"p\": [" + iArr[0] + " ";
        for (int i4 = 1; i4 < iArr.length; i4++) {
            str3 = str3 + ", " + iArr[i4];
        }
        String str4 = "\"x\": [" + ((int) dArr3[0]) + " ";
        for (int i5 = 1; i5 < dArr3.length; i5++) {
            str4 = str4 + ", " + ((int) dArr3[i5]);
        }
        writeAndZip(VectorFormat.DEFAULT_PREFIX + str2 + "]," + str3 + "]," + str4 + "]}", str);
    }

    public void writeAndZip(String str, String str2) {
        BufferedWriter bufferedWriter = null;
        try {
            bufferedWriter = new BufferedWriter(new OutputStreamWriter(new FileOutputStream(str2), "UTF-8"));
        } catch (IOException e) {
            e.printStackTrace();
        }
        try {
            bufferedWriter.write(str);
        } catch (IOException e2) {
            e2.printStackTrace();
        }
        try {
            bufferedWriter.close();
        } catch (IOException e3) {
            e3.printStackTrace();
        }
        try {
            ZipOutputStream zipOutputStream = new ZipOutputStream(new BufferedOutputStream(new FileOutputStream(str2 + ".zip")));
            byte[] bArr = new byte[2048];
            BufferedInputStream bufferedInputStream = new BufferedInputStream(new FileInputStream(str2), 2048);
            zipOutputStream.putNextEntry(new ZipEntry(str2.substring(str2.lastIndexOf("/") + 1, str2.length())));
            while (true) {
                int read = bufferedInputStream.read(bArr, 0, 2048);
                if (read == -1) {
                    bufferedInputStream.close();
                    zipOutputStream.close();
                    return;
                }
                zipOutputStream.write(bArr, 0, read);
            }
        } catch (Exception e4) {
            e4.printStackTrace();
        }
    }

    public void writeTopWords(String[] strArr, double[][] dArr) {
        Word[][] wordArr = new Word[dArr.length][dArr[0].length];
        for (int i = 0; i < dArr.length; i++) {
            for (int i2 = 0; i2 < dArr[i].length; i2++) {
                Word word = new Word();
                word.id = i2;
                word.weight = dArr[i][i2];
                word.word = strArr[i2];
                wordArr[i][i2] = word;
            }
        }
        for (int i3 = 0; i3 < dArr.length; i3++) {
            Arrays.sort(wordArr[i3]);
        }
        try {
            System.setOut(new PrintStream((OutputStream) new FileOutputStream(this.path + "topWords.txt", false), true, "UTF-8"));
        } catch (FileNotFoundException e) {
            e.printStackTrace();
        } catch (UnsupportedEncodingException e2) {
            e2.printStackTrace();
        }
        for (int i4 = 0; i4 < this.topK && i4 < strArr.length; i4++) {
            String str = "";
            for (int i5 = 0; i5 < dArr.length; i5++) {
                str = str + i5 + "," + wordArr[i5][i4].toString() + ",";
            }
            System.out.println(str);
        }
        System.setOut(System.out);
    }

    public void writeTW(String[] strArr, double[][] dArr, double d) {
        String str = "\"tw\":[";
        String str2 = "{\"alpha\":[" + d;
        for (int i = 1; i < dArr.length; i++) {
            str2 = str2 + "," + d;
        }
        for (int i2 = 0; i2 < dArr.length; i2++) {
            double[] dArr2 = new double[dArr[i2].length];
            for (int i3 = 0; i3 < dArr[i2].length; i3++) {
                dArr2[i3] = dArr[i2][i3];
            }
            Arrays.sort(dArr2);
            if (this.topK >= dArr2.length) {
            }
            double d2 = dArr2[dArr2.length - this.topK];
            String str3 = "\"words\":[";
            String str4 = "{\"weights\":[";
            for (int i4 = 0; i4 < dArr[i2].length; i4++) {
                if (dArr[i2][i4] >= d2) {
                    str3 = str3 + "\"" + strArr[i4] + "\",";
                    str4 = str4 + dArr[i2][i4] + ",";
                }
            }
            str = str + (str4.substring(0, str4.length() - 1) + "],") + (str3.substring(0, str3.length() - 1) + "]}") + ",";
        }
        writeAndZip((str2 + "],") + str.substring(0, str.length() - 1) + "]}", this.path + "tw.json");
    }

    public List<ParameterType> getParameterTypes() {
        List<ParameterType> parameterTypes = super.getParameterTypes();
        parameterTypes.add(new ParameterTypeInt(PARAMETER_NUMITERATIONS, "Number of Iterations for Gibbs Sampling.", 1, Integer.MAX_VALUE, 2000));
        parameterTypes.add(new ParameterTypeInt(PARAMETER_NUMTOPICS, "Number of Topics.", 1, Integer.MAX_VALUE, 5));
        parameterTypes.add(new ParameterTypeDouble(PARAMETER_ALPHA, "Alpha metaparameter for Dirichlet", 0.0d, Double.MAX_VALUE, 0.25d));
        parameterTypes.add(new ParameterTypeDouble(PARAMETER_BETA, "Beta metaparameter for Dirichlet", 0.0d, Double.MAX_VALUE, 0.1d));
        parameterTypes.add(new ParameterTypeDouble(PARAMETER_LAMBDA, "lambda: weight for l1 regularization for DMR", 0.0d, Double.MAX_VALUE, 0.1d));
        parameterTypes.add(new ParameterTypeDouble(PARAMETER_SIGMA, "sigma: variance of weights for document features in DMR", 0.0d, Double.MAX_VALUE, 0.1d));
        parameterTypes.add(new ParameterTypeString(PARAMETER_GROUP, "Attribute name for grouping the word counts."));
        parameterTypes.add(new ParameterTypeBoolean(PARAMETER_DFR, "Write results out for dfr browser", true, false));
        parameterTypes.add(new ParameterTypeDirectory(PARAMETER_PATH, "Path for dfr files.", "/home/poelitz/work/Datasets/ResultData/acl2015/"));
        parameterTypes.addAll(RandomGenerator.getRandomGeneratorParameters(this));
        return parameterTypes;
    }

    public static void main(String[] strArr) {
    }
}
