package com.rapidminer.operator.mfs;

import com.rapidminer.example.AttributeWeights;
import com.rapidminer.example.ExampleSet;
import com.rapidminer.example.set.MappedExampleSet;
import com.rapidminer.example.set.SplittedExampleSet;
import com.rapidminer.operator.OperatorDescription;
import com.rapidminer.operator.OperatorException;
import com.rapidminer.operator.ValueDouble;
import com.rapidminer.operator.performance.EstimatedPerformance;
import com.rapidminer.operator.performance.PerformanceVector;
import com.rapidminer.operator.ports.OutputPort;
import com.rapidminer.operator.ports.metadata.GenerateNewMDRule;
import com.rapidminer.parameter.ParameterType;
import com.rapidminer.parameter.ParameterTypeBoolean;
import com.rapidminer.parameter.ParameterTypeCategory;
import com.rapidminer.parameter.ParameterTypeDouble;
import com.rapidminer.parameter.ParameterTypeInt;
import com.rapidminer.parameter.conditions.BooleanParameterCondition;
import com.rapidminer.parameter.conditions.EqualTypeCondition;
import com.rapidminer.tools.RandomGenerator;
import java.util.List;

/* loaded from: input_file:com/rapidminer/operator/mfs/EnsembleFeatureSelection.class */
public class EnsembleFeatureSelection extends AbstractWeightingChain {
    protected final OutputPort robustnessOutput;
    public static final String PARAMETER_METHOD = "method";
    public static final int METHOD_TOPK = 0;
    public static final int METHOD_WEIGHTS = 1;
    public static final int METHOD_ACCUMULATE_WEIGHTS = 2;
    public static final String PARAMETER_K = "k";
    public static final String PARAMETER_MIN_ROUNDS = "min_rounds";
    public static final String PARAMETER_W = "w";
    public static final String PARAMETER_BOOTSTRAP_OR_SUBSETS = "subsets_or_bootstrap";
    public static final int SUBSET = 0;
    public static final int BOOTSTRAP = 1;
    public static final String PARAMETER_BOOTSTRAP_RATIO = "ratio";
    public static final String PARAMETER_ENSEMBLE_SIZE = "ensemble_size";
    public static final String PARAMETER_NORMALIZE = "normalize_weights";
    public static final String PARAMETER_ABSOLUTE_WEIGHTS = "use_absolute_weights";
    public static final String PARAMETER_LEAVE_ONE_OUT = "leave_one_out";
    public static final String PARAMETER_SAMPLING_TYPE = "sampling_type";
    public static final String PARAMETER_LOCAL_RANDOM_SEED = "local_random_seed";
    private double robustness;
    public static final String[] METHODS = {"top-k", "geq_w", "accumulate_weights"};
    public static final String[] BOOTSTRAP_OR_SUBSETS = {"subsets", "bootstrap"};

    /* loaded from: input_file:com/rapidminer/operator/mfs/EnsembleFeatureSelection$Method.class */
    public enum Method {
        Subset,
        Average
    }

    public EnsembleFeatureSelection(OperatorDescription operatorDescription) {
        super(operatorDescription);
        this.robustnessOutput = getOutputPorts().createPort("robustness");
        this.robustness = 0.0d;
        addValue(new ValueDouble("robustness", "Robustness of the inner Feature Selection operator") { // from class: com.rapidminer.operator.mfs.EnsembleFeatureSelection.1
            public double getDoubleValue() {
                return EnsembleFeatureSelection.this.robustness;
            }
        });
        getTransformer().addRule(new GenerateNewMDRule(this.robustnessOutput, PerformanceVector.class));
    }

