package cc.mallet.types.tests;

import cc.mallet.classify.MaxEnt;
import cc.mallet.classify.MaxEntTrainer;
import cc.mallet.classify.Trial;
import cc.mallet.pipe.FeatureSequence2FeatureVector;
import cc.mallet.pipe.Pipe;
import cc.mallet.pipe.SerialPipes;
import cc.mallet.pipe.Target2Label;
import cc.mallet.pipe.TokenSequence2FeatureSequence;
import cc.mallet.pipe.iterator.RandomTokenSequenceIterator;
import cc.mallet.types.Alphabet;
import cc.mallet.types.Dirichlet;
import cc.mallet.types.InstanceList;
import cc.mallet.types.PagedInstanceList;
import cc.mallet.util.Randoms;
import java.io.File;
import jregex.WildcardPattern;
import junit.framework.Test;
import junit.framework.TestCase;
import junit.framework.TestSuite;
import junit.textui.TestRunner;
import weka.gui.GenericObjectEditorHistory;

/* loaded from: input_file:cc/mallet/types/tests/TestPagedInstanceList.class */
public class TestPagedInstanceList extends TestCase {
    public TestPagedInstanceList(String str) {
        super(str);
    }

    public static Test suite() {
        return new TestSuite((Class<? extends TestCase>) TestPagedInstanceList.class);
    }

    private static Alphabet dictOfSize(int i) {
        Alphabet alphabet = new Alphabet();
        for (int i2 = 0; i2 < i; i2++) {
            alphabet.lookupIndex("feature" + i2);
        }
        return alphabet;
    }

    public void testRandomTrained() {
        SerialPipes serialPipes = new SerialPipes(new Pipe[]{new TokenSequence2FeatureSequence(), new FeatureSequence2FeatureVector(), new Target2Label()});
        assertEquals(testRandomTrainedOn(new InstanceList(serialPipes)), testRandomTrainedOn(new PagedInstanceList(serialPipes, 700, GenericObjectEditorHistory.MAX_HISTORY_LENGTH, new File(WildcardPattern.ANY_CHAR))), 0.01d);
    }

    private double testRandomTrainedOn(InstanceList instanceList) {
        MaxEntTrainer maxEntTrainer = new MaxEntTrainer();
        Alphabet dictOfSize = dictOfSize(3);
        String[] strArr = {"class0", "class1", "class2"};
        Randoms randoms = new Randoms(1);
        instanceList.addThruPipe(new RandomTokenSequenceIterator(randoms, new Dirichlet(dictOfSize, 2.0d), 30.0d, 0.0d, 10.0d, 200.0d, strArr));
        InstanceList instanceList2 = new InstanceList(instanceList.getPipe());
        instanceList2.addThruPipe(new RandomTokenSequenceIterator(randoms, new Dirichlet(dictOfSize, 2.0d), 30.0d, 0.0d, 10.0d, 200.0d, strArr));
        System.out.println("Training set size = " + instanceList.size());
        System.out.println("Testing set size = " + instanceList2.size());
        MaxEnt train = maxEntTrainer.train(instanceList);
        System.out.println("Accuracy on training set:");
        System.out.println(train.getClass().getName() + ": " + new Trial(train, instanceList).getAccuracy());
        System.out.println("Accuracy on testing set:");
        double accuracy = new Trial(train, instanceList2).getAccuracy();
        System.out.println(train.getClass().getName() + ": " + accuracy);
        return accuracy;
    }

    public static void main(String[] strArr) throws Throwable {
        TestSuite testSuite;
        if (strArr.length > 0) {
            testSuite = new TestSuite();
            for (String str : strArr) {
                testSuite.addTest(new TestPagedInstanceList(str));
            }
        } else {
            testSuite = (TestSuite) suite();
        }
        TestRunner.run(testSuite);
    }
}
