package weka.classifiers.functions;

import java.util.Collections;
import java.util.Enumeration;
import java.util.Vector;
import si.ijs.straw.CSVWriter;
import weka.classifiers.AbstractClassifier;
import weka.classifiers.evaluation.RegressionAnalysis;
import weka.core.Attribute;
import weka.core.Capabilities;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.Option;
import weka.core.RevisionUtils;
import weka.core.Utils;
import weka.core.WeightedInstancesHandler;

/* loaded from: classes2.dex */
public class SimpleLinearRegression extends AbstractClassifier implements WeightedInstancesHandler {
    static final long serialVersionUID = 1679336022895414137L;
    private Attribute m_attribute;
    private int m_attributeIndex;
    private double m_classMeanForMissing;
    private int m_df;
    private double m_intercept;
    protected boolean m_outputAdditionalStats;
    private double m_slope;
    private double m_seSlope = Double.NaN;
    private double m_seIntercept = Double.NaN;
    private double m_tstatSlope = Double.NaN;
    private double m_tstatIntercept = Double.NaN;
    private double m_rsquared = Double.NaN;
    private double m_rsquaredAdj = Double.NaN;
    private double m_fstat = Double.NaN;
    private boolean m_suppressErrorMessage = false;

    public static void main(String[] strArr) {
        runClassifier(new SimpleLinearRegression(), strArr);
    }

