package edu.harvard.seas.iis.abilities.classify;

import edu.harvard.seas.iis.abilities.analysis.IISMouseLogParser;
import edu.harvard.seas.iis.abilities.analysis.Movement;
import edu.harvard.seas.iis.abilities.analysis.MovementFilter;
import edu.harvard.seas.iis.abilities.analysis.Parser;
import edu.harvard.seas.iis.util.Logger;
import edu.harvard.seas.iis.util.collections.PrettyPrint;
import edu.harvard.seas.iis.util.io.FileManipulation;
import java.io.File;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.util.Collection;
import java.util.Vector;
import weka.classifiers.lazy.kstar.KStarConstants;
import weka.core.Attribute;
import weka.core.Instance;

/* loaded from: input_file:edu/harvard/seas/iis/abilities/classify/MovementClassifier.class */
public class MovementClassifier {
    public static final String DELIBERATE_PROBABILITY_KEY = "deliberate probability";
    public static final String PREDICTED_CLASS_KEY = "predicted class";
    protected PositiveAndUnlabeledClassifier c1;
    protected PositiveAndUnlabeledClassifier c2;
    protected PositiveAndUnlabeledClassifier c3;
    protected PositiveAndUnlabeledClassifier c4;
    protected NormalizationConstants normalizationConstants;
    public static final String REPORT_HEADER = "File name\ttotal number of movements\tdeliberate movements\tdistracted movements\tfraction of deliberate movements";

