package com.rapidminer.operator.mfs;

import com.rapidminer.example.Attribute;
import com.rapidminer.example.Attributes;
import com.rapidminer.example.Example;
import com.rapidminer.example.ExampleSet;
import com.rapidminer.operator.learner.PredictionModel;
import com.rapidminer.tools.Tools;
import com.rapidminer.tools.math.similarity.DistanceMeasure;
import com.rapidminer.tools.math.similarity.numerical.EuclideanDistance;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Map;

/* loaded from: input_file:com/rapidminer/operator/mfs/SAMModel.class */
public class SAMModel extends PredictionModel {
    private static final long serialVersionUID = -424389057824068747L;
    protected Example C1;
    protected Example C2;
    protected DistanceMeasure distFunc;
    protected Map<Attribute, Double> M1;
    protected Map<Attribute, Double> M2;
    protected String posLabelString;
    protected String negLabelString;
    protected double posLabelId;
    protected double negLabelId;

    public SAMModel(ExampleSet exampleSet) {
        super(exampleSet);
        this.posLabelString = "";
        this.negLabelString = "";
        this.posLabelId = 0.0d;
        this.negLabelId = 1.0d;
        this.distFunc = new EuclideanDistance();
        int size = exampleSet.getAttributes().size();
        int size2 = exampleSet.size();
        int i = 0;
        int i2 = 0;
        this.posLabelId = exampleSet.getAttributes().getLabel().getMapping().getPositiveIndex();
        this.negLabelId = exampleSet.getAttributes().getLabel().getMapping().getNegativeIndex();
        this.posLabelString = exampleSet.getAttributes().getLabel().getMapping().getPositiveString();
        this.negLabelString = exampleSet.getAttributes().getLabel().getMapping().getNegativeString();
        double[] dArr = new double[size2];
        this.M1 = new HashMap(size);
        this.M2 = new HashMap(size);
        Iterator allAttributes = exampleSet.getAttributes().allAttributes();
        while (allAttributes.hasNext()) {
            Attribute attribute = (Attribute) allAttributes.next();
            this.M1.put(attribute, Double.valueOf(0.0d));
            this.M2.put(attribute, Double.valueOf(0.0d));
        }
        Iterator it = exampleSet.iterator();
        int i3 = 0;
        while (it.hasNext()) {
            double label = ((Example) it.next()).getLabel();
            dArr[i3] = label;
            if (label == this.posLabelId) {
                i++;
            } else if (label == this.negLabelId) {
                i2++;
            }
            i3++;
        }
        log("n1=" + i + ", n2=" + i2);
        if (i + i2 != size2) {
            log("Summe der KlassengrÃ¶ÃŸen ungleich Anzahl Beispiele.");
        }
        if (i == 0 || i2 == 0) {
            log("Eine Klasse mit 0 Elementen");
        }
        Iterator it2 = exampleSet.iterator();
        while (it2.hasNext()) {
            Example example = (Example) it2.next();
            Map<Attribute, Double> map = example.getLabel() == this.posLabelId ? this.M1 : this.M2;
            for (Attribute attribute2 : exampleSet.getAttributes()) {
                map.put(attribute2, Double.valueOf(map.get(attribute2).doubleValue() + example.getValue(attribute2)));
            }
        }
        for (Attribute attribute3 : exampleSet.getAttributes()) {
            this.M1.put(attribute3, Double.valueOf(this.M1.get(attribute3).doubleValue() / i));
            this.M2.put(attribute3, Double.valueOf(this.M2.get(attribute3).doubleValue() / i2));
        }
    }

    public ExampleSet performPrediction(ExampleSet exampleSet, Attribute attribute) {
        Attributes<Attribute> attributes = exampleSet.getAttributes();
        Iterator it = exampleSet.iterator();
        while (it.hasNext()) {
            Example example = (Example) it.next();
            double d = 0.0d;
            double d2 = 0.0d;
            for (Attribute attribute2 : attributes) {
                double numericalValue = example.getNumericalValue(attribute2);
                double doubleValue = numericalValue - this.M1.get(attribute2).doubleValue();
                d += doubleValue * doubleValue;
                double doubleValue2 = numericalValue - this.M2.get(attribute2).doubleValue();
                d2 += doubleValue2 * doubleValue2;
            }
            if (d < d2) {
                example.setPredictedLabel(this.posLabelId);
            } else {
                example.setPredictedLabel(this.negLabelId);
            }
            example.setConfidence(this.posLabelString, d / (d + d2));
            example.setConfidence(this.negLabelString, d2 / (d + d2));
        }
        return exampleSet;
    }

    public boolean isUpdatable() {
        return false;
    }

    public String toString() {
        StringBuffer stringBuffer = new StringBuffer();
        stringBuffer.append("SAM model for two classes");
        stringBuffer.append(Tools.getLineSeparator());
        stringBuffer.append(Tools.getLineSeparator());
        stringBuffer.append("Centroid of positive class: ");
        stringBuffer.append(Tools.getLineSeparator());
        stringBuffer.append(this.M1.toString());
        stringBuffer.append(Tools.getLineSeparator());
        stringBuffer.append(Tools.getLineSeparator());
        stringBuffer.append("Centroid of negative class: ");
        stringBuffer.append(Tools.getLineSeparator());
        stringBuffer.append(this.M2.toString());
        return stringBuffer.toString();
    }
}