    @Override // weka.classifiers.Classifier
    public void buildClassifier(Instances instances) throws Exception {
        double d;
        boolean z;
        Instances instances2 = instances;
        getCapabilities().testWithFail(instances2);
        if (this.m_outputAdditionalStats) {
            int i = 0;
            while (true) {
                if (i >= instances.numInstances()) {
                    z = true;
                    break;
                } else {
                    if (instances2.instance(i).weight() != 1.0d) {
                        z = false;
                        break;
                    }
                    i++;
                }
            }
            if (!z) {
                throw new Exception("Can only compute additional statistics on unweighted data");
            }
        }
        double[] dArr = new double[instances.numAttributes()];
        double[] dArr2 = new double[instances.numAttributes()];
        double[] dArr3 = new double[instances.numAttributes()];
        double[] dArr4 = new double[instances.numAttributes()];
        double d2 = 0.0d;
        double d3 = 0.0d;
        for (int i2 = 0; i2 < instances.numInstances(); i2++) {
            Instance instance = instances2.instance(i2);
            if (!instance.classIsMissing()) {
                for (int i3 = 0; i3 < instances.numAttributes(); i3++) {
                    if (instance.isMissing(i3)) {
                        dArr3[i3] = dArr3[i3] + (instance.classValue() * instance.weight());
                        dArr4[i3] = dArr4[i3] + (instance.classValue() * instance.classValue() * instance.weight());
                    } else {
                        dArr[i3] = dArr[i3] + (instance.weight() * instance.value(i3));
                        dArr2[i3] = dArr2[i3] + instance.weight();
                    }
                }
                d2 += instance.weight();
                d3 += instance.weight() * instance.classValue();
            }
        }
        double[] dArr5 = new double[instances.numAttributes()];
        double[] dArr6 = new double[instances.numAttributes()];
        double[] dArr7 = new double[instances.numAttributes()];
        for (int i4 = 0; i4 < instances.numAttributes(); i4++) {
            if (i4 != instances.classIndex()) {
                if (dArr2[i4] > 0.0d) {
                    dArr5[i4] = dArr[i4] / dArr2[i4];
                }
                if (d2 - dArr2[i4] > 0.0d) {
                    dArr6[i4] = dArr3[i4] / (d2 - dArr2[i4]);
                }
                if (dArr2[i4] > 0.0d) {
                    dArr7[i4] = (d3 - dArr3[i4]) / dArr2[i4];
                }
            }
        }
        double[] dArr8 = new double[instances.numAttributes()];
        double[] dArr9 = new double[instances.numAttributes()];
        double[] dArr10 = new double[instances.numAttributes()];
        int i5 = 0;
        while (i5 < instances.numInstances()) {
            Instance instance2 = instances2.instance(i5);
            if (instance2.classIsMissing()) {
                d = d2;
            } else {
                d = d2;
                for (int i6 = 0; i6 < instances.numAttributes(); i6++) {
                    if (!instance2.isMissing(i6) && i6 != instances.classIndex()) {
                        double classValue = instance2.classValue() - dArr7[i6];
                        double weight = instance2.weight() * classValue;
                        double value = instance2.value(i6) - dArr5[i6];
                        double weight2 = instance2.weight() * value;
                        dArr8[i6] = dArr8[i6] + (weight * value);
                        dArr9[i6] = dArr9[i6] + (weight2 * value);
                        dArr10[i6] = dArr10[i6] + (weight * classValue);
                    }
                }
            }
            i5++;
            instances2 = instances;
            d2 = d;
        }
        double d4 = d2;
        this.m_attribute = null;
        double d5 = Double.NaN;
        double d6 = Double.NaN;
        double d7 = Double.NaN;
        double d8 = Double.MAX_VALUE;
        int i7 = -1;
        for (int i8 = 0; i8 < instances.numAttributes(); i8++) {
            double d9 = dArr4[i8] - (dArr3[i8] * dArr6[i8]);
            if (i8 != instances.classIndex() && dArr9[i8] != 0.0d) {
                double d10 = dArr8[i8];
                dArr8[i8] = dArr8[i8] / dArr9[i8];
                double d11 = dArr7[i8] - (dArr8[i8] * dArr5[i8]);
                double d12 = (dArr10[i8] - (dArr8[i8] * d10)) + d9;
                if (d12 < d8) {
                    i7 = i8;
                    d5 = dArr8[i8];
                    d7 = dArr6[i8];
                    d6 = d11;
                    d8 = d12;
                }
            }
        }
        if (i7 == -1) {
            if (!this.m_suppressErrorMessage) {
                System.err.println("----- no useful attribute found");
            }
            this.m_attribute = null;
            this.m_attributeIndex = 0;
            this.m_slope = 0.0d;
            this.m_intercept = d3 / d4;
            this.m_classMeanForMissing = 0.0d;
            return;
        }
        this.m_attribute = instances.attribute(i7);
        this.m_attributeIndex = i7;
        this.m_slope = d5;
        this.m_intercept = d6;
        this.m_classMeanForMissing = d7;
        if (this.m_outputAdditionalStats) {
            Instances instances3 = new Instances(instances, instances.numInstances());
            for (int i9 = 0; i9 < instances.numInstances(); i9++) {
                Instance instance3 = instances.instance(i9);
                if (!instance3.classIsMissing() && !instance3.isMissing(this.m_attributeIndex)) {
                    instances3.add(instance3);
                }
            }
            this.m_df = instances3.numInstances() - 2;
            double[] calculateStdErrorOfCoef = RegressionAnalysis.calculateStdErrorOfCoef(instances3, this.m_attribute, this.m_slope, this.m_intercept, this.m_df);
            this.m_seSlope = calculateStdErrorOfCoef[0];
            this.m_seIntercept = calculateStdErrorOfCoef[1];
            double[] calculateTStats = RegressionAnalysis.calculateTStats(new double[]{this.m_slope, this.m_intercept}, calculateStdErrorOfCoef, 2);
            this.m_tstatSlope = calculateTStats[0];
            this.m_tstatIntercept = calculateTStats[1];
            this.m_rsquared = RegressionAnalysis.calculateRSquared(instances3, RegressionAnalysis.calculateSSR(instances3, this.m_attribute, this.m_slope, this.m_intercept));
            this.m_rsquaredAdj = RegressionAnalysis.calculateAdjRSquared(this.m_rsquared, instances3.numInstances(), 2);
            this.m_fstat = RegressionAnalysis.calculateFStat(this.m_rsquared, instances3.numInstances(), 2);
        }
    }

    @Override // weka.classifiers.AbstractClassifier, weka.classifiers.Classifier
    public double classifyInstance(Instance instance) throws Exception {
        return this.m_attribute == null ? this.m_intercept : instance.isMissing(this.m_attributeIndex) ? this.m_classMeanForMissing : this.m_intercept + (this.m_slope * instance.value(this.m_attributeIndex));
    }

    public boolean foundUsefulAttribute() {
        return this.m_attribute != null;
    }

    public int getAttributeIndex() {
        return this.m_attributeIndex;
    }

    @Override // weka.classifiers.AbstractClassifier, weka.classifiers.Classifier, weka.core.CapabilitiesHandler
    public Capabilities getCapabilities() {
        Capabilities capabilities = super.getCapabilities();
        capabilities.disableAll();
        capabilities.enable(Capabilities.Capability.NUMERIC_ATTRIBUTES);
        capabilities.enable(Capabilities.Capability.DATE_ATTRIBUTES);
        capabilities.enable(Capabilities.Capability.MISSING_VALUES);
        capabilities.enable(Capabilities.Capability.NUMERIC_CLASS);
        capabilities.enable(Capabilities.Capability.DATE_CLASS);
        capabilities.enable(Capabilities.Capability.MISSING_CLASS_VALUES);
        return capabilities;
    }

