package edu.harvard.seas.iis.abilities.classify;

import edu.harvard.seas.iis.util.collections.PrettyPrint;
import edu.harvard.seas.iis.util.io.FileManipulation;
import java.io.File;
import java.io.IOException;
import java.util.Enumeration;
import weka.classifiers.Classifier;
import weka.classifiers.functions.Logistic;
import weka.classifiers.lazy.kstar.KStarConstants;
import weka.classifiers.meta.FilteredClassifier;
import weka.core.Instance;
import weka.core.Instances;
import weka.filters.unsupervised.attribute.Remove;

/* loaded from: input_file:edu/harvard/seas/iis/abilities/classify/PositiveAndUnlabeledClassifier.class */
public class PositiveAndUnlabeledClassifier extends Classifier {
    private static final long serialVersionUID = 1;
    protected Classifier filteredClassifier;
    protected Classifier baseClassifier;
    protected double c;
    protected static final int FOLDS = 10;
    protected int methodForEvaluatingC;
    protected String[] allowedFeatures;

    public PositiveAndUnlabeledClassifier() throws Exception {
        this(new Logistic());
    }

    public PositiveAndUnlabeledClassifier(Classifier classifier) {
        this(classifier, Settings.ALLOWED_FEATURES);
    }

    public PositiveAndUnlabeledClassifier(Classifier classifier, String[] strArr) {
        this.c = KStarConstants.FLOOR;
        this.methodForEvaluatingC = 2;
        this.baseClassifier = classifier;
        this.allowedFeatures = strArr;
    }

    protected void setUnderlyingClassifier(Instances instances) {
        FilteredClassifier filteredClassifier = new FilteredClassifier();
        Remove remove = new Remove();
        int[] iArr = new int[this.allowedFeatures.length];
        for (int i = 0; i < this.allowedFeatures.length; i++) {
            if (instances.attribute(this.allowedFeatures[i]) == null) {
                System.err.println("Can't find attribute for " + this.allowedFeatures[i]);
            }
            iArr[i] = instances.attribute(this.allowedFeatures[i]).index();
        }
        remove.setAttributeIndicesArray(iArr);
        remove.setInvertSelection(true);
        filteredClassifier.setFilter(remove);
        filteredClassifier.setClassifier(this.baseClassifier);
        this.filteredClassifier = filteredClassifier;
    }

    public int getMethodForEvaluatingC() {
        return this.methodForEvaluatingC;
    }

    public void setMethodForEvaluatingC(int i) {
        this.methodForEvaluatingC = i;
    }

    public String[] getAllowedFeatures() {
        return this.allowedFeatures;
    }

    public void setAllowedFeatures(String[] strArr) {
        this.allowedFeatures = strArr;
    }

    @Override // weka.classifiers.Classifier
    public double classifyInstance(Instance instance) throws Exception {
        return this.filteredClassifier.distributionForInstance(instance)[1] / this.c > 0.5d ? 1 : 0;
    }

    public double getDeliberateProbability(Instance instance) throws Exception {
        return this.filteredClassifier.distributionForInstance(instance)[1] / this.c;
    }

    @Override // weka.classifiers.Classifier
    public double[] distributionForInstance(Instance instance) throws Exception {
        double[] distributionForInstance = this.filteredClassifier.distributionForInstance(instance);
        for (int i = 0; i < distributionForInstance.length; i++) {
            int i2 = i;
            distributionForInstance[i2] = distributionForInstance[i2] / this.c;
        }
        distributionForInstance[0] = 1.0d - distributionForInstance[1];
        return distributionForInstance;
    }

    @Override // weka.classifiers.Classifier
    public void buildClassifier(Instances instances) throws Exception {
        setUnderlyingClassifier(instances);
        switch (this.methodForEvaluatingC) {
            case 1:
                buildClassifier1(instances);
                return;
            case 2:
                buildClassifier2(instances);
                return;
            case 3:
                buildClassifier3(instances);
                return;
            default:
                throw new IllegalArgumentException("Illegal method for evaluaating C: " + this.methodForEvaluatingC);
        }
    }

    public void buildClassifier1(Instances instances) throws Exception {
        int i = 0;
        for (int i2 = 0; i2 < 10; i2++) {
            Instances trainCV = instances.trainCV(10, i2);
            Instances testCV = instances.testCV(10, i2);
            this.filteredClassifier.setDebug(true);
            this.filteredClassifier.buildClassifier(trainCV);
            Enumeration enumerateInstances = testCV.enumerateInstances();
            while (enumerateInstances.hasMoreElements()) {
                Instance instance = (Instance) enumerateInstances.nextElement();
                if (instance.classValue() == 1.0d) {
                    this.c += this.filteredClassifier.distributionForInstance(instance)[1];
                    i++;
                }
            }
        }
        this.c /= i;
        this.filteredClassifier.buildClassifier(instances);
    }

    public void buildClassifier2(Instances instances) throws Exception {
        this.filteredClassifier.setDebug(true);
        double d = 0.0d;
        double d2 = 0.0d;
        for (int i = 0; i < 10; i++) {
            Instances trainCV = instances.trainCV(10, i);
            Instances testCV = instances.testCV(10, i);
            this.filteredClassifier.buildClassifier(trainCV);
            Enumeration enumerateInstances = testCV.enumerateInstances();
            while (enumerateInstances.hasMoreElements()) {
                Instance instance = (Instance) enumerateInstances.nextElement();
                if (instance.classValue() == 1.0d) {
                    d += this.filteredClassifier.distributionForInstance(instance)[1];
                } else {
                    d2 += this.filteredClassifier.distributionForInstance(instance)[1];
                }
            }
        }
        this.c = d / (d + d2);
        this.filteredClassifier.buildClassifier(instances);
    }

    public void buildClassifier3(Instances instances) throws Exception {
        this.filteredClassifier.setDebug(true);
        for (int i = 0; i < 10; i++) {
            Instances trainCV = instances.trainCV(10, i);
            Instances testCV = instances.testCV(10, i);
            this.filteredClassifier.buildClassifier(trainCV);
            Enumeration enumerateInstances = testCV.enumerateInstances();
            while (enumerateInstances.hasMoreElements()) {
                Instance instance = (Instance) enumerateInstances.nextElement();
                double d = this.filteredClassifier.distributionForInstance(instance)[1];
                if (instance.classValue() != 1.0d && this.c < d) {
                    this.c = d;
                }
            }
        }
    }

    public static PositiveAndUnlabeledClassifier deserializeFromFile(File file) throws IOException, ClassNotFoundException {
        return (PositiveAndUnlabeledClassifier) FileManipulation.readObjectFromFile(file);
    }

    public String toString() {
        return this.baseClassifier.getClass() + "\t" + this.methodForEvaluatingC + "\t" + PrettyPrint.toPrettyLine(this.allowedFeatures, ",");
    }
}