    public MovementClassifier() {
        try {
            loadClassifiers(null);
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    public MovementClassifier(String str) {
        try {
            loadClassifiers(str);
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    protected void loadClassifiers(String str) throws IOException {
        if (str != null) {
            try {
                if (!"".equals(str.trim())) {
                    this.c1 = PositiveAndUnlabeledClassifier.deserializeFromFile(new File(String.valueOf(str) + File.separator + "c1.classifier"));
                    this.c2 = PositiveAndUnlabeledClassifier.deserializeFromFile(new File(String.valueOf(str) + File.separator + "c2.classifier"));
                    this.c3 = PositiveAndUnlabeledClassifier.deserializeFromFile(new File(String.valueOf(str) + File.separator + "c3.classifier"));
                    this.c4 = PositiveAndUnlabeledClassifier.deserializeFromFile(new File(String.valueOf(str) + File.separator + "c4.classifier"));
                    this.normalizationConstants = (NormalizationConstants) FileManipulation.readObjectFromFile(new File(String.valueOf(str) + File.separator + "normalizationConstants"));
                }
            } catch (IOException e) {
                System.err.println("Problems deserializing the classifiers");
                e.printStackTrace();
                return;
            } catch (ClassNotFoundException e2) {
                e2.printStackTrace();
                return;
            }
        }
        this.c1 = (PositiveAndUnlabeledClassifier) new ObjectInputStream(MovementClassifier.class.getClassLoader().getResourceAsStream("edu/harvard/seas/iis/abilities/classify/resources/c1.classifier")).readObject();
        this.c2 = (PositiveAndUnlabeledClassifier) new ObjectInputStream(MovementClassifier.class.getClassLoader().getResourceAsStream("edu/harvard/seas/iis/abilities/classify/resources/c2.classifier")).readObject();
        this.c3 = (PositiveAndUnlabeledClassifier) new ObjectInputStream(MovementClassifier.class.getClassLoader().getResourceAsStream("edu/harvard/seas/iis/abilities/classify/resources/c3.classifier")).readObject();
        this.c4 = (PositiveAndUnlabeledClassifier) new ObjectInputStream(MovementClassifier.class.getClassLoader().getResourceAsStream("edu/harvard/seas/iis/abilities/classify/resources/c4.classifier")).readObject();
        this.normalizationConstants = (NormalizationConstants) new ObjectInputStream(MovementClassifier.class.getClassLoader().getResourceAsStream("edu/harvard/seas/iis/abilities/classify/resources/normalizationConstants")).readObject();
    }

    public void classifyMovements(Collection<Movement> collection, Vector<Movement> vector, Vector<Movement> vector2, boolean z) throws Exception {
        for (Movement movement : collection) {
            if (getDeliberateProbability(movement, z) >= 0.5d) {
                vector.add(movement);
            } else if (getDeliberateProbability(movement, z) >= KStarConstants.FLOOR) {
                vector2.add(movement);
            }
        }
    }

    public double getDeliberateProbability(Movement movement, boolean z) throws Exception {
        UserDataSet userDataSet = new UserDataSet(1);
        userDataSet.addMovement(movement, "implicit", "test");
        if (userDataSet.numInstances() <= 0) {
            Logger.log(3, "Could not process this movement (perhaps an NaN or an Infinity somewhere?)");
            return -1.0d;
        }
        double deliberateProbability = getDeliberateProbability(Transform.normalizeUsingNormalizationConstants(Transform.computeAdditonalFeatures(userDataSet), Settings.FEATURES_TO_NORMALIZE, this.normalizationConstants).firstInstance(), z);
        movement.setAdditionalMetaData(DELIBERATE_PROBABILITY_KEY, Double.valueOf(deliberateProbability));
        movement.setAdditionalMetaData(PREDICTED_CLASS_KEY, Boolean.valueOf(deliberateProbability >= 0.5d));
        return deliberateProbability;
    }

    public double getDeliberateProbability(Instance instance, boolean z) throws Exception {
        double d = -1.0d;
        try {
            d = (z ? this.c1 : this.c3).getDeliberateProbability(instance);
        } catch (Exception e) {
            System.err.println("Trouble when trying to classify " + instance);
            e.printStackTrace();
        }
        return d;
    }

    public DataSet classifyMovements(DataSet dataSet, boolean z) throws Exception {
        Attribute attribute = dataSet.attribute("Prediction probability");
        Attribute attribute2 = dataSet.attribute("Predicted class");
        for (int i = 0; i < dataSet.numInstances(); i++) {
            Instance instance = dataSet.instance(i);
            double deliberateProbability = getDeliberateProbability(instance, z);
            instance.setValue(attribute, deliberateProbability);
            instance.setValue(attribute2, deliberateProbability > 0.5d ? 1 : 0);
        }
        return dataSet;
    }

    public Vector<Movement> parseAndClassify(File[] fileArr, MovementFilter movementFilter, Parser parser, boolean z, boolean z2, Vector<String> vector) throws Exception {
        Vector<Movement> vector2 = new Vector<>();
        for (File file : fileArr) {
            Logger.log("Working on " + file);
            Vector<Movement> parseMovementLog = parser.parseMovementLog(new File[]{file});
            int i = 0;
            int i2 = 0;
            int i3 = 0;
            for (int i4 = 0; i4 < parseMovementLog.size(); i4++) {
                if (movementFilter == null || movementFilter.evaluateMovement(parseMovementLog.elementAt(i4))) {
                    if (getDeliberateProbability(parseMovementLog.elementAt(i4), z) >= 0.5d) {
                        i2++;
                    } else {
                        i3++;
                    }
                    i++;
                }
            }
            if (vector != null) {
                vector.add(String.valueOf(file.getName()) + " \t" + i + "\t" + i2 + "\t" + i3 + "\t" + (i2 / i));
            }
            if (z2) {
                vector2.addAll(parseMovementLog);
            }
        }
        return vector2;
    }

    public static void main(String[] strArr) throws Exception {
        final String str = strArr.length > 0 ? strArr[0] : null;
        MovementClassifier movementClassifier = new MovementClassifier();
        Vector<String> vector = new Vector<>();
        vector.add(REPORT_HEADER);
        System.out.println("\nSelect log files containing movement traces to be analyzed\n");
        movementClassifier.parseAndClassify(FileManipulation.getUserSpecifiedFilesForReading(new File(Settings.DATA_DIRECTORY)), new MovementFilter() { // from class: edu.harvard.seas.iis.abilities.classify.MovementClassifier.1
            @Override // edu.harvard.seas.iis.abilities.analysis.MovementFilter
            public boolean evaluateMovement(Movement movement) {
                if (str != null) {
                    return str.equals(movement.getTargetType());
                }
                return true;
            }
        }, new IISMouseLogParser(), true, false, vector);
        if (str != null) {
            System.out.println("Only including movements where target type = " + str);
        }
        System.out.println("\n\n=============================\n\nPaste the output that follows into a spreadsheet for an easy to read summary of your data\n\n");
        System.out.println(PrettyPrint.toPrettyString(vector, ""));
    }
}
