package cc.mallet.classify;

import cc.mallet.classify.ClassifierTrainer;
import cc.mallet.optimize.ConjugateGradient;
import cc.mallet.optimize.InvalidOptimizableException;
import cc.mallet.optimize.LimitedMemoryBFGS;
import cc.mallet.optimize.Optimizable;
import cc.mallet.optimize.OptimizationException;
import cc.mallet.optimize.Optimizer;
import cc.mallet.types.Alphabet;
import cc.mallet.types.InstanceList;
import cc.mallet.util.MalletLogger;
import cc.mallet.util.MalletProgressMessageLogger;
import java.io.Serializable;
import java.util.logging.Logger;
import org.apache.commons.io.IOUtils;

/* loaded from: input_file:cc/mallet/classify/MaxEntTrainer.class */
public class MaxEntTrainer extends ClassifierTrainer<MaxEnt> implements ClassifierTrainer.ByOptimization<MaxEnt>, Boostable, Serializable {
    private static Logger logger;
    private static Logger progressLogger;
    int numIterations;
    static final double DEFAULT_GAUSSIAN_PRIOR_VARIANCE = 1.0d;
    static final double DEFAULT_L1_WEIGHT = 0.0d;
    static final Class DEFAULT_MAXIMIZER_CLASS;
    double gaussianPriorVariance;
    double l1Weight;
    Class maximizerClass;
    InstanceList trainingSet;
    MaxEnt initialClassifier;
    MaxEntOptimizableByLabelLikelihood optimizable;
    Optimizer optimizer;
    static final /* synthetic */ boolean $assertionsDisabled;

    public MaxEntTrainer() {
        this.numIterations = Integer.MAX_VALUE;
        this.gaussianPriorVariance = 1.0d;
        this.l1Weight = 0.0d;
        this.maximizerClass = DEFAULT_MAXIMIZER_CLASS;
        this.trainingSet = null;
        this.optimizable = null;
        this.optimizer = null;
    }

    public MaxEntTrainer(MaxEnt maxEnt) {
        this.numIterations = Integer.MAX_VALUE;
        this.gaussianPriorVariance = 1.0d;
        this.l1Weight = 0.0d;
        this.maximizerClass = DEFAULT_MAXIMIZER_CLASS;
        this.trainingSet = null;
        this.optimizable = null;
        this.optimizer = null;
        this.initialClassifier = maxEnt;
    }

