Package weka.classifiers

Examples of weka.classifiers.Evaluation


  public static double crossValidate(NaiveBayesUpdateable fullModel,
             Instances trainingSet,
             Random r) throws Exception {
    // make some copies for fast evaluation of 5-fold xval
    Classifier [] copies = AbstractClassifier.makeCopies(fullModel, 5);
    Evaluation eval = new Evaluation(trainingSet);
    // make some splits
    for (int j = 0; j < 5; j++) {
      Instances test = trainingSet.testCV(5, j);
      // unlearn these test instances
      for (int k = 0; k < test.numInstances(); k++) {
  test.instance(k).setWeight(-test.instance(k).weight());
  ((NaiveBayesUpdateable)copies[j]).updateClassifier(test.instance(k));
  // reset the weight back to its original value
  test.instance(k).setWeight(-test.instance(k).weight());
      }
      eval.evaluateModel(copies[j], test);
    }
    return eval.incorrect();
  }
View Full Code Here


   */
  public void acceptClassifier(final IncrementalClassifierEvent ce) {
    try {
      if (ce.getStatus() == IncrementalClassifierEvent.NEW_BATCH) {
  //  m_eval = new Evaluation(ce.getCurrentInstance().dataset());
  m_eval = new Evaluation(ce.getStructure());
  m_eval.useNoPriors();
 
  m_dataLegend = new Vector();
  m_reset = true;
  m_dataPoint = new double[0];
View Full Code Here

      throw new Exception("On-demand cost file doesn't exist: " + costFile);
    }
    CostMatrix costMatrix = new CostMatrix(new BufferedReader(
    new FileReader(costFile)));
   
    Evaluation eval = new Evaluation(train, costMatrix);   
    m_Classifier = AbstractClassifier.makeCopy(m_Template);
   
    trainTimeStart = System.currentTimeMillis();
    if(canMeasureCPUTime)
      CPUStartTime = thMonitor.getThreadUserTime(thID);
    m_Classifier.buildClassifier(train);
    if(canMeasureCPUTime)
      trainCPUTimeElapsed = thMonitor.getThreadUserTime(thID) - CPUStartTime;
    trainTimeElapsed = System.currentTimeMillis() - trainTimeStart;
    testTimeStart = System.currentTimeMillis();
    if(canMeasureCPUTime)
      CPUStartTime = thMonitor.getThreadUserTime(thID);
    eval.evaluateModel(m_Classifier, test);
    if(canMeasureCPUTime)
      testCPUTimeElapsed = thMonitor.getThreadUserTime(thID) - CPUStartTime;
    testTimeElapsed = System.currentTimeMillis() - testTimeStart;
    thMonitor = null;
   
    m_result = eval.toSummaryString();
    // The results stored are all per instance -- can be multiplied by the
    // number of instances to get absolute numbers
    int current = 0;
    result[current++] = new Double(train.numInstances());
    result[current++] = new Double(eval.numInstances());
   
    result[current++] = new Double(eval.correct());
    result[current++] = new Double(eval.incorrect());
    result[current++] = new Double(eval.unclassified());
    result[current++] = new Double(eval.pctCorrect());
    result[current++] = new Double(eval.pctIncorrect());
    result[current++] = new Double(eval.pctUnclassified());
    result[current++] = new Double(eval.totalCost());
    result[current++] = new Double(eval.avgCost());
   
    result[current++] = new Double(eval.meanAbsoluteError());
    result[current++] = new Double(eval.rootMeanSquaredError());
    result[current++] = new Double(eval.relativeAbsoluteError());
    result[current++] = new Double(eval.rootRelativeSquaredError());
   
    result[current++] = new Double(eval.SFPriorEntropy());
    result[current++] = new Double(eval.SFSchemeEntropy());
    result[current++] = new Double(eval.SFEntropyGain());
    result[current++] = new Double(eval.SFMeanPriorEntropy());
    result[current++] = new Double(eval.SFMeanSchemeEntropy());
    result[current++] = new Double(eval.SFMeanEntropyGain());
   
    // K&B stats
    result[current++] = new Double(eval.KBInformation());
    result[current++] = new Double(eval.KBMeanInformation());
    result[current++] = new Double(eval.KBRelativeInformation());
   
    // Timing stats
    result[current++] = new Double(trainTimeElapsed / 1000.0);
    result[current++] = new Double(testTimeElapsed / 1000.0);
    if(canMeasureCPUTime) {
View Full Code Here

     * @param data the set of instances
     * @return the error rate
     * @throws Exception if something goes wrong
     */
    protected double getErrorRate(Instances data) throws Exception {
  Evaluation eval = new Evaluation(data);
  eval.evaluateModel(this,data);
  return eval.errorRate();
    }
View Full Code Here

     * @param data the set of instances
     * @return the error
     * @throws Exception if something goes wrong
     */
    protected double getMeanAbsoluteError(Instances data) throws Exception {
  Evaluation eval = new Evaluation(data);
  eval.evaluateModel(this,data);
  return eval.meanAbsoluteError();
    }
View Full Code Here

    int addm = (m_AdditionalMeasures != null) ? m_AdditionalMeasures.length : 0;
    Object [] result = new Object[RESULT_SIZE+addm];
    long thID = Thread.currentThread().getId();
    long CPUStartTime=-1, trainCPUTimeElapsed=-1, testCPUTimeElapsed=-1,
         trainTimeStart, trainTimeElapsed, testTimeStart, testTimeElapsed;   
    Evaluation eval = new Evaluation(train);
    m_Classifier = AbstractClassifier.makeCopy(m_Template);

    trainTimeStart = System.currentTimeMillis();
    if(canMeasureCPUTime)
      CPUStartTime = thMonitor.getThreadUserTime(thID);
    m_Classifier.buildClassifier(train);
    if(canMeasureCPUTime)
      trainCPUTimeElapsed = thMonitor.getThreadUserTime(thID) - CPUStartTime;
    trainTimeElapsed = System.currentTimeMillis() - trainTimeStart;
    testTimeStart = System.currentTimeMillis();
    if(canMeasureCPUTime)
      CPUStartTime = thMonitor.getThreadUserTime(thID);
    eval.evaluateModel(m_Classifier, test);
    if(canMeasureCPUTime)
      testCPUTimeElapsed = thMonitor.getThreadUserTime(thID) - CPUStartTime;
    testTimeElapsed = System.currentTimeMillis() - testTimeStart;
    thMonitor = null;
   
    m_result = eval.toSummaryString();
    // The results stored are all per instance -- can be multiplied by the
    // number of instances to get absolute numbers
    int current = 0;
    result[current++] = new Double(train.numInstances());
    result[current++] = new Double(eval.numInstances());

    result[current++] = new Double(eval.meanAbsoluteError());
    result[current++] = new Double(eval.rootMeanSquaredError());
    result[current++] = new Double(eval.relativeAbsoluteError());
    result[current++] = new Double(eval.rootRelativeSquaredError());
    result[current++] = new Double(eval.correlationCoefficient());

    result[current++] = new Double(eval.SFPriorEntropy());
    result[current++] = new Double(eval.SFSchemeEntropy());
    result[current++] = new Double(eval.SFEntropyGain());
    result[current++] = new Double(eval.SFMeanPriorEntropy());
    result[current++] = new Double(eval.SFMeanSchemeEntropy());
    result[current++] = new Double(eval.SFMeanEntropyGain());
   
    // Timing stats
    result[current++] = new Double(trainTimeElapsed / 1000.0);
    result[current++] = new Double(testTimeElapsed / 1000.0);
    if(canMeasureCPUTime) {
      result[current++] = new Double((trainCPUTimeElapsed/1000000.0) / 1000.0);
      result[current++] = new Double((testCPUTimeElapsed /1000000.0) / 1000.0);
    }
    else {
      result[current++] = new Double(Utils.missingValue());
      result[current++] = new Double(Utils.missingValue());
    }
   
    // sizes
    ByteArrayOutputStream bastream = new ByteArrayOutputStream();
    ObjectOutputStream oostream = new ObjectOutputStream(bastream);
    oostream.writeObject(m_Classifier);
    result[current++] = new Double(bastream.size());
    bastream = new ByteArrayOutputStream();
    oostream = new ObjectOutputStream(bastream);
    oostream.writeObject(train);
    result[current++] = new Double(bastream.size());
    bastream = new ByteArrayOutputStream();
    oostream = new ObjectOutputStream(bastream);
    oostream.writeObject(test);
    result[current++] = new Double(bastream.size());
   
    // Prediction interval statistics
    result[current++] = new Double(eval.coverageOfTestCasesByPredictedRegions());
    result[current++] = new Double(eval.sizeOfPredictedRegions());

    if (m_Classifier instanceof Summarizable) {
      result[current++] = ((Summarizable)m_Classifier).toSummaryString();
    } else {
      result[current++] = null;
View Full Code Here

        classificationOutput.setHeader(mappedClassifierHeader);
      }
     
      if (!onlySetPriors) {
        if (costMatrix != null) {
          eval = new Evaluation(new Instances(mappedClassifierHeader, 0), costMatrix);
        } else {
          eval = new Evaluation(new Instances(mappedClassifierHeader, 0));
        }
      }
     
      if (!eval.getHeader().equalHeaders(inst)) {
        // When the InputMappedClassifier is loading a model,
View Full Code Here

      classificationOutput.setBuffer(outBuff);
    }
    String name = (new SimpleDateFormat("HH:mm:ss - ")).format(new Date());
    String cname = "";
          String cmd = "";
    Evaluation eval = null;
    try {
      if (m_CVBut.isSelected()) {
        testMode = 1;
        numFolds = Integer.parseInt(m_CVText.getText());
        if (numFolds <= 1) {
    throw new Exception("Number of folds must be greater than 1");
        }
      } else if (m_PercentBut.isSelected()) {
        testMode = 2;
        percent = Double.parseDouble(m_PercentText.getText());
        if ((percent <= 0) || (percent >= 100)) {
    throw new Exception("Percentage must be between 0 and 100");
        }
      } else if (m_TrainBut.isSelected()) {
        testMode = 3;
      } else if (m_TestSplitBut.isSelected()) {
        testMode = 4;
        // Check the test instance compatibility
        if (source == null) {
          throw new Exception("No user test set has been specified");
        }
       
        if (!(classifier instanceof weka.classifiers.misc.InputMappedClassifier)) {
          if (!inst.equalHeaders(userTestStructure)) {
            boolean wrapClassifier = false;
            if (!Utils.
                getDontShowDialog("weka.gui.explorer.ClassifierPanel.AutoWrapInInputMappedClassifier")) {
              JCheckBox dontShow = new JCheckBox("Do not show this message again");
              Object[] stuff = new Object[2];
              stuff[0] = "Train and test set are not compatible.\n" +
              "Would you like to automatically wrap the classifier in\n" +
              "an \"InputMappedClassifier\" before proceeding?.\n";
              stuff[1] = dontShow;

              int result = JOptionPane.showConfirmDialog(ClassifierPanel.this, stuff,
                  "ClassifierPanel", JOptionPane.YES_OPTION);
             
              if (result == JOptionPane.YES_OPTION) {
                wrapClassifier = true;
              }
             
              if (dontShow.isSelected()) {
                String response = (wrapClassifier) ? "yes" : "no";
                Utils.
                  setDontShowDialogResponse("weka.gui.explorer.ClassifierPanel.AutoWrapInInputMappedClassifier",
                      response);
              }

            } else {
              // What did the user say - do they want to autowrap or not?
              String response =
                Utils.getDontShowDialogResponse("weka.gui.explorer.ClassifierPanel.AutoWrapInInputMappedClassifier");
              if (response != null && response.equalsIgnoreCase("yes")) {
                wrapClassifier = true;
              }
            }

            if (wrapClassifier) {
              weka.classifiers.misc.InputMappedClassifier temp =
                new weka.classifiers.misc.InputMappedClassifier();

              // pass on the known test structure so that we get the
              // correct mapping report from the toString() method
              // of InputMappedClassifier
              temp.setClassifier(classifier);
              temp.setTestStructure(userTestStructure);
              classifier = temp;
            } else {
              throw new Exception("Train and test set are not compatible\n" + inst.equalHeadersMsg(userTestStructure));
            }
          }
        }
             
      } else {
        throw new Exception("Unknown test mode");
      }

      cname = classifier.getClass().getName();
      if (cname.startsWith("weka.classifiers.")) {
        name += cname.substring("weka.classifiers.".length());
      } else {
        name += cname;
      }
      cmd = classifier.getClass().getName();
      if (classifier instanceof OptionHandler)
        cmd += " " + Utils.joinOptions(((OptionHandler) classifier).getOptions());
     
      // set up the structure of the plottable instances for
      // visualization
      plotInstances = ExplorerDefaults.getClassifierErrorsPlotInstances();
      plotInstances.setInstances(inst);
      plotInstances.setClassifier(classifier);
      plotInstances.setClassIndex(inst.classIndex());
      plotInstances.setSaveForVisualization(saveVis);

      // Output some header information
      m_Log.logMessage("Started " + cname);
      m_Log.logMessage("Command: " + cmd);
      if (m_Log instanceof TaskLogger) {
        ((TaskLogger)m_Log).taskStarted();
      }
      outBuff.append("=== Run information ===\n\n");
      outBuff.append("Scheme:       " + cname);
      if (classifier instanceof OptionHandler) {
        String [] o = ((OptionHandler) classifier).getOptions();
        outBuff.append(" " + Utils.joinOptions(o));
      }
      outBuff.append("\n");
      outBuff.append("Relation:     " + inst.relationName() + '\n');
      outBuff.append("Instances:    " + inst.numInstances() + '\n');
      outBuff.append("Attributes:   " + inst.numAttributes() + '\n');
      if (inst.numAttributes() < 100) {
        for (int i = 0; i < inst.numAttributes(); i++) {
    outBuff.append("              " + inst.attribute(i).name()
             + '\n');
        }
      } else {
        outBuff.append("              [list of attributes omitted]\n");
      }

      outBuff.append("Test mode:    ");
      switch (testMode) {
        case 3: // Test on training
    outBuff.append("evaluate on training data\n");
    break;
        case 1: // CV mode
    outBuff.append("" + numFolds + "-fold cross-validation\n");
    break;
        case 2: // Percent split
    outBuff.append("split " + percent
        + "% train, remainder test\n");
    break;
        case 4: // Test on user split
    if (source.isIncremental())
      outBuff.append("user supplied test set: "
          + " size unknown (reading incrementally)\n");
    else
      outBuff.append("user supplied test set: "
          + source.getDataSet().numInstances() + " instances\n");
    break;
      }
            if (costMatrix != null) {
               outBuff.append("Evaluation cost matrix:\n")
               .append(costMatrix.toString()).append("\n");
            }
      outBuff.append("\n");
      m_History.addResult(name, outBuff);
      m_History.setSingle(name);
     
      // Build the model and output it.
      if (outputModel || (testMode == 3) || (testMode == 4)) {
        m_Log.statusMessage("Building model on training data...");

        trainTimeStart = System.currentTimeMillis();
        classifier.buildClassifier(inst);
        trainTimeElapsed = System.currentTimeMillis() - trainTimeStart;
      }

      if (outputModel) {
        outBuff.append("=== Classifier model (full training set) ===\n\n");
        outBuff.append(classifier.toString() + "\n");
        outBuff.append("\nTime taken to build model: " +
           Utils.doubleToString(trainTimeElapsed / 1000.0,2)
           + " seconds\n\n");
        m_History.updateResult(name);
        if (classifier instanceof Drawable) {
    grph = null;
    try {
      grph = ((Drawable)classifier).graph();
    } catch (Exception ex) {
    }
        }
        // copy full model for output
        SerializedObject so = new SerializedObject(classifier);
        fullClassifier = (Classifier) so.getObject();
      }
     
      switch (testMode) {
        case 3: // Test on training
        m_Log.statusMessage("Evaluating on training data...");
        eval = new Evaluation(inst, costMatrix);
       
        // make adjustments if the classifier is an InputMappedClassifier
        eval = setupEval(eval, classifier, inst, costMatrix,
            plotInstances, classificationOutput, false);
       
        //plotInstances.setEvaluation(eval);
              plotInstances.setUp();
       
        if (outputPredictionsText) {
    printPredictionsHeader(outBuff, classificationOutput, "training set");
        }

        for (int jj=0;jj<inst.numInstances();jj++) {
    plotInstances.process(inst.instance(jj), classifier, eval);
   
    if (outputPredictionsText) {
      classificationOutput.printClassification(classifier, inst.instance(jj), jj);
    }
    if ((jj % 100) == 0) {
      m_Log.statusMessage("Evaluating on training data. Processed "
              +jj+" instances...");
    }
        }
        if (outputPredictionsText)
    classificationOutput.printFooter();
        if (outputPredictionsText && classificationOutput.generatesOutput()) {
    outBuff.append("\n");
        }
        outBuff.append("=== Evaluation on training set ===\n");
        break;

        case 1: // CV mode
        m_Log.statusMessage("Randomizing instances...");
        int rnd = 1;
        try {
    rnd = Integer.parseInt(m_RandomSeedText.getText().trim());
    // System.err.println("Using random seed "+rnd);
        } catch (Exception ex) {
    m_Log.logMessage("Trouble parsing random seed value");
    rnd = 1;
        }
        Random random = new Random(rnd);
        inst.randomize(random);
        if (inst.attribute(classIndex).isNominal()) {
    m_Log.statusMessage("Stratifying instances...");
    inst.stratify(numFolds);
        }
        eval = new Evaluation(inst, costMatrix);
       
         // make adjustments if the classifier is an InputMappedClassifier
              eval = setupEval(eval, classifier, inst, costMatrix,
                  plotInstances, classificationOutput, false);
       
//        plotInstances.setEvaluation(eval);
              plotInstances.setUp();
     
        if (outputPredictionsText) {
    printPredictionsHeader(outBuff, classificationOutput, "test data");
        }

        // Make some splits and do a CV
        for (int fold = 0; fold < numFolds; fold++) {
    m_Log.statusMessage("Creating splits for fold "
            + (fold + 1) + "...");
    Instances train = inst.trainCV(numFolds, fold, random);
   
    // make adjustments if the classifier is an InputMappedClassifier
          eval = setupEval(eval, classifier, train, costMatrix,
              plotInstances, classificationOutput, true);
         
//    eval.setPriors(train);
    m_Log.statusMessage("Building model for fold "
            + (fold + 1) + "...");
    Classifier current = null;
    try {
      current = AbstractClassifier.makeCopy(template);
    } catch (Exception ex) {
      m_Log.logMessage("Problem copying classifier: " + ex.getMessage());
    }
    current.buildClassifier(train);
    Instances test = inst.testCV(numFolds, fold);
    m_Log.statusMessage("Evaluating model for fold "
            + (fold + 1) + "...");
    for (int jj=0;jj<test.numInstances();jj++) {
      plotInstances.process(test.instance(jj), current, eval);
      if (outputPredictionsText) {
        classificationOutput.printClassification(current, test.instance(jj), jj);
      }
    }
        }
        if (outputPredictionsText)
    classificationOutput.printFooter();
        if (outputPredictionsText) {
    outBuff.append("\n");
        }
        if (inst.attribute(classIndex).isNominal()) {
    outBuff.append("=== Stratified cross-validation ===\n");
        } else {
    outBuff.append("=== Cross-validation ===\n");
        }
        break;
   
        case 2: // Percent split
        if (!m_PreserveOrderBut.isSelected()) {
    m_Log.statusMessage("Randomizing instances...");
    try {
      rnd = Integer.parseInt(m_RandomSeedText.getText().trim());
    } catch (Exception ex) {
      m_Log.logMessage("Trouble parsing random seed value");
      rnd = 1;
    }
    inst.randomize(new Random(rnd));
        }
        int trainSize = (int) Math.round(inst.numInstances() * percent / 100);
        int testSize = inst.numInstances() - trainSize;
        Instances train = new Instances(inst, 0, trainSize);
        Instances test = new Instances(inst, trainSize, testSize);
        m_Log.statusMessage("Building model on training split ("+trainSize+" instances)...");
        Classifier current = null;
        try {
    current = AbstractClassifier.makeCopy(template);
        } catch (Exception ex) {
    m_Log.logMessage("Problem copying classifier: " + ex.getMessage());
        }
        current.buildClassifier(train);
        eval = new Evaluation(train, costMatrix);
       
        // make adjustments if the classifier is an InputMappedClassifier
              eval = setupEval(eval, classifier, train, costMatrix,
                  plotInstances, classificationOutput, false);
                     
//        plotInstances.setEvaluation(eval);
              plotInstances.setUp();
        m_Log.statusMessage("Evaluating on test split...");
      
        if (outputPredictionsText) {
    printPredictionsHeader(outBuff, classificationOutput, "test split");
        }
    
        for (int jj=0;jj<test.numInstances();jj++) {
    plotInstances.process(test.instance(jj), current, eval);
    if (outputPredictionsText) {
      classificationOutput.printClassification(current, test.instance(jj), jj);
    }
    if ((jj % 100) == 0) {
      m_Log.statusMessage("Evaluating on test split. Processed "
              +jj+" instances...");
    }
        }
        if (outputPredictionsText)
    classificationOutput.printFooter();
        if (outputPredictionsText) {
    outBuff.append("\n");
        }
        outBuff.append("=== Evaluation on test split ===\n");
        break;
   
        case 4: // Test on user split
        m_Log.statusMessage("Evaluating on test data...");
        eval = new Evaluation(inst, costMatrix);
        // make adjustments if the classifier is an InputMappedClassifier
              eval = setupEval(eval, classifier, inst, costMatrix,
                  plotInstances, classificationOutput, false);
             
//        plotInstances.setEvaluation(eval);
              plotInstances.setUp();
       
        if (outputPredictionsText) {
    printPredictionsHeader(outBuff, classificationOutput, "test set");
        }

        Instance instance;
        int jj = 0;
        while (source.hasMoreElements(userTestStructure)) {
    instance = source.nextElement(userTestStructure);
    plotInstances.process(instance, classifier, eval);
    if (outputPredictionsText) {
      classificationOutput.printClassification(classifier, instance, jj);
    }
    if ((++jj % 100) == 0) {
      m_Log.statusMessage("Evaluating on test data. Processed "
          +jj+" instances...");
    }
        }

        if (outputPredictionsText)
    classificationOutput.printFooter();
        if (outputPredictionsText) {
    outBuff.append("\n");
        }
        outBuff.append("=== Evaluation on test set ===\n");
        break;

        default:
        throw new Exception("Test mode not implemented");
      }
     
      if (outputSummary) {
        outBuff.append(eval.toSummaryString(outputEntropy) + "\n");
      }

      if (inst.attribute(classIndex).isNominal()) {

        if (outputPerClass) {
    outBuff.append(eval.toClassDetailsString() + "\n");
        }

        if (outputConfusion) {
    outBuff.append(eval.toMatrixString() + "\n");
        }
      }

            if (   (fullClassifier instanceof Sourcable)
                 && m_OutputSourceCode.isSelected()) {
              outBuff.append("=== Source code ===\n\n");
              outBuff.append(
                Evaluation.wekaStaticWrapper(
                    ((Sourcable) fullClassifier),
                    m_SourceCodeClass.getText()));
            }

      m_History.updateResult(name);
      m_Log.logMessage("Finished " + cname);
      m_Log.statusMessage("OK");
    } catch (Exception ex) {
      ex.printStackTrace();
      m_Log.logMessage(ex.getMessage());
      JOptionPane.showMessageDialog(ClassifierPanel.this,
            "Problem evaluating classifier:\n"
            + ex.getMessage(),
            "Evaluate classifier",
            JOptionPane.ERROR_MESSAGE);
      m_Log.statusMessage("Problem evaluating classifier");
    } finally {
      try {
              if (!saveVis && outputModel) {
      FastVector vv = new FastVector();
      vv.addElement(fullClassifier);
      Instances trainHeader = new Instances(m_Instances, 0);
      trainHeader.setClassIndex(classIndex);
      vv.addElement(trainHeader);
                  if (grph != null) {
        vv.addElement(grph);
      }
      m_History.addObject(name, vv);
              } else if (saveVis && plotInstances != null && plotInstances.getPlotInstances().numInstances() > 0) {
    m_CurrentVis = new VisualizePanel();
    m_CurrentVis.setName(name+" ("+inst.relationName()+")");
    m_CurrentVis.setLog(m_Log);
    m_CurrentVis.addPlot(plotInstances.getPlotData(cname));
    //m_CurrentVis.setColourIndex(plotInstances.getPlotInstances().classIndex()+1);
          m_CurrentVis.setColourIndex(plotInstances.getPlotInstances().classIndex());
    plotInstances.cleanUp();
     
                FastVector vv = new FastVector();
                if (outputModel) {
                  vv.addElement(fullClassifier);
                  Instances trainHeader = new Instances(m_Instances, 0);
                  trainHeader.setClassIndex(classIndex);
                  vv.addElement(trainHeader);
                  if (grph != null) {
                    vv.addElement(grph);
                  }
                }
                vv.addElement(m_CurrentVis);
               
                if ((eval != null) && (eval.predictions() != null)) {
                  vv.addElement(eval.predictions());
                  vv.addElement(inst.classAttribute());
                }
                m_History.addObject(name, vv);
        }
      } catch (Exception ex) {
View Full Code Here

            boolean outputSummary = true;
            boolean outputEntropy = m_OutputEntropyBut.isSelected();
            boolean saveVis = m_StorePredictionsBut.isSelected();
            boolean outputPredictionsText = (m_ClassificationOutputEditor.getValue().getClass() != Null.class);
            String grph = null;   
            Evaluation eval = null;

            try {

              boolean incrementalLoader = (m_TestLoader instanceof IncrementalConverter);
              if (m_TestLoader != null && m_TestLoader.getStructure() != null) {
                m_TestLoader.reset();
                source = new DataSource(m_TestLoader);
                userTestStructure = source.getStructure();
                userTestStructure.setClassIndex(m_TestClassIndex);
              }
              // Check the test instance compatibility
              if (source == null) {
                throw new Exception("No user test set has been specified");
              }
              if (trainHeader != null) {
                boolean compatibilityProblem = false;
                if (trainHeader.classIndex() >
                    userTestStructure.numAttributes()-1) {
                  compatibilityProblem = true;
                  //throw new Exception("Train and test set are not compatible");
                }
                userTestStructure.setClassIndex(trainHeader.classIndex());
                if (!trainHeader.equalHeaders(userTestStructure)) {
                  compatibilityProblem = true;
                  // throw new Exception("Train and test set are not compatible:\n" + trainHeader.equalHeadersMsg(userTestStructure));
                 
                  if (compatibilityProblem &&
                      !(classifierToUse instanceof weka.classifiers.misc.InputMappedClassifier)) {

                    boolean wrapClassifier = false;
                    if (!Utils.
                        getDontShowDialog("weka.gui.explorer.ClassifierPanel.AutoWrapInInputMappedClassifier")) {
                      JCheckBox dontShow = new JCheckBox("Do not show this message again");
                      Object[] stuff = new Object[2];
                      stuff[0] = "Data used to train model and test set are not compatible.\n" +
                      "Would you like to automatically wrap the classifier in\n" +
                      "an \"InputMappedClassifier\" before proceeding?.\n";
                      stuff[1] = dontShow;

                      int result = JOptionPane.showConfirmDialog(ClassifierPanel.this, stuff,
                          "ClassifierPanel", JOptionPane.YES_OPTION);
                     
                      if (result == JOptionPane.YES_OPTION) {
                        wrapClassifier = true;
                      }
                     
                      if (dontShow.isSelected()) {
                        String response = (wrapClassifier) ? "yes" : "no";
                        Utils.
                          setDontShowDialogResponse("weka.gui.explorer.ClassifierPanel.AutoWrapInInputMappedClassifier",
                              response);
                      }

                    } else {
                      // What did the user say - do they want to autowrap or not?
                      String response =
                        Utils.getDontShowDialogResponse("weka.gui.explorer.ClassifierPanel.AutoWrapInInputMappedClassifier");
                      if (response != null && response.equalsIgnoreCase("yes")) {
                        wrapClassifier = true;
                      }
                    }

                    if (wrapClassifier) {
                      weka.classifiers.misc.InputMappedClassifier temp =
                        new weka.classifiers.misc.InputMappedClassifier();

                      temp.setClassifier(classifierToUse);
                      temp.setModelHeader(trainHeader);
                      classifierToUse = temp;
                    } else {
                      throw new Exception("Train and test set are not compatible\n" +
                          trainHeader.equalHeadersMsg(userTestStructure));
                    }
                  }
                }
              } else {
          if (classifierToUse instanceof PMMLClassifier) {
            // set the class based on information in the mining schema
            Instances miningSchemaStructure =
              ((PMMLClassifier)classifierToUse).getMiningSchema().getMiningSchemaAsInstances();
            String className = miningSchemaStructure.classAttribute().name();
            Attribute classMatch = userTestStructure.attribute(className);
            if (classMatch == null) {
              throw new Exception("Can't find a match for the PMML target field "
            + className + " in the "
            + "test instances!");
            }
            userTestStructure.setClass(classMatch);
          } else {
            userTestStructure.
              setClassIndex(userTestStructure.numAttributes()-1);
          }
              }
              if (m_Log instanceof TaskLogger) {
                ((TaskLogger)m_Log).taskStarted();
              }
              m_Log.statusMessage("Evaluating on test data...");
              m_Log.logMessage("Re-evaluating classifier (" + name
                               + ") on test set");
              eval = new Evaluation(userTestStructure, costMatrix);
     
              // set up the structure of the plottable instances for
              // visualization if selected
              if (saveVis) {
          plotInstances = new ClassifierErrorsPlotInstances();
          plotInstances.setInstances(userTestStructure);
          plotInstances.setClassifier(classifierToUse);
          plotInstances.setClassIndex(userTestStructure.classIndex());
          plotInstances.setEvaluation(eval);
          plotInstances.setUp();
              }
             
     
              outBuff.append("\n=== Re-evaluation on test set ===\n\n");
              outBuff.append("User supplied test set\n")
              outBuff.append("Relation:     "
                             + userTestStructure.relationName() + '\n');
              if (incrementalLoader)
          outBuff.append("Instances:     unknown (yet). Reading incrementally\n");
              else
          outBuff.append("Instances:    " + source.getDataSet().numInstances() + "\n");
              outBuff.append("Attributes:   "
            + userTestStructure.numAttributes()
            + "\n\n");
              if (trainHeader == null &&
                  !(classifierToUse instanceof
                      weka.classifiers.pmml.consumer.PMMLClassifier)) {
                outBuff.append("NOTE - if test set is not compatible then results are "
                               + "unpredictable\n\n");
              }

              AbstractOutput classificationOutput = null;
              if (outputPredictionsText) {
          classificationOutput = (AbstractOutput) m_ClassificationOutputEditor.getValue();
          classificationOutput.setHeader(userTestStructure);
          classificationOutput.setBuffer(outBuff);
/*          classificationOutput.setAttributes("");
          classificationOutput.setOutputDistribution(false);*/
//          classificationOutput.printHeader();         
              }
             
              // make adjustments if the classifier is an InputMappedClassifier
              eval = setupEval(eval, classifierToUse, userTestStructure, costMatrix,
                  plotInstances, classificationOutput, false);
              eval.useNoPriors();
             
              if (outputPredictionsText) {
                printPredictionsHeader(outBuff, classificationOutput, "user test set");
              }

        Instance instance;
        int jj = 0;
        while (source.hasMoreElements(userTestStructure)) {
    instance = source.nextElement(userTestStructure);
    plotInstances.process(instance, classifierToUse, eval);
    if (outputPredictionsText) {
      classificationOutput.printClassification(classifierToUse, instance, jj);
    }
    if ((++jj % 100) == 0) {
      m_Log.statusMessage("Evaluating on test data. Processed "
          +jj+" instances...");
    }
        }

        if (outputPredictionsText)
    classificationOutput.printFooter();
              if (outputPredictionsText && classificationOutput.generatesOutput()) {
                outBuff.append("\n");
              }
     
              if (outputSummary) {
                outBuff.append(eval.toSummaryString(outputEntropy) + "\n");
              }
     
              if (userTestStructure.classAttribute().isNominal()) {
 
                if (outputPerClass) {
                  outBuff.append(eval.toClassDetailsString() + "\n");
                }
 
                if (outputConfusion) {
                  outBuff.append(eval.toMatrixString() + "\n");
                }
              }
     
              m_History.updateResult(name);
              m_Log.logMessage("Finished re-evaluation");
              m_Log.statusMessage("OK");
            } catch (Exception ex) {
              ex.printStackTrace();
              m_Log.logMessage(ex.getMessage());
              m_Log.statusMessage("See error log");

              ex.printStackTrace();
              m_Log.logMessage(ex.getMessage());
              JOptionPane.showMessageDialog(ClassifierPanel.this,
                                            "Problem evaluating classifier:\n"
                                            + ex.getMessage(),
                                            "Evaluate classifier",
                                            JOptionPane.ERROR_MESSAGE);
              m_Log.statusMessage("Problem evaluating classifier");
            } finally {
              try {
          if (classifierToUse instanceof PMMLClassifier) {
            // signal the end of the scoring run so
            // that the initialized state can be reset
            // (forces the field mapping to be recomputed
            // for the next scoring run).
            ((PMMLClassifier)classifierToUse).done();
          }
         
                if (plotInstances != null && plotInstances.getPlotInstances().numInstances() > 0) {
                  m_CurrentVis = new VisualizePanel();
                  m_CurrentVis.setName(name + " (" + userTestStructure.relationName() + ")");
                  m_CurrentVis.setLog(m_Log);
                  m_CurrentVis.addPlot(plotInstances.getPlotData(name));
                  //m_CurrentVis.setColourIndex(plotInstances.getPlotInstances().classIndex()+1);
                  m_CurrentVis.setColourIndex(plotInstances.getPlotInstances().classIndex());
                  plotInstances.cleanUp();
   
                  if (classifierToUse instanceof Drawable) {
                    try {
                      grph = ((Drawable)classifierToUse).graph();
                    } catch (Exception ex) {
                    }
                  }

                  if (saveVis) {
                    FastVector vv = new FastVector();
                    vv.addElement(classifier);
                    if (trainHeader != null) vv.addElement(trainHeader);
                    vv.addElement(m_CurrentVis);
                    if (grph != null) {
                      vv.addElement(grph);
                    }
                    if ((eval != null) && (eval.predictions() != null)) {
                      vv.addElement(eval.predictions());
                      vv.addElement(userTestStructure.classAttribute());
                    }
                    m_History.addObject(name, vv);
                  } else {
                    FastVector vv = new FastVector();
View Full Code Here

    boolean canMeasureCPUTime = thMonitor.isThreadCpuTimeSupported();
    if(canMeasureCPUTime && !thMonitor.isThreadCpuTimeEnabled())
      thMonitor.setThreadCpuTimeEnabled(true);
   
    Object [] result = new Object[overall_length];
    Evaluation eval = new Evaluation(train);
    m_Classifier = AbstractClassifier.makeCopy(m_Template);
    double [] predictions;
    long thID = Thread.currentThread().getId();
    long CPUStartTime=-1, trainCPUTimeElapsed=-1, testCPUTimeElapsed=-1,
         trainTimeStart, trainTimeElapsed, testTimeStart, testTimeElapsed;   

    //training classifier
    trainTimeStart = System.currentTimeMillis();
    if(canMeasureCPUTime)
      CPUStartTime = thMonitor.getThreadUserTime(thID);
    m_Classifier.buildClassifier(train);   
    if(canMeasureCPUTime)
      trainCPUTimeElapsed = thMonitor.getThreadUserTime(thID) - CPUStartTime;
    trainTimeElapsed = System.currentTimeMillis() - trainTimeStart;
   
    //testing classifier
    testTimeStart = System.currentTimeMillis();
    if(canMeasureCPUTime)
      CPUStartTime = thMonitor.getThreadUserTime(thID);
    predictions = eval.evaluateModel(m_Classifier, test);
    if(canMeasureCPUTime)
      testCPUTimeElapsed = thMonitor.getThreadUserTime(thID) - CPUStartTime;
    testTimeElapsed = System.currentTimeMillis() - testTimeStart;
    thMonitor = null;
   
    m_result = eval.toSummaryString();
    // The results stored are all per instance -- can be multiplied by the
    // number of instances to get absolute numbers
    int current = 0;
    result[current++] = new Double(train.numInstances());
    result[current++] = new Double(eval.numInstances());
    result[current++] = new Double(eval.correct());
    result[current++] = new Double(eval.incorrect());
    result[current++] = new Double(eval.unclassified());
    result[current++] = new Double(eval.pctCorrect());
    result[current++] = new Double(eval.pctIncorrect());
    result[current++] = new Double(eval.pctUnclassified());
    result[current++] = new Double(eval.kappa());
   
    result[current++] = new Double(eval.meanAbsoluteError());
    result[current++] = new Double(eval.rootMeanSquaredError());
    result[current++] = new Double(eval.relativeAbsoluteError());
    result[current++] = new Double(eval.rootRelativeSquaredError());
   
    result[current++] = new Double(eval.SFPriorEntropy());
    result[current++] = new Double(eval.SFSchemeEntropy());
    result[current++] = new Double(eval.SFEntropyGain());
    result[current++] = new Double(eval.SFMeanPriorEntropy());
    result[current++] = new Double(eval.SFMeanSchemeEntropy());
    result[current++] = new Double(eval.SFMeanEntropyGain());
   
    // K&B stats
    result[current++] = new Double(eval.KBInformation());
    result[current++] = new Double(eval.KBMeanInformation());
    result[current++] = new Double(eval.KBRelativeInformation());
   
    // IR stats
    result[current++] = new Double(eval.truePositiveRate(m_IRclass));
    result[current++] = new Double(eval.numTruePositives(m_IRclass));
    result[current++] = new Double(eval.falsePositiveRate(m_IRclass));
    result[current++] = new Double(eval.numFalsePositives(m_IRclass));
    result[current++] = new Double(eval.trueNegativeRate(m_IRclass));
    result[current++] = new Double(eval.numTrueNegatives(m_IRclass));
    result[current++] = new Double(eval.falseNegativeRate(m_IRclass));
    result[current++] = new Double(eval.numFalseNegatives(m_IRclass));
    result[current++] = new Double(eval.precision(m_IRclass));
    result[current++] = new Double(eval.recall(m_IRclass));
    result[current++] = new Double(eval.fMeasure(m_IRclass));
    result[current++] = new Double(eval.areaUnderROC(m_IRclass));
   
    // Weighted IR stats
    result[current++] = new Double(eval.weightedTruePositiveRate());
    result[current++] = new Double(eval.weightedFalsePositiveRate());
    result[current++] = new Double(eval.weightedTrueNegativeRate());
    result[current++] = new Double(eval.weightedFalseNegativeRate());
    result[current++] = new Double(eval.weightedPrecision());
    result[current++] = new Double(eval.weightedRecall());
    result[current++] = new Double(eval.weightedFMeasure());
    result[current++] = new Double(eval.weightedAreaUnderROC());
   
    // Unweighted IR stats
    result[current++] = new Double(eval.unweightedMacroFmeasure());
    result[current++] = new Double(eval.unweightedMicroFmeasure());
   
    // Timing stats
    result[current++] = new Double(trainTimeElapsed / 1000.0);
    result[current++] = new Double(testTimeElapsed / 1000.0);
    if(canMeasureCPUTime) {
      result[current++] = new Double((trainCPUTimeElapsed/1000000.0) / 1000.0);
      result[current++] = new Double((testCPUTimeElapsed /1000000.0) / 1000.0);
    }
    else {
      result[current++] = new Double(Utils.missingValue());
      result[current++] = new Double(Utils.missingValue());
    }

    // sizes
    ByteArrayOutputStream bastream = new ByteArrayOutputStream();
    ObjectOutputStream oostream = new ObjectOutputStream(bastream);
    oostream.writeObject(m_Classifier);
    result[current++] = new Double(bastream.size());
    bastream = new ByteArrayOutputStream();
    oostream = new ObjectOutputStream(bastream);
    oostream.writeObject(train);
    result[current++] = new Double(bastream.size());
    bastream = new ByteArrayOutputStream();
    oostream = new ObjectOutputStream(bastream);
    oostream.writeObject(test);
    result[current++] = new Double(bastream.size());
   
    // Prediction interval statistics
    result[current++] = new Double(eval.coverageOfTestCasesByPredictedRegions());
    result[current++] = new Double(eval.sizeOfPredictedRegions());

    // IDs
    if (getAttributeID() >= 0){
      String idsString = "";
      if (test.attribute(m_attID).isNumeric()){
View Full Code Here

TOP

Related Classes of weka.classifiers.Evaluation

Copyright © 2018 www.massapicom. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.