package com.rapidminer.kobra.opt;

import cc.mallet.optimize.InvalidOptimizableException;
import cc.mallet.optimize.Optimizable;
import cc.mallet.optimize.Optimizer;
import cc.mallet.types.MatrixOps;
import cc.mallet.util.MalletLogger;
import java.util.LinkedList;
import java.util.logging.Logger;

/* loaded from: input_file:com/rapidminer/kobra/opt/MyOrthantWiseLimitedMemoryBFGS.class */
public class MyOrthantWiseLimitedMemoryBFGS implements Optimizer {
    private static Logger logger = MalletLogger.getLogger(MyOrthantWiseLimitedMemoryBFGS.class.getName());
    boolean converged;
    Optimizable.ByGradientValue optimizable;
    String optName;
    final int maxIterations = 1000;
    final double tolerance = 1.0E-6d;
    final double gradientTolerance = 1.0E-5d;
    final double eps = 1.0E-4d;
    double l1Weight;
    final int m = 6;
    double oldValue;
    double value;
    double yDotY;
    double[] grad;
    double[] oldGrad;
    double[] direction;
    double[] steepestDescentDirection;
    double[] parameters;
    double[] oldParameters;
    LinkedList<double[]> s;
    LinkedList<double[]> y;
    LinkedList<Double> rhos;
    double[] alphas;
    int iterations;
    public boolean prox;
    public boolean gl;
    public double lambda;

    public MyOrthantWiseLimitedMemoryBFGS(Optimizable.ByGradientValue byGradientValue) {
        this(byGradientValue, 0.0d);
    }

    public MyOrthantWiseLimitedMemoryBFGS(Optimizable.ByGradientValue byGradientValue, double d) {
        this.converged = false;
        this.maxIterations = 1000;
        this.tolerance = 1.0E-6d;
        this.gradientTolerance = 1.0E-5d;
        this.eps = 1.0E-4d;
        this.m = 6;
        this.prox = false;
        this.gl = false;
        this.lambda = 0.1d;
        this.optimizable = byGradientValue;
        this.l1Weight = d;
        String[] split = this.optimizable.getClass().getName().split("\\.");
        this.optName = split[split.length - 1];
        this.iterations = 0;
        this.s = new LinkedList<>();
        this.y = new LinkedList<>();
        this.rhos = new LinkedList<>();
        this.alphas = new double[6];
        MatrixOps.setAll(this.alphas, 0.0d);
        this.yDotY = 0.0d;
        int numParameters = this.optimizable.getNumParameters();
        this.parameters = new double[numParameters];
        this.optimizable.getParameters(this.parameters);
        this.value = evalL1();
        this.grad = new double[numParameters];
        evalGradient();
        this.direction = new double[numParameters];
        this.steepestDescentDirection = new double[numParameters];
        this.oldParameters = new double[numParameters];
        this.oldGrad = new double[numParameters];
    }

    @Override // cc.mallet.optimize.Optimizer
    public Optimizable getOptimizable() {
        return this.optimizable;
    }

    @Override // cc.mallet.optimize.Optimizer
    public boolean isConverged() {
        return this.converged;
    }

    public int getIteration() {
        return this.iterations;
    }

    @Override // cc.mallet.optimize.Optimizer
    public boolean optimize() {
        return optimize(Integer.MAX_VALUE);
    }

    @Override // cc.mallet.optimize.Optimizer
    public boolean optimize(int i) {
        logger.fine("Entering OWL-BFGS.optimize(). L1 weight=" + this.l1Weight + " Initial Value=" + this.value);
        for (int i2 = 0; i2 < i; i2++) {
            makeSteepestDescDir();
            mapDirByInverseHessian(this.yDotY);
            fixDirSigns();
            storeSrcInDest(this.parameters, this.oldParameters);
            storeSrcInDest(this.grad, this.oldGrad);
            try {
                backTrackingLineSearch();
                evalGradient();
                if (checkValueTerminationCondition()) {
                    logger.fine("Exiting OWL-BFGS on termination #1:");
                    logger.fine("value difference below tolerance (oldValue: " + this.oldValue + " newValue: " + this.value);
                    this.converged = true;
                    return true;
                }
                if (checkGradientTerminationCondition()) {
                    logger.fine("Exiting OWL-BFGS on termination #2:");
                    logger.fine("gradient=" + MatrixOps.twoNorm(this.grad) + " < 1.0E-5");
                    this.converged = true;
                    return true;
                }
                this.iterations++;
                if (this.iterations > 1000) {
                    logger.fine("Too many iterations in OWL-BFGS. Continuing with current parameters.");
                    this.converged = true;
                    return true;
                }
            } catch (Exception e) {
                e.printStackTrace();
                return false;
            }
        }
        return false;
    }