    public MaxEntTrainer(double d) {
        this.numIterations = Integer.MAX_VALUE;
        this.gaussianPriorVariance = 1.0d;
        this.l1Weight = 0.0d;
        this.maximizerClass = DEFAULT_MAXIMIZER_CLASS;
        this.trainingSet = null;
        this.optimizable = null;
        this.optimizer = null;
        this.gaussianPriorVariance = d;
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // cc.mallet.classify.ClassifierTrainer
    public MaxEnt getClassifier() {
        return this.optimizable != null ? this.optimizable.getClassifier() : this.initialClassifier;
    }

    public void setClassifier(MaxEnt maxEnt) {
        if (!$assertionsDisabled && this.trainingSet != null && !Alphabet.alphabetsMatch(maxEnt, this.trainingSet)) {
            throw new AssertionError();
        }
        if (this.initialClassifier != maxEnt) {
            this.initialClassifier = maxEnt;
            this.optimizable = null;
            this.optimizer = null;
        }
    }

    public Optimizable getOptimizable() {
        return this.optimizable;
    }

    public MaxEntOptimizableByLabelLikelihood getOptimizable(InstanceList instanceList) {
        return getOptimizable(instanceList, getClassifier());
    }

    public MaxEntOptimizableByLabelLikelihood getOptimizable(InstanceList instanceList, MaxEnt maxEnt) {
        if (instanceList != this.trainingSet || this.initialClassifier != maxEnt) {
            this.trainingSet = instanceList;
            this.initialClassifier = maxEnt;
            if (this.optimizable == null || this.optimizable.trainingList != instanceList) {
                this.optimizable = new MaxEntOptimizableByLabelLikelihood(instanceList, maxEnt);
                if (this.l1Weight == 0.0d) {
                    this.optimizable.setGaussianPriorVariance(this.gaussianPriorVariance);
                } else {
                    this.optimizable.useNoPrior();
                }
                this.optimizer = null;
            }
        }
        return this.optimizable;
    }

    public Optimizer getOptimizer() {
        if (this.optimizer == null && this.optimizable != null) {
            this.optimizer = new ConjugateGradient(this.optimizable);
        }
        return this.optimizer;
    }

    public Optimizer getOptimizer(InstanceList instanceList) {
        if (instanceList != this.trainingSet || this.optimizable == null) {
            getOptimizable(instanceList);
            this.optimizer = null;
        }
        if (this.optimizer == null) {
            this.optimizer = new LimitedMemoryBFGS(this.optimizable);
        }
        return this.optimizer;
    }

    public MaxEntTrainer setNumIterations(int i) {
        this.numIterations = i;
        return this;
    }

    @Override // cc.mallet.classify.ClassifierTrainer.ByOptimization
    public int getIteration() {
        return this.optimizable == null ? 0 : Integer.MAX_VALUE;
    }

    public MaxEntTrainer setGaussianPriorVariance(double d) {
        this.gaussianPriorVariance = d;
        return this;
    }

    public MaxEntTrainer setL1Weight(double d) {
        this.l1Weight = d;
        return this;
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // cc.mallet.classify.ClassifierTrainer
    public MaxEnt train(InstanceList instanceList) {
        return train(instanceList, this.numIterations);
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // cc.mallet.classify.ClassifierTrainer.ByOptimization
    public MaxEnt train(InstanceList instanceList, int i) {
        logger.fine("trainingSet.size() = " + instanceList.size());
        getOptimizer(instanceList);
        for (int i2 = 0; i2 < i; i2++) {
            try {
                this.finishedTraining = this.optimizer.optimize(1);
            } catch (InvalidOptimizableException e) {
                e.printStackTrace();
                logger.warning("Catching InvalidOptimizatinException! saying converged.");
                this.finishedTraining = true;
            } catch (OptimizationException e2) {
                e2.printStackTrace();
                logger.info("Catching OptimizationException; saying converged.");
                this.finishedTraining = true;
            }
            if (this.finishedTraining) {
                break;
            }
        }
        if (i == Integer.MAX_VALUE) {
            this.optimizer = null;
            getOptimizer(instanceList);
            try {
                this.finishedTraining = this.optimizer.optimize();
            } catch (InvalidOptimizableException e3) {
                e3.printStackTrace();
                logger.warning("Catching InvalidOptimizatinException! saying converged.");
                this.finishedTraining = true;
            } catch (OptimizationException e4) {
                e4.printStackTrace();
                logger.info("Catching OptimizationException; saying converged.");
                this.finishedTraining = true;
            }
        }
        progressLogger.info(IOUtils.LINE_SEPARATOR_UNIX);
        return this.optimizable.getClassifier();
    }

    public String toString() {
        StringBuilder sb = new StringBuilder();
        sb.append("MaxEntTrainer");
        if (this.numIterations < Integer.MAX_VALUE) {
            sb.append(",numIterations=" + this.numIterations);
        }
        if (this.l1Weight != 0.0d) {
            sb.append(",l1Weight=" + this.l1Weight);
        } else {
            sb.append(",gaussianPriorVariance=" + this.gaussianPriorVariance);
        }
        return sb.toString();
    }

    static {
        $assertionsDisabled = !MaxEntTrainer.class.desiredAssertionStatus();
        logger = MalletLogger.getLogger(MaxEntTrainer.class.getName());
        progressLogger = MalletProgressMessageLogger.getLogger(MaxEntTrainer.class.getName() + "-pl");
        DEFAULT_MAXIMIZER_CLASS = LimitedMemoryBFGS.class;
    }
}
