package edu.harvard.seas.iis.abilities.classify;

import java.io.File;
import java.io.IOException;
import java.io.PrintWriter;
import java.util.Collection;
import java.util.Enumeration;
import java.util.Vector;
import weka.core.Attribute;
import weka.core.FastVector;
import weka.core.Instances;
import weka.core.converters.ArffLoader;
import weka.core.converters.CSVSaver;

/* loaded from: input_file:edu/harvard/seas/iis/abilities/classify/DataSet.class */
public class DataSet extends Instances {
    public DataSet(String str, FastVector fastVector, int i) {
        super(str, fastVector, i);
        setClass(attribute("Class"));
    }

    public DataSet(Instances instances) {
        super(instances);
        setClass(attribute("Class"));
    }

    public void addInstances(Instances instances) {
        if (instances == null) {
            return;
        }
        for (int i = 0; i < instances.numInstances(); i++) {
            add(instances.instance(i));
            for (int i2 = 0; i2 < numAttributes(); i2++) {
                if (attribute(i2).isString()) {
                    lastInstance().setValue(i2, instances.instance(i).stringValue(i2));
                }
            }
        }
    }

    public DataSet getInstancesWithAttributeValueEqual(Attribute attribute, String str) {
        DataSet dataSet = new DataSet(this);
        for (int numInstances = dataSet.numInstances() - 1; numInstances >= 0; numInstances--) {
            if (!str.equals(dataSet.instance(numInstances).stringValue(attribute))) {
                dataSet.delete(numInstances);
            }
        }
        return dataSet;
    }

    public DataSet getInstancesWithAttributeValueNotEqual(Attribute attribute, String str) {
        DataSet dataSet = new DataSet(this);
        for (int numInstances = dataSet.numInstances() - 1; numInstances >= 0; numInstances--) {
            if (str.equals(dataSet.instance(numInstances).stringValue(attribute))) {
                dataSet.delete(numInstances);
            }
        }
        return dataSet;
    }

    public DataSet getInstancesWithAttributeValues(Attribute attribute, Collection<String> collection) {
        DataSet dataSet = new DataSet(this);
        for (int numInstances = dataSet.numInstances() - 1; numInstances >= 0; numInstances--) {
            if (!collection.contains(dataSet.instance(numInstances).stringValue(attribute))) {
                dataSet.delete(numInstances);
            }
        }
        return dataSet;
    }

    public DataSet getInstancesWithAttributeValueGreaterThan(Attribute attribute, double d) {
        DataSet dataSet = new DataSet(this);
        for (int numInstances = dataSet.numInstances() - 1; numInstances >= 0; numInstances--) {
            if (d >= dataSet.instance(numInstances).value(attribute)) {
                dataSet.delete(numInstances);
            }
        }
        return dataSet;
    }

    public DataSet getExplicitInstances() {
        return getInstancesWithAttributeValueEqual(attribute("Class"), "explicit");
    }

    public DataSet getImplicitInstances() {
        return getInstancesWithAttributeValueEqual(attribute("Class"), "implicit");
    }

    public DataSet getInstancesForUser(String str) {
        return getInstancesWithAttributeValueEqual(attribute("User"), str);
    }

    public int getNumImplicitInstances() {
        return attributeStats(attribute("Class").index()).nominalCounts[attribute("Class").indexOfValue("implicit")];
    }

    public int getNumExplicitInstances() {
        return attributeStats(attribute("Class").index()).nominalCounts[attribute("Class").indexOfValue("explicit")];
    }

    public Vector<String> getValuesOfStringOrNominalAttribute(Attribute attribute) {
        Enumeration enumerateValues = attribute.enumerateValues();
        Vector<String> vector = new Vector<>();
        while (enumerateValues.hasMoreElements()) {
            vector.add(enumerateValues.nextElement().toString());
        }
        return vector;
    }

    public void setValue(Attribute attribute, String str, InstanceFilter instanceFilter) {
        for (int i = 0; i < numInstances(); i++) {
            if (instanceFilter == null || instanceFilter.evaluateInstance(instance(i), this)) {
                instance(i).setValue(attribute, str);
            }
        }
    }

    public double[] attributeToDoubleArray(String str) {
        return attributeToDoubleArray(attribute(str).index());
    }

    public void saveAsBothARFFandCSV(String str) throws IOException {
        saveAsARFF(String.valueOf(str) + Instances.FILE_EXTENSION);
        saveAsCSV(String.valueOf(str) + ".csv");
    }

    public void saveAsARFF(String str) throws IOException {
        PrintWriter printWriter = new PrintWriter(str);
        printWriter.write(toString());
        printWriter.close();
    }

    public void saveAsCSV(String str) throws IOException {
        CSVSaver cSVSaver = new CSVSaver();
        cSVSaver.setInstances(this);
        cSVSaver.setFile(new File(str));
        cSVSaver.writeBatch();
    }

    public static DataSet fromArffFile(File file) throws IOException {
        ArffLoader arffLoader = new ArffLoader();
        arffLoader.setSource(file);
        return new DataSet(arffLoader.getDataSet());
    }

    public static DataSet fromArffFiles(File[] fileArr) throws IOException {
        if (fileArr.length == 0) {
            return null;
        }
        DataSet fromArffFile = fromArffFile(fileArr[0]);
        for (int i = 1; i < fileArr.length; i++) {
            fromArffFile.addInstances(fromArffFile(fileArr[i]));
        }
        for (int i2 = 0; i2 < fromArffFile.numInstances(); i2++) {
            try {
                fromArffFile.instance(i2).toString();
            } catch (Exception e) {
                e.printStackTrace();
            }
        }
        return fromArffFile;
    }
}