    private double evalL1() {
        double d = -this.optimizable.getValue();
        double d2 = 0.0d;
        if (this.l1Weight > 0.0d) {
            for (double d3 : this.parameters) {
                if (!Double.isInfinite(d3)) {
                    d2 += Math.abs(d3) * this.l1Weight;
                }
            }
        }
        logger.fine("getValue() (" + this.optName + ".getValue() = " + d + " + |w|=" + d2 + ") = " + (d + d2));
        return d + d2;
    }

    private void evalGradient() {
        this.optimizable.getValueGradient(this.grad);
        adjustGradForInfiniteParams(this.grad);
        MatrixOps.timesEquals(this.grad, -1.0d);
    }

    private void makeSteepestDescDir() {
        if (this.l1Weight == 0.0d) {
            for (int i = 0; i < this.grad.length; i++) {
                this.direction[i] = -this.grad[i];
            }
        } else {
            for (int i2 = 0; i2 < this.grad.length; i2++) {
                if (this.parameters[i2] < 0.0d) {
                    this.direction[i2] = (-this.grad[i2]) + this.l1Weight;
                } else if (this.parameters[i2] > 0.0d) {
                    this.direction[i2] = (-this.grad[i2]) - this.l1Weight;
                } else if (this.grad[i2] < (-this.l1Weight)) {
                    this.direction[i2] = (-this.grad[i2]) - this.l1Weight;
                } else if (this.grad[i2] > this.l1Weight) {
                    this.direction[i2] = (-this.grad[i2]) + this.l1Weight;
                } else {
                    this.direction[i2] = 0.0d;
                }
            }
        }
        storeSrcInDest(this.direction, this.steepestDescentDirection);
    }

    private void adjustGradForInfiniteParams(double[] dArr) {
        for (int i = 0; i < this.parameters.length; i++) {
            if (Double.isInfinite(this.parameters[i])) {
                dArr[i] = 0.0d;
            }
        }
    }

    private void mapDirByInverseHessian(double d) {
        if (this.s.size() == 0) {
            return;
        }
        int size = this.s.size();
        for (int i = size - 1; i >= 0; i--) {
            this.alphas[i] = (-MatrixOps.dotProduct(this.s.get(i), this.direction)) / this.rhos.get(i).doubleValue();
            MatrixOps.plusEquals(this.direction, this.y.get(i), this.alphas[i]);
        }
        double doubleValue = this.rhos.get(size - 1).doubleValue() / d;
        logger.fine("Direction multiplier = " + doubleValue);
        MatrixOps.timesEquals(this.direction, doubleValue);
        for (int i2 = 0; i2 < size; i2++) {
            MatrixOps.plusEquals(this.direction, this.s.get(i2), (-this.alphas[i2]) - (MatrixOps.dotProduct(this.y.get(i2), this.direction) / this.rhos.get(i2).doubleValue()));
        }
    }

    private void fixDirSigns() {
        if (this.l1Weight > 0.0d) {
            for (int i = 0; i < this.direction.length; i++) {
                if (this.direction[i] * this.steepestDescentDirection[i] <= 0.0d) {
                    this.direction[i] = 0.0d;
                }
            }
        }
    }

    private double dirDeriv() {
        if (this.l1Weight == 0.0d) {
            return MatrixOps.dotProduct(this.direction, this.grad);
        }
        double d = 0.0d;
        for (int i = 0; i < this.direction.length; i++) {
            if (this.direction[i] != 0.0d) {
                if (this.parameters[i] < 0.0d) {
                    d += this.direction[i] * (this.grad[i] - this.l1Weight);
                } else if (this.parameters[i] > 0.0d) {
                    d += this.direction[i] * (this.grad[i] + this.l1Weight);
                } else if (this.direction[i] < 0.0d) {
                    d += this.direction[i] * (this.grad[i] - this.l1Weight);
                } else if (this.direction[i] > 0.0d) {
                    d += this.direction[i] * (this.grad[i] + this.l1Weight);
                }
            }
        }
        return d;
    }

