package cc.mallet.classify;

import cc.mallet.classify.evaluate.ConfusionMatrix;
import cc.mallet.pipe.Classification2ConfidencePredictingFeatureVector;
import cc.mallet.pipe.Pipe;
import cc.mallet.types.InstanceList;
import cc.mallet.types.PerLabelInfoGain;
import cc.mallet.util.MalletLogger;
import java.util.logging.Logger;

/* loaded from: input_file:cc/mallet/classify/ConfidencePredictingClassifierTrainer.class */
public class ConfidencePredictingClassifierTrainer extends ClassifierTrainer<ConfidencePredictingClassifier> implements Boostable {
    private static Logger logger;
    ClassifierTrainer underlyingClassifierTrainer;
    MaxEntTrainer confidencePredictingClassifierTrainer;
    Pipe confidencePredictingPipe;
    static ConfusionMatrix confusionMatrix;
    ConfidencePredictingClassifier classifier;
    static final /* synthetic */ boolean $assertionsDisabled;

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // cc.mallet.classify.ClassifierTrainer
    public ConfidencePredictingClassifier getClassifier() {
        return this.classifier;
    }

    public ConfidencePredictingClassifierTrainer(ClassifierTrainer classifierTrainer, InstanceList instanceList, Pipe pipe) {
        this.confidencePredictingPipe = pipe;
        this.confidencePredictingClassifierTrainer = new MaxEntTrainer();
        this.validationSet = instanceList;
        this.underlyingClassifierTrainer = classifierTrainer;
    }

    public ConfidencePredictingClassifierTrainer(ClassifierTrainer classifierTrainer, InstanceList instanceList) {
        this(classifierTrainer, instanceList, new Classification2ConfidencePredictingFeatureVector());
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // cc.mallet.classify.ClassifierTrainer
    public ConfidencePredictingClassifier train(InstanceList instanceList) {
        instanceList.getFeatureSelection();
        logger.fine("Training underlying classifier");
        Classifier train = this.underlyingClassifierTrainer.train(instanceList);
        confusionMatrix = new ConfusionMatrix(new Trial(train, instanceList));
        if (!$assertionsDisabled && this.validationSet == null) {
            throw new AssertionError("This ClassifierTrainer requires a validation set.");
        }
        Trial trial = new Trial(train, this.validationSet);
        trial.getAccuracy();
        InstanceList instanceList2 = new InstanceList(this.confidencePredictingPipe);
        logger.fine("Creating confidence prediction instance list");
        for (int i = 0; i < trial.size(); i++) {
            Classification classification = trial.get(i);
            instanceList2.add(classification, null, classification.getInstance().getName(), classification.getInstance().getSource());
        }
        logger.info("Begin training ConfidencePredictingClassifier . . . ");
        MaxEnt train2 = this.confidencePredictingClassifierTrainer.train(instanceList2);
        logger.info("Accuracy at predicting correct/incorrect in training = " + train2.getAccuracy(instanceList2));
        new PerLabelInfoGain(instanceList);
        this.classifier = new ConfidencePredictingClassifier(train, train2);
        return this.classifier;
    }

    static {
        $assertionsDisabled = !ConfidencePredictingClassifierTrainer.class.desiredAssertionStatus();
        logger = MalletLogger.getLogger(ConfidencePredictingClassifierTrainer.class.getName());
        confusionMatrix = null;
    }
}