    @Override // com.rapidminer.operator.mfs.AbstractWeightingChain
    public void doWork() throws OperatorException {
        this.iteration = 0;
        ExampleSet data = this.exampleSetInput.getData();
        boolean z = getParameterAsInt("subsets_or_bootstrap") == 1;
        int parameterAsInt = getParameterAsInt(PARAMETER_METHOD);
        int parameterAsInt2 = getParameterAsInt("sampling_type");
        int parameterAsInt3 = (z || !getParameterAsBoolean("leave_one_out")) ? getParameterAsInt("ensemble_size") : data.size();
        getLogger().finer("Starting " + parameterAsInt3 + "-fold ensemble feature selection");
        int parameterAsInt4 = getParameterAsInt(PARAMETER_MIN_ROUNDS);
        if (parameterAsInt4 > parameterAsInt3) {
            parameterAsInt4 = parameterAsInt3;
            getLogger().finer("Setting the minimum number of top-ranked rounds to number of rounds.");
        }
        AttributeWeights[] attributeWeightsArr = new AttributeWeights[parameterAsInt3];
        if (z) {
            RandomGenerator randomGenerator = RandomGenerator.getRandomGenerator(this);
            this.iteration = 0;
            while (this.iteration < parameterAsInt3) {
                inApplyLoop();
                this.weightingProcessExampleSetOutput.deliver(new MappedExampleSet(data, MappedExampleSet.createBootstrappingMapping(data, (int) Math.round(data.size() * getParameterAsDouble("ratio")), randomGenerator), true));
                getSubprocess(0).execute();
                attributeWeightsArr[this.iteration] = (AttributeWeights) this.weightingProcessWeightsInput.getData();
                this.iteration++;
            }
        } else {
            SplittedExampleSet splittedExampleSet = new SplittedExampleSet(data, parameterAsInt3, parameterAsInt2, getParameterAsBoolean("use_local_random_seed"), getParameterAsInt("local_random_seed"));
            this.iteration = 0;
            while (this.iteration < parameterAsInt3) {
                inApplyLoop();
                splittedExampleSet.selectAllSubsetsBut(this.iteration);
                this.weightingProcessExampleSetOutput.deliver(splittedExampleSet);
                getSubprocess(0).execute();
                attributeWeightsArr[this.iteration] = (AttributeWeights) this.weightingProcessWeightsInput.getData();
                this.iteration++;
            }
        }
        AttributeWeights attributeWeights = new AttributeWeights(data);
        String[] strArr = (String[]) attributeWeights.getAttributeNames().toArray(new String[attributeWeights.getAttributeNames().size()]);
        int parameterAsInt5 = getParameterAsInt("k");
        double parameterAsDouble = getParameterAsDouble(PARAMETER_W);
        for (String str : strArr) {
            attributeWeights.setWeight(str, 0.0d);
        }
        for (int i = 0; i < parameterAsInt3; i++) {
            if (attributeWeightsArr[i].getSize() < strArr.length) {
                AttributeWeights attributeWeights2 = new AttributeWeights(data);
                for (String str2 : strArr) {
                    double weight = attributeWeightsArr[i].getWeight(str2);
                    if (Double.isNaN(weight)) {
                        attributeWeights2.setWeight(str2, 0.0d);
                    } else {
                        attributeWeights2.setWeight(str2, weight);
                    }
                }
                attributeWeightsArr[i] = attributeWeights2;
            }
        }
        boolean parameterAsBoolean = getParameterAsBoolean("use_absolute_weights");
        int i2 = parameterAsBoolean ? 1 : 0;
        for (int i3 = 0; i3 < parameterAsInt3; i3++) {
            if (parameterAsInt == 0) {
                attributeWeightsArr[i3].sortByWeight(strArr, -1, i2);
                for (int i4 = 0; i4 < parameterAsInt5 && i4 < strArr.length; i4++) {
                    attributeWeights.setWeight(strArr[i4], attributeWeights.getWeight(strArr[i4]) + 1.0d);
                    attributeWeightsArr[i3].setWeight(strArr[i4], 1.0d);
                }
                for (int i5 = parameterAsInt5; i5 < strArr.length; i5++) {
                    attributeWeightsArr[i3].setWeight(strArr[i5], 0.0d);
                }
            } else if (parameterAsInt == 1) {
                for (int i6 = 0; i6 < strArr.length; i6++) {
                    if ((parameterAsBoolean || attributeWeightsArr[i3].getWeight(strArr[i6]) < parameterAsDouble) && (!parameterAsBoolean || Math.abs(attributeWeightsArr[i3].getWeight(strArr[i6])) < parameterAsDouble)) {
                        attributeWeightsArr[i3].setWeight(strArr[i6], 0.0d);
                    } else {
                        attributeWeightsArr[i3].setWeight(strArr[i6], 1.0d);
                        attributeWeights.setWeight(strArr[i6], attributeWeights.getWeight(strArr[i6]) + 1.0d);
                    }
                }
            } else if (parameterAsInt == 2) {
                for (int i7 = 0; i7 < strArr.length; i7++) {
                    attributeWeights.setWeight(strArr[i7], attributeWeights.getWeight(strArr[i7]) + attributeWeightsArr[i3].getWeight(strArr[i7]));
                }
            }
        }
        if (parameterAsInt == 0) {
            if (parameterAsInt4 > 0) {
                for (int i8 = 0; i8 < strArr.length; i8++) {
                    if (attributeWeights.getWeight(strArr[i8]) < parameterAsInt4) {
                        attributeWeights.setWeight(strArr[i8], 0.0d);
                    }
                }
            } else {
                attributeWeights.sortByWeight(strArr, -1, i2);
                for (int i9 = 0; i9 < parameterAsInt5 && i9 < strArr.length; i9++) {
                    attributeWeights.setWeight(strArr[i9], 1.0d);
                }
                for (int i10 = parameterAsInt5; i10 < strArr.length; i10++) {
                    attributeWeights.setWeight(strArr[i10], 0.0d);
                }
            }
        }
        if (getParameterAsBoolean("normalize_weights")) {
            attributeWeights.normalize();
        }
        this.robustness = Util.averageJaccard(strArr, attributeWeightsArr, parameterAsInt3);
        getLogger().finest("Robustness of the ensemble feature selection: " + this.robustness);
        PerformanceVector performanceVector = new PerformanceVector();
        performanceVector.addCriterion(new EstimatedPerformance("Robustness", this.robustness, strArr.length, false));
        performanceVector.setMainCriterionName("Robustness");
        this.weightsOutput.deliver(attributeWeights);
        this.exampleSetOutput.deliver(data);
        this.robustnessOutput.deliver(performanceVector);
    }