    private double shift() {
        double[] removeFirst;
        double[] removeFirst2;
        if (this.s.size() < 6) {
            removeFirst = new double[this.parameters.length];
            removeFirst2 = new double[this.parameters.length];
        } else {
            removeFirst = this.s.removeFirst();
            removeFirst2 = this.y.removeFirst();
            this.rhos.removeFirst();
        }
        double d = 0.0d;
        double d2 = 0.0d;
        for (int i = 0; i < this.parameters.length; i++) {
            if (Double.isInfinite(this.parameters[i]) && Double.isInfinite(this.oldParameters[i]) && this.parameters[i] * this.oldParameters[i] > 0.0d) {
                removeFirst[i] = 0.0d;
            } else {
                removeFirst[i] = this.parameters[i] - this.oldParameters[i];
            }
            if (Double.isInfinite(this.grad[i]) && Double.isInfinite(this.oldGrad[i]) && this.grad[i] * this.oldGrad[i] > 0.0d) {
                removeFirst2[i] = 0.0d;
            } else {
                removeFirst2[i] = this.grad[i] - this.oldGrad[i];
            }
            d += removeFirst[i] * removeFirst2[i];
            d2 += removeFirst2[i] * removeFirst2[i];
        }
        if (d <= 1.0E-4d && d >= -1.0E-4d) {
            d = 0.0d;
        }
        logger.fine("rho=" + d);
        if (d < 0.0d) {
            throw new InvalidOptimizableException("rho = " + d + " < 0: Invalid hessian inverse. Gradient change should be opposite of parameter change.");
        }
        this.s.addLast(removeFirst);
        this.y.addLast(removeFirst2);
        this.rhos.addLast(Double.valueOf(d));
        storeSrcInDest(this.parameters, this.oldParameters);
        storeSrcInDest(this.grad, this.oldGrad);
        return d2;
    }

    private void storeSrcInDest(double[] dArr, double[] dArr2) {
        System.arraycopy(dArr, 0, dArr2, 0, dArr.length);
    }

    private void backTrackingLineSearch() throws Exception {
        double dirDeriv = dirDeriv();
        if (dirDeriv >= 0.0d) {
            throw new InvalidOptimizableException("L-BFGS chose a non-ascent direction: check your gradient!");
        }
        double d = 1.0d;
        double d2 = 0.5d;
        if (this.iterations == 0) {
            d = 1.0d / Math.sqrt(MatrixOps.dotProduct(this.direction, this.direction));
            d2 = 0.1d;
        }
        this.oldValue = this.value;
        logger.fine("*** Starting line search iter=" + this.iterations);
        logger.fine("iter[" + this.iterations + "] Value at start of line search = " + this.value);
        while (true) {
            getNextPoint(d);
            double d3 = this.value;
            this.value = evalL1();
            if (Double.isNaN(this.value)) {
                this.value = d3;
                getPreviousPoint(d);
                return;
            } else {
                logger.fine("iter[" + this.iterations + "] Using alpha = " + d + " new value = " + this.value + " |grad|=" + MatrixOps.twoNorm(this.grad) + " |x|=" + MatrixOps.twoNorm(this.parameters));
                if (this.value <= this.oldValue + (1.0E-4d * dirDeriv * d) || d == 0.0d) {
                    return;
                } else {
                    d *= d2;
                }
            }
        }
    }

    private void getNextPoint(double d) {
        for (int i = 0; i < this.parameters.length; i++) {
            this.parameters[i] = this.oldParameters[i] + (this.direction[i] * d);
            if (this.l1Weight > 0.0d && this.oldParameters[i] * this.parameters[i] < 0.0d) {
                this.parameters[i] = 0.0d;
            }
        }
        if (this.prox && this.gl) {
            double d2 = 0.0d;
            for (int i2 = 0; i2 < this.parameters.length; i2++) {
                d2 += this.parameters[i2] * this.parameters[i2];
            }
            double sqrt = Math.sqrt(d2);
            if (sqrt == 0.0d) {
                this.optimizable.setParameters(this.parameters);
                return;
            }
            double d3 = 1.0d - (this.lambda / sqrt);
            if (d3 < 0.0d) {
                d3 = 0.0d;
            }
            for (int i3 = 0; i3 < this.parameters.length; i3++) {
                this.parameters[i3] = d3 * this.parameters[i3];
            }
        }
        this.optimizable.setParameters(this.parameters);
    }

    private void getPreviousPoint(double d) {
        for (int i = 0; i < this.parameters.length; i++) {
            this.parameters[i] = this.oldParameters[i] - (this.direction[i] * d);
            if (this.l1Weight > 0.0d && this.oldParameters[i] * this.parameters[i] < 0.0d) {
                this.parameters[i] = 0.0d;
            }
        }
        this.optimizable.setParameters(this.parameters);
    }

    private boolean checkValueTerminationCondition() {
        return 2.0d * Math.abs(this.value - this.oldValue) <= 1.0E-6d * ((Math.abs(this.value) + Math.abs(this.oldValue)) + 1.0E-4d);
    }

    private boolean checkGradientTerminationCondition() {
        return MatrixOps.twoNorm(this.grad) < 1.0E-5d;
    }
}
