package weka.classifiers.meta;

import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import java.io.StringReader;
import java.io.StringWriter;
import java.lang.reflect.Array;
import java.util.Collections;
import java.util.Enumeration;
import java.util.Random;
import java.util.Vector;
import si.ijs.straw.CSVWriter;
import weka.classifiers.Classifier;
import weka.classifiers.CostMatrix;
import weka.classifiers.RandomizableSingleClassifierEnhancer;
import weka.classifiers.rules.ZeroR;
import weka.core.BatchPredictor;
import weka.core.Capabilities;
import weka.core.Drawable;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.Option;
import weka.core.OptionHandler;
import weka.core.RevisionUtils;
import weka.core.SelectedTag;
import weka.core.Tag;
import weka.core.TestInstances;
import weka.core.Utils;
import weka.core.WeightedInstancesHandler;

/* loaded from: classes2.dex */
public class CostSensitiveClassifier extends RandomizableSingleClassifierEnhancer implements OptionHandler, Drawable, BatchPredictor, WeightedInstancesHandler {
    public static final int MATRIX_ON_DEMAND = 1;
    public static final int MATRIX_SUPPLIED = 2;
    public static final Tag[] TAGS_MATRIX_SOURCE = {new Tag(1, "Load cost matrix on demand"), new Tag(2, "Use explicit cost matrix")};
    static final long serialVersionUID = -110658209263002404L;
    protected String m_CostFile;
    protected boolean m_MinimizeExpectedCost;
    protected int m_MatrixSource = 1;
    protected File m_OnDemandDirectory = new File(System.getProperty("user.dir"));
    protected CostMatrix m_CostMatrix = new CostMatrix(1);

    public CostSensitiveClassifier() {
        this.m_Classifier = new ZeroR();
    }

    public static void main(String[] strArr) {
        runClassifier(new CostSensitiveClassifier(), strArr);
    }

    @Override // weka.classifiers.AbstractClassifier
    public String batchSizeTipText() {
        return "Batch size to use if base learner is a BatchPredictor";
    }

    @Override // weka.classifiers.Classifier
    public void buildClassifier(Instances instances) throws Exception {
        getCapabilities().testWithFail(instances);
        Instances instances2 = new Instances(instances);
        instances2.deleteWithMissingClass();
        if (this.m_Classifier == null) {
            throw new Exception("No base classifier has been set!");
        }
        if (this.m_MatrixSource == 1) {
            File file = new File(getOnDemandDirectory(), instances2.relationName() + CostMatrix.FILE_EXTENSION);
            if (!file.exists()) {
                throw new Exception("On-demand cost file doesn't exist: " + file);
            }
            setCostMatrix(new CostMatrix(new BufferedReader(new FileReader(file))));
        } else if (this.m_CostMatrix == null) {
            this.m_CostMatrix = new CostMatrix(instances2.numClasses());
            this.m_CostMatrix.readOldFormat(new BufferedReader(new FileReader(this.m_CostFile)));
        }
        if (!this.m_MinimizeExpectedCost) {
            instances2 = this.m_CostMatrix.applyCostMatrix(instances2, this.m_Classifier instanceof WeightedInstancesHandler ? null : new Random(this.m_Seed));
        } else if (!instances2.allInstanceWeightsIdentical() && !(this.m_Classifier instanceof WeightedInstancesHandler)) {
            instances2 = instances2.resampleWithWeights(instances2.numInstances() > 0 ? instances2.getRandomNumberGenerator(getSeed()) : new Random(getSeed()));
        }
        this.m_Classifier.buildClassifier(instances2);
    }

    protected double[] convertDistribution(double[] dArr, Instance instance) throws Exception {
        int minIndex = Utils.minIndex(this.m_CostMatrix.expectedCosts(dArr, instance));
        for (int i = 0; i < dArr.length; i++) {
            if (i == minIndex) {
                dArr[i] = 1.0d;
            } else {
                dArr[i] = 0.0d;
            }
        }
        return dArr;
    }

    public String costMatrixSourceTipText() {
        return "Sets where to get the cost matrix. The two options areto use the supplied explicit cost matrix (the setting of the costMatrix property), or to load a cost matrix from a file when required (this file will be loaded from the directory set by the onDemandDirectory property and will be named relation_name" + CostMatrix.FILE_EXTENSION + ").";
    }

    public String costMatrixTipText() {
        return "Sets the cost matrix explicitly. This matrix is used if the costMatrixSource property is set to \"Supplied\".";
    }

    @Override // weka.classifiers.SingleClassifierEnhancer
    protected String defaultClassifierString() {
        return "weka.classifiers.rules.ZeroR";
    }

    @Override // weka.classifiers.AbstractClassifier, weka.classifiers.Classifier
    public double[] distributionForInstance(Instance instance) throws Exception {
        return !this.m_MinimizeExpectedCost ? this.m_Classifier.distributionForInstance(instance) : convertDistribution(this.m_Classifier.distributionForInstance(instance), instance);
    }