    public List<ParameterType> getParameterTypes() {
        List<ParameterType> parameterTypes = super.getParameterTypes();
        ParameterTypeInt parameterTypeInt = new ParameterTypeInt("ensemble_size", "Ensemble size.", 2, Integer.MAX_VALUE, 10);
        parameterTypeInt.setExpert(false);
        parameterTypes.add(parameterTypeInt);
        parameterTypes.add(new ParameterTypeCategory(PARAMETER_METHOD, "Select the top-k attributes or those with a weight above threshold w", METHODS, 0));
        ParameterTypeInt parameterTypeInt2 = new ParameterTypeInt("k", "Parameter k, as in top-k", 0, Integer.MAX_VALUE, 100);
        parameterTypeInt2.registerDependencyCondition(new EqualTypeCondition(this, PARAMETER_METHOD, METHODS, false, new int[]{0}));
        parameterTypes.add(parameterTypeInt2);
        ParameterTypeInt parameterTypeInt3 = new ParameterTypeInt(PARAMETER_MIN_ROUNDS, "The minimum number of rounds an attribute has to be top ranked, to be included in the final subset. 0 = no constraint.", 0, Integer.MAX_VALUE, 0);
        parameterTypeInt3.registerDependencyCondition(new EqualTypeCondition(this, PARAMETER_METHOD, METHODS, false, new int[]{0}));
        parameterTypes.add(parameterTypeInt3);
        ParameterTypeDouble parameterTypeDouble = new ParameterTypeDouble(PARAMETER_W, "Parameter w, as in weights geq w", Double.MIN_VALUE, Double.MAX_VALUE, 0.5d);
        parameterTypeDouble.registerDependencyCondition(new EqualTypeCondition(this, PARAMETER_METHOD, METHODS, false, new int[]{1}));
        parameterTypes.add(parameterTypeDouble);
        parameterTypes.add(new ParameterTypeBoolean("use_absolute_weights", "Use absolute weights", false));
        parameterTypes.add(new ParameterTypeBoolean("normalize_weights", "Normalize weights", true));
        parameterTypes.add(new ParameterTypeCategory("subsets_or_bootstrap", "Use subsets or bootstraps to generate ensemble diversity.", BOOTSTRAP_OR_SUBSETS, 0));
        ParameterTypeDouble parameterTypeDouble2 = new ParameterTypeDouble("ratio", "Relative size of the bootstrapped example sets.", 1.0E-4d, 1.0d, 1.0d);
        parameterTypeDouble2.registerDependencyCondition(new EqualTypeCondition(this, "subsets_or_bootstrap", BOOTSTRAP_OR_SUBSETS, false, new int[]{1}));
        parameterTypeDouble2.setExpert(false);
        parameterTypes.add(parameterTypeDouble2);
        ParameterTypeBoolean parameterTypeBoolean = new ParameterTypeBoolean("leave_one_out", "Set the number of validations to the number of examples. If set to true, number_of_validations is ignored", false);
        parameterTypeBoolean.registerDependencyCondition(new EqualTypeCondition(this, "subsets_or_bootstrap", BOOTSTRAP_OR_SUBSETS, false, new int[]{0}));
        parameterTypeBoolean.setExpert(false);
        parameterTypes.add(parameterTypeBoolean);
        ParameterTypeCategory parameterTypeCategory = new ParameterTypeCategory("sampling_type", "Defines the sampling type of the cross validation (linear = consecutive subsets, shuffled = random subsets, stratified = random subsets with class distribution kept constant)", SplittedExampleSet.SAMPLING_NAMES, 2);
        parameterTypeCategory.registerDependencyCondition(new EqualTypeCondition(this, "subsets_or_bootstrap", BOOTSTRAP_OR_SUBSETS, false, new int[]{0}));
        parameterTypeCategory.registerDependencyCondition(new BooleanParameterCondition(this, "leave_one_out", false, false));
        parameterTypes.add(parameterTypeCategory);
        ParameterTypeInt parameterTypeInt4 = new ParameterTypeInt("local_random_seed", "Use the given random seed instead of global random numbers (-1: use global)", -1, Integer.MAX_VALUE, -1);
        parameterTypeInt4.registerDependencyCondition(new EqualTypeCondition(this, "subsets_or_bootstrap", BOOTSTRAP_OR_SUBSETS, false, new int[]{0}));
        parameterTypeInt4.registerDependencyCondition(new BooleanParameterCondition(this, "leave_one_out", false, false));
        parameterTypeInt4.registerDependencyCondition(new EqualTypeCondition(this, "sampling_type", SplittedExampleSet.SAMPLING_NAMES, false, new int[]{1, 2}));
        parameterTypes.add(parameterTypeInt4);
        return parameterTypes;
    }
}