    public double getIntercept() {
        return this.m_intercept;
    }

    @Override // weka.classifiers.AbstractClassifier, weka.core.OptionHandler
    public String[] getOptions() {
        Vector vector = new Vector();
        if (getOutputAdditionalStats()) {
            vector.add("-additional-stats");
        }
        Collections.addAll(vector, super.getOptions());
        return (String[]) vector.toArray(new String[vector.size()]);
    }

    public boolean getOutputAdditionalStats() {
        return this.m_outputAdditionalStats;
    }

    @Override // weka.classifiers.AbstractClassifier, weka.core.RevisionHandler
    public String getRevision() {
        return RevisionUtils.extract("$Revision: 11130 $");
    }

    public double getSlope() {
        return this.m_slope;
    }

    public String globalInfo() {
        return "Learns a simple linear regression model. Picks the attribute that results in the lowest squared error. Can only deal with numeric attributes.";
    }

    @Override // weka.classifiers.AbstractClassifier, weka.core.OptionHandler
    public Enumeration<Option> listOptions() {
        Vector vector = new Vector();
        vector.addElement(new Option("\tOutput additional statistics.", "additional-stats", 0, "-additional-stats"));
        vector.addAll(Collections.list(super.listOptions()));
        return vector.elements();
    }

    public String outputAdditionalStatsTipText() {
        return "Output additional statistics (such as std deviation of coefficients and t-statistics)";
    }

    @Override // weka.classifiers.AbstractClassifier, weka.core.OptionHandler
    public void setOptions(String[] strArr) throws Exception {
        setOutputAdditionalStats(Utils.getFlag("additional-stats", strArr));
        super.setOptions(strArr);
        Utils.checkForRemainingOptions(strArr);
    }

    public void setOutputAdditionalStats(boolean z) {
        this.m_outputAdditionalStats = z;
    }

    public void setSuppressErrorMessage(boolean z) {
        this.m_suppressErrorMessage = z;
    }

    public String toString() {
        StringBuffer stringBuffer = new StringBuffer();
        if (this.m_attribute == null) {
            stringBuffer.append("Predicting constant " + this.m_intercept);
        } else {
            stringBuffer.append("Linear regression on " + this.m_attribute.name() + "\n\n");
            stringBuffer.append(Utils.doubleToString(this.m_slope, 2) + " * " + this.m_attribute.name());
            if (this.m_intercept > 0.0d) {
                stringBuffer.append(" + " + Utils.doubleToString(this.m_intercept, 2));
            } else {
                stringBuffer.append(" - " + Utils.doubleToString(-this.m_intercept, 2));
            }
            stringBuffer.append("\n\nPredicting " + Utils.doubleToString(this.m_classMeanForMissing, 2) + " if attribute value is missing.");
            if (this.m_outputAdditionalStats) {
                int length = this.m_attribute.name().length() + 3;
                if (length < 11) {
                    length = 11;
                }
                stringBuffer.append("\n\nRegression Analysis:\n\n" + Utils.padRight("Variable", length) + "  Coefficient     SE of Coef        t-Stat");
                StringBuilder sb = new StringBuilder();
                sb.append(CSVWriter.DEFAULT_LINE_END);
                sb.append(Utils.padRight(this.m_attribute.name(), length));
                stringBuffer.append(sb.toString());
                stringBuffer.append(Utils.doubleToString(this.m_slope, 12, 4));
                stringBuffer.append("   " + Utils.doubleToString(this.m_seSlope, 12, 5));
                stringBuffer.append("   " + Utils.doubleToString(this.m_tstatSlope, 12, 5));
                stringBuffer.append(Utils.padRight("\nconst", length + 1) + Utils.doubleToString(this.m_intercept, 12, 4));
                stringBuffer.append("   " + Utils.doubleToString(this.m_seIntercept, 12, 5));
                stringBuffer.append("   " + Utils.doubleToString(this.m_tstatIntercept, 12, 5));
                stringBuffer.append("\n\nDegrees of freedom = " + Integer.toString(this.m_df));
                stringBuffer.append("\nR^2 value = " + Utils.doubleToString(this.m_rsquared, 5));
                stringBuffer.append("\nAdjusted R^2 = " + Utils.doubleToString(this.m_rsquaredAdj, 5));
                stringBuffer.append("\nF-statistic = " + Utils.doubleToString(this.m_fstat, 5));
            }
        }
        stringBuffer.append(CSVWriter.DEFAULT_LINE_END);
        return stringBuffer.toString();
    }
}