    @Override // weka.classifiers.AbstractClassifier, weka.core.BatchPredictor
    public double[][] distributionsForInstances(Instances instances) throws Exception {
        int i = 0;
        if (!(getClassifier() instanceof BatchPredictor)) {
            double[][] dArr = (double[][]) Array.newInstance((Class<?>) double.class, instances.numInstances(), instances.numClasses());
            while (i < instances.numInstances()) {
                dArr[i] = distributionForInstance(instances.instance(i));
                i++;
            }
            return dArr;
        }
        double[][] distributionsForInstances = ((BatchPredictor) getClassifier()).distributionsForInstances(instances);
        if (!this.m_MinimizeExpectedCost) {
            return distributionsForInstances;
        }
        while (i < distributionsForInstances.length) {
            distributionsForInstances[i] = convertDistribution(distributionsForInstances[i], instances.instance(i));
            i++;
        }
        return distributionsForInstances;
    }

    @Override // weka.classifiers.AbstractClassifier, weka.core.BatchPredictor
    public String getBatchSize() {
        return getClassifier() instanceof BatchPredictor ? ((BatchPredictor) getClassifier()).getBatchSize() : super.getBatchSize();
    }

    @Override // weka.classifiers.SingleClassifierEnhancer, weka.classifiers.AbstractClassifier, weka.classifiers.Classifier, weka.core.CapabilitiesHandler
    public Capabilities getCapabilities() {
        Capabilities capabilities = super.getCapabilities();
        capabilities.disableAllClasses();
        capabilities.disableAllClassDependencies();
        capabilities.enable(Capabilities.Capability.NOMINAL_CLASS);
        return capabilities;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // weka.classifiers.SingleClassifierEnhancer
    public String getClassifierSpec() {
        Classifier classifier = getClassifier();
        if (!(classifier instanceof OptionHandler)) {
            return classifier.getClass().getName();
        }
        return classifier.getClass().getName() + TestInstances.DEFAULT_SEPARATORS + Utils.joinOptions(((OptionHandler) classifier).getOptions());
    }

    public CostMatrix getCostMatrix() {
        return this.m_CostMatrix;
    }

    public SelectedTag getCostMatrixSource() {
        return new SelectedTag(this.m_MatrixSource, TAGS_MATRIX_SOURCE);
    }

    public boolean getMinimizeExpectedCost() {
        return this.m_MinimizeExpectedCost;
    }

    public File getOnDemandDirectory() {
        return this.m_OnDemandDirectory;
    }

    @Override // weka.classifiers.RandomizableSingleClassifierEnhancer, weka.classifiers.SingleClassifierEnhancer, weka.classifiers.AbstractClassifier, weka.core.OptionHandler
    public String[] getOptions() {
        Vector vector = new Vector();
        if (this.m_MatrixSource != 2) {
            vector.add("-N");
            vector.add("" + getOnDemandDirectory());
        } else if (this.m_CostFile != null) {
            vector.add("-C");
            vector.add("" + this.m_CostFile);
        } else {
            vector.add("-cost-matrix");
            vector.add(getCostMatrix().toMatlab());
        }
        if (getMinimizeExpectedCost()) {
            vector.add("-M");
        }
        Collections.addAll(vector, super.getOptions());
        return (String[]) vector.toArray(new String[0]);
    }

    @Override // weka.classifiers.AbstractClassifier, weka.core.RevisionHandler
    public String getRevision() {
        return RevisionUtils.extract("$Revision: 15477 $");
    }

    public String globalInfo() {
        return "A metaclassifier that makes its base classifier cost sensitive. Two methods can be used to introduce cost-sensitivity: reweighting training instances according to the total cost assigned to each class; or predicting the class with minimum expected misclassification cost (rather than the most likely class). Performance can often be improved by using a bagged classifier to improve the probability estimates of the base classifier. If the base classifier cannot handle instance weights, and the instance weights are not uniform, the data will be resampled with replacement based on the weights before being passed to the base classifier.";
    }

    @Override // weka.core.Drawable
    public String graph() throws Exception {
        if (this.m_Classifier instanceof Drawable) {
            return ((Drawable) this.m_Classifier).graph();
        }
        throw new Exception("Classifier: " + getClassifierSpec() + " cannot be graphed");
    }

    @Override // weka.core.Drawable
    public int graphType() {
        if (this.m_Classifier instanceof Drawable) {
            return ((Drawable) this.m_Classifier).graphType();
        }
        return 0;
    }

    @Override // weka.classifiers.AbstractClassifier, weka.core.BatchPredictor
    public boolean implementsMoreEfficientBatchPrediction() {
        if (getClassifier() instanceof BatchPredictor) {
            return ((BatchPredictor) getClassifier()).implementsMoreEfficientBatchPrediction();
        }
        return false;
    }

    @Override // weka.classifiers.RandomizableSingleClassifierEnhancer, weka.classifiers.SingleClassifierEnhancer, weka.classifiers.AbstractClassifier, weka.core.OptionHandler
    public Enumeration<Option> listOptions() {
        Vector vector = new Vector(4);
        vector.addElement(new Option("\tMinimize expected misclassification cost. Default is to\n\treweight training instances according to costs per class", "M", 0, "-M"));
        vector.addElement(new Option("\tFile name of a cost matrix to use. If this is not supplied,\n\ta cost matrix will be loaded on demand. The name of the\n\ton-demand file is the relation name of the training data\n\tplus \".cost\", and the path to the on-demand file is\n\tspecified with the -N option.", "C", 1, "-C <cost file name>"));
        vector.addElement(new Option("\tName of a directory to search for cost files when loading\n\tcosts on demand (default current directory).", "N", 1, "-N <directory>"));
        vector.addElement(new Option("\tThe cost matrix in Matlab single line format.", "cost-matrix", 1, "-cost-matrix <matrix>"));
        vector.addAll(Collections.list(super.listOptions()));
        return vector.elements();
    }

    public String minimizeExpectedCostTipText() {
        return "Sets whether the minimum expected cost criteria will be used. If this is false, the training data will be reweighted according to the costs assigned to each class. If true, the minimum expected cost criteria will be used.";
    }

    public String onDemandDirectoryTipText() {
        return "Sets the directory where cost files are loaded from. This option is used when the costMatrixSource is set to \"On Demand\".";
    }

    @Override // weka.classifiers.AbstractClassifier, weka.core.BatchPredictor
    public void setBatchSize(String str) {
        if (getClassifier() instanceof BatchPredictor) {
            ((BatchPredictor) getClassifier()).setBatchSize(str);
        } else {
            super.setBatchSize(str);
        }
    }

    public void setCostMatrix(CostMatrix costMatrix) {
        this.m_CostMatrix = costMatrix;
        this.m_MatrixSource = 2;
    }

    public void setCostMatrixSource(SelectedTag selectedTag) {
        if (selectedTag.getTags() == TAGS_MATRIX_SOURCE) {
            this.m_MatrixSource = selectedTag.getSelectedTag().getID();
        }
    }

    public void setMinimizeExpectedCost(boolean z) {
        this.m_MinimizeExpectedCost = z;
    }

    public void setOnDemandDirectory(File file) {
        if (file.isDirectory()) {
            this.m_OnDemandDirectory = file;
        } else {
            this.m_OnDemandDirectory = new File(file.getParent());
        }
        this.m_MatrixSource = 1;
    }

    @Override // weka.classifiers.RandomizableSingleClassifierEnhancer, weka.classifiers.SingleClassifierEnhancer, weka.classifiers.AbstractClassifier, weka.core.OptionHandler
    public void setOptions(String[] strArr) throws Exception {
        setMinimizeExpectedCost(Utils.getFlag('M', strArr));
        String option = Utils.getOption('C', strArr);
        if (option.length() != 0) {
            try {
                setCostMatrix(new CostMatrix(new BufferedReader(new FileReader(option))));
            } catch (Exception unused) {
                setCostMatrix(null);
            }
            setCostMatrixSource(new SelectedTag(2, TAGS_MATRIX_SOURCE));
            this.m_CostFile = option;
        } else {
            setCostMatrixSource(new SelectedTag(1, TAGS_MATRIX_SOURCE));
        }
        String option2 = Utils.getOption('N', strArr);
        if (option2.length() != 0) {
            setOnDemandDirectory(new File(option2));
        }
        String option3 = Utils.getOption("cost-matrix", strArr);
        if (option3.length() != 0) {
            StringWriter stringWriter = new StringWriter();
            CostMatrix.parseMatlab(option3).write(stringWriter);
            setCostMatrix(new CostMatrix(new StringReader(stringWriter.toString())));
            setCostMatrixSource(new SelectedTag(2, TAGS_MATRIX_SOURCE));
        }
        super.setOptions(strArr);
        Utils.checkForRemainingOptions(strArr);
    }

    public String toString() {
        String str;
        if (this.m_Classifier == null) {
            return "CostSensitiveClassifier: No model built yet.";
        }
        if (this.m_MinimizeExpectedCost) {
            str = "CostSensitiveClassifier using minimized expected misclasification cost\n";
        } else {
            str = "CostSensitiveClassifier using reweighted training instances\n";
        }
        return str + CSVWriter.DEFAULT_LINE_END + getClassifierSpec() + "\n\nClassifier Model\n" + this.m_Classifier.toString() + "\n\nCost Matrix\n" + this.m_CostMatrix.toString();
    }
}
