Package weka.classifiers.xml

Examples of weka.classifiers.xml.XMLClassifier


        } catch (IllegalArgumentException ex) {
          success = false;
        }
        if (!success) {
          // load options from serialized data  ('-l' is automatically erased!)
          XMLClassifier xmlserial = new XMLClassifier();
          OptionHandler cl = (OptionHandler) xmlserial.read(Utils.getOption('l', options));

          // merge options
          optionsTmp = new String[options.length + cl.getOptions().length];
          System.arraycopy(cl.getOptions(), 0, optionsTmp, 0, cl.getOptions().length);
          System.arraycopy(options, 0, optionsTmp, cl.getOptions().length, options.length);
          options = optionsTmp;
        }
      }

      noCrossValidation = Utils.getFlag("no-cv", options);
      // Get basic options (options the same for all schemes)
      classIndexString = Utils.getOption('c', options);
      if (classIndexString.length() != 0) {
        if (classIndexString.equals("first"))
          classIndex = 1;
        else if (classIndexString.equals("last"))
          classIndex = -1;
        else
          classIndex = Integer.parseInt(classIndexString);
      }
      trainFileName = Utils.getOption('t', options);
      objectInputFileName = Utils.getOption('l', options);
      objectOutputFileName = Utils.getOption('d', options);
      testFileName = Utils.getOption('T', options);
      foldsString = Utils.getOption('x', options);
      if (foldsString.length() != 0) {
        folds = Integer.parseInt(foldsString);
      }
      seedString = Utils.getOption('s', options);
      if (seedString.length() != 0) {
        seed = Integer.parseInt(seedString);
      }
      if (trainFileName.length() == 0) {
        if (objectInputFileName.length() == 0) {
          throw new Exception("No training file and no object input file given.");
        }
        if (testFileName.length() == 0) {
          throw new Exception("No training file and no test file given.");
        }
      } else if ((objectInputFileName.length() != 0) &&
          ((!(classifier instanceof UpdateableClassifier)) ||
           (testFileName.length() == 0))) {
        throw new Exception("Classifier not incremental, or no " +
            "test file provided: can't "+
            "use both train and model file.");
      }
      try {
        if (trainFileName.length() != 0) {
          trainSetPresent = true;
          trainSource = new DataSource(trainFileName);
        }
        if (testFileName.length() != 0) {
          testSetPresent = true;
          testSource = new DataSource(testFileName);
        }
        if (objectInputFileName.length() != 0) {
          if (objectInputFileName.endsWith(".xml")) {
            // if this is the case then it means that a PMML classifier was
            // successfully loaded earlier in the code
            objectInputStream = null;
            xmlInputStream = null;
          } else {
            InputStream is = new FileInputStream(objectInputFileName);
            if (objectInputFileName.endsWith(".gz")) {
              is = new GZIPInputStream(is);
            }
            // load from KOML?
            if (!(objectInputFileName.endsWith(".koml") && KOML.isPresent()) ) {
              objectInputStream = new ObjectInputStream(is);
              xmlInputStream    = null;
            }
            else {
              objectInputStream = null;
              xmlInputStream    = new BufferedInputStream(is);
            }
          }
        }
      } catch (Exception e) {
        throw new Exception("Can't open file " + e.getMessage() + '.');
      }
      if (testSetPresent) {
        template = test = testSource.getStructure();
        if (classIndex != -1) {
          test.setClassIndex(classIndex - 1);
        } else {
          if ( (test.classIndex() == -1) || (classIndexString.length() != 0) )
            test.setClassIndex(test.numAttributes() - 1);
        }
        actualClassIndex = test.classIndex();
      }
      else {
        // percentage split
        splitPercentageString = Utils.getOption("split-percentage", options);
        if (splitPercentageString.length() != 0) {
          if (foldsString.length() != 0)
            throw new Exception(
                "Percentage split cannot be used in conjunction with "
                + "cross-validation ('-x').");
          splitPercentage = Double.parseDouble(splitPercentageString);
          if ((splitPercentage <= 0) || (splitPercentage >= 100))
            throw new Exception("Percentage split value needs be >0 and <100.");
        }
        else {
          splitPercentage = -1;
        }
        preserveOrder = Utils.getFlag("preserve-order", options);
        if (preserveOrder) {
          if (splitPercentage == -1)
            throw new Exception("Percentage split ('-percentage-split') is missing.");
        }
        // create new train/test sources
        if (splitPercentage > 0) {
          testSetPresent = true;
          Instances tmpInst = trainSource.getDataSet(actualClassIndex);
          if (!preserveOrder)
            tmpInst.randomize(new Random(seed));
          int trainSize =
            (int) Math.round(tmpInst.numInstances() * splitPercentage / 100);
          int testSize  = tmpInst.numInstances() - trainSize;
          Instances trainInst = new Instances(tmpInst, 0, trainSize);
          Instances testInst  = new Instances(tmpInst, trainSize, testSize);
          trainSource = new DataSource(trainInst);
          testSource  = new DataSource(testInst);
          template = test = testSource.getStructure();
          if (classIndex != -1) {
            test.setClassIndex(classIndex - 1);
          } else {
            if ( (test.classIndex() == -1) || (classIndexString.length() != 0) )
              test.setClassIndex(test.numAttributes() - 1);
          }
          actualClassIndex = test.classIndex();
        }
      }
      if (trainSetPresent) {
        template = train = trainSource.getStructure();
        if (classIndex != -1) {
          train.setClassIndex(classIndex - 1);
        } else {
          if ( (train.classIndex() == -1) || (classIndexString.length() != 0) )
            train.setClassIndex(train.numAttributes() - 1);
        }
        actualClassIndex = train.classIndex();
        if (!(classifier instanceof weka.classifiers.misc.InputMappedClassifier)) {
          if ((testSetPresent) && !test.equalHeaders(train)) {
            throw new IllegalArgumentException("Train and test file not compatible!\n" + test.equalHeadersMsg(train));
          }
        }
      }
      if (template == null) {
        throw new Exception("No actual dataset provided to use as template");
      }
      costMatrix = handleCostOption(
          Utils.getOption('m', options), template.numClasses());

      classStatistics = Utils.getFlag('i', options);
      noOutput = Utils.getFlag('o', options);
      trainStatistics = !Utils.getFlag('v', options);
      printComplexityStatistics = Utils.getFlag('k', options);
      printMargins = Utils.getFlag('r', options);
      printGraph = Utils.getFlag('g', options);
      sourceClass = Utils.getOption('z', options);
      printSource = (sourceClass.length() != 0);
      thresholdFile = Utils.getOption("threshold-file", options);
      thresholdLabel = Utils.getOption("threshold-label", options);

      String classifications = Utils.getOption("classifications", options);
      String classificationsOld = Utils.getOption("p", options);
      if (classifications.length() > 0) {
        noOutput = true;
        classificationOutput = AbstractOutput.fromCommandline(classifications);
        classificationOutput.setHeader(template);
      }
      // backwards compatible with old "-p range" and "-distribution" options
      else if (classificationsOld.length() > 0) {
        noOutput = true;
        classificationOutput = new PlainText();
        classificationOutput.setHeader(template);
        if (!classificationsOld.equals("0"))
          classificationOutput.setAttributes(classificationsOld);
        classificationOutput.setOutputDistribution(Utils.getFlag("distribution", options));
      }
      // -distribution flag needs -p option
      else {
        if (Utils.getFlag("distribution", options))
          throw new Exception("Cannot print distribution without '-p' option!");
      }

      // if no training file given, we don't have any priors
      if ( (!trainSetPresent) && (printComplexityStatistics) )
        throw new Exception("Cannot print complexity statistics ('-k') without training file ('-t')!");

      // If a model file is given, we can't process
      // scheme-specific options
      if (objectInputFileName.length() != 0) {
        Utils.checkForRemainingOptions(options);
      } else {

        // Set options for classifier
        if (classifier instanceof OptionHandler) {
          for (int i = 0; i < options.length; i++) {
            if (options[i].length() != 0) {
              if (schemeOptionsText == null) {
                schemeOptionsText = new StringBuffer();
              }
              if (options[i].indexOf(' ') != -1) {
                schemeOptionsText.append('"' + options[i] + "\" ");
              } else {
                schemeOptionsText.append(options[i] + " ");
              }
            }
          }
          ((OptionHandler)classifier).setOptions(options);
        }
      }

      Utils.checkForRemainingOptions(options);
    } catch (Exception e) {
      throw new Exception("\nWeka exception: " + e.getMessage()
          + makeOptionString(classifier, false));
    }

    if (objectInputFileName.length() != 0) {
      // Load classifier from file
      if (objectInputStream != null) {
        classifier = (Classifier) objectInputStream.readObject();
        // try and read a header (if present)
        Instances savedStructure = null;
        try {
          savedStructure = (Instances) objectInputStream.readObject();
        } catch (Exception ex) {
          // don't make a fuss
        }
        if (savedStructure != null) {
          // test for compatibility with template
          if (!template.equalHeaders(savedStructure)) {
            throw new Exception("training and test set are not compatible\n" + template.equalHeadersMsg(savedStructure));
          }
        }
        objectInputStream.close();
      }
      else if (xmlInputStream != null) {
        // whether KOML is available has already been checked (objectInputStream would null otherwise)!
        classifier = (Classifier) KOML.read(xmlInputStream);
        xmlInputStream.close();
      }
    }
   
    // Setup up evaluation objects
    Evaluation trainingEvaluation = new Evaluation(new Instances(template, 0), costMatrix);
    Evaluation testingEvaluation = new Evaluation(new Instances(template, 0), costMatrix);
    if (classifier instanceof weka.classifiers.misc.InputMappedClassifier) {
      Instances mappedClassifierHeader =
        ((weka.classifiers.misc.InputMappedClassifier)classifier).
          getModelHeader(new Instances(template, 0));
           
      trainingEvaluation = new Evaluation(new Instances(mappedClassifierHeader, 0), costMatrix);
      testingEvaluation = new Evaluation(new Instances(mappedClassifierHeader, 0), costMatrix);
    }

    // disable use of priors if no training file given
    if (!trainSetPresent)
      testingEvaluation.useNoPriors();

    // backup of fully setup classifier for cross-validation
    classifierBackup = AbstractClassifier.makeCopy(classifier);

    // Build the classifier if no object file provided
    if ((classifier instanceof UpdateableClassifier) &&
        (testSetPresent || noCrossValidation) &&
        (costMatrix == null) &&
        (trainSetPresent)) {
      // Build classifier incrementally
      trainingEvaluation.setPriors(train);
      testingEvaluation.setPriors(train);
      trainTimeStart = System.currentTimeMillis();
      if (objectInputFileName.length() == 0) {
        classifier.buildClassifier(train);
      }
      Instance trainInst;
      while (trainSource.hasMoreElements(train)) {
        trainInst = trainSource.nextElement(train);
        trainingEvaluation.updatePriors(trainInst);
        testingEvaluation.updatePriors(trainInst);
        ((UpdateableClassifier)classifier).updateClassifier(trainInst);
      }
      trainTimeElapsed = System.currentTimeMillis() - trainTimeStart;
    } else if (objectInputFileName.length() == 0) {
      // Build classifier in one go
      tempTrain = trainSource.getDataSet(actualClassIndex);
     
      if (classifier instanceof weka.classifiers.misc.InputMappedClassifier &&
          !trainingEvaluation.getHeader().equalHeaders(tempTrain)) {
        // we need to make a new dataset that maps the training instances to
        // the structure expected by the mapped classifier - this is only
        // to ensure that the structure and priors computed by the *testing*
        // evaluation object is correct with respect to the mapped classifier
        Instances mappedClassifierDataset =
          ((weka.classifiers.misc.InputMappedClassifier)classifier).
            getModelHeader(new Instances(template, 0));
        for (int zz = 0; zz < tempTrain.numInstances(); zz++) {
          Instance mapped = ((weka.classifiers.misc.InputMappedClassifier)classifier).
            constructMappedInstance(tempTrain.instance(zz));
          mappedClassifierDataset.add(mapped);
        }
        tempTrain = mappedClassifierDataset;
      }
     
      trainingEvaluation.setPriors(tempTrain);
      testingEvaluation.setPriors(tempTrain);
      trainTimeStart = System.currentTimeMillis();
      classifier.buildClassifier(tempTrain);
      trainTimeElapsed = System.currentTimeMillis() - trainTimeStart;
    }

    // backup of fully trained classifier for printing the classifications
    if (classificationOutput != null) {
      classifierClassifications = AbstractClassifier.makeCopy(classifier);
      if (classifier instanceof weka.classifiers.misc.InputMappedClassifier) {
        classificationOutput.setHeader(trainingEvaluation.getHeader());
      }
    }

    // Save the classifier if an object output file is provided
    if (objectOutputFileName.length() != 0) {
      OutputStream os = new FileOutputStream(objectOutputFileName);
      // binary
      if (!(objectOutputFileName.endsWith(".xml") || (objectOutputFileName.endsWith(".koml") && KOML.isPresent()))) {
        if (objectOutputFileName.endsWith(".gz")) {
          os = new GZIPOutputStream(os);
        }
        ObjectOutputStream objectOutputStream = new ObjectOutputStream(os);
        objectOutputStream.writeObject(classifier);
        if (template != null) {
          objectOutputStream.writeObject(template);
        }
        objectOutputStream.flush();
        objectOutputStream.close();
      }
      // KOML/XML
      else {
        BufferedOutputStream xmlOutputStream = new BufferedOutputStream(os);
        if (objectOutputFileName.endsWith(".xml")) {
          XMLSerialization xmlSerial = new XMLClassifier();
          xmlSerial.write(xmlOutputStream, classifier);
        }
        else
          // whether KOML is present has already been checked
          // if not present -> ".koml" is interpreted as binary - see above
          if (objectOutputFileName.endsWith(".koml")) {
View Full Code Here


  } catch (IllegalArgumentException ex) {
    success = false;
  }
  if (!success) {
    // load options from serialized data  ('-l' is automatically erased!)
    XMLClassifier xmlserial = new XMLClassifier();
    Classifier cl = (Classifier) xmlserial.read(Utils.getOption('l', options));
   
    // merge options
    optionsTmp = new String[options.length + cl.getOptions().length];
    System.arraycopy(cl.getOptions(), 0, optionsTmp, 0, cl.getOptions().length);
    System.arraycopy(options, 0, optionsTmp, cl.getOptions().length, options.length);
    options = optionsTmp;
  }
      }

      noCrossValidation = Utils.getFlag("no-cv", options);
      // Get basic options (options the same for all schemes)
      classIndexString = Utils.getOption('c', options);
      if (classIndexString.length() != 0) {
  if (classIndexString.equals("first"))
    classIndex = 1;
  else if (classIndexString.equals("last"))
    classIndex = -1;
  else
    classIndex = Integer.parseInt(classIndexString);
      }
      trainFileName = Utils.getOption('t', options);
      objectInputFileName = Utils.getOption('l', options);
      objectOutputFileName = Utils.getOption('d', options);
      testFileName = Utils.getOption('T', options);
      foldsString = Utils.getOption('x', options);
      if (foldsString.length() != 0) {
  folds = Integer.parseInt(foldsString);
      }
      seedString = Utils.getOption('s', options);
      if (seedString.length() != 0) {
  seed = Integer.parseInt(seedString);
      }
      if (trainFileName.length() == 0) {
  if (objectInputFileName.length() == 0) {
    throw new Exception("No training file and no object "+
    "input file given.");
  }
  if (testFileName.length() == 0) {
    throw new Exception("No training file and no test "+
    "file given.");
  }
      } else if ((objectInputFileName.length() != 0) &&
    ((!(classifier instanceof UpdateableClassifier)) ||
        (testFileName.length() == 0))) {
  throw new Exception("Classifier not incremental, or no " +
      "test file provided: can't "+
  "use both train and model file.");
      }
      try {
  if (trainFileName.length() != 0) {
    trainSetPresent = true;
    trainSource = new DataSource(trainFileName);
  }
  if (testFileName.length() != 0) {
    testSetPresent = true;
    testSource = new DataSource(testFileName);
  }
  if (objectInputFileName.length() != 0) {
    if (objectInputFileName.endsWith(".xml")) {
      // if this is the case then it means that a PMML classifier was
      // successfully loaded earlier in the code
      objectInputStream = null;
      xmlInputStream = null;
    } else {
      InputStream is = new FileInputStream(objectInputFileName);
      if (objectInputFileName.endsWith(".gz")) {
        is = new GZIPInputStream(is);
      }
      // load from KOML?
      if (!(objectInputFileName.endsWith(".koml") && KOML.isPresent()) ) {
        objectInputStream = new ObjectInputStream(is);
        xmlInputStream    = null;
      }
      else {
        objectInputStream = null;
        xmlInputStream    = new BufferedInputStream(is);
      }
    }
  }
      } catch (Exception e) {
  throw new Exception("Can't open file " + e.getMessage() + '.');
      }
      if (testSetPresent) {
  template = test = testSource.getStructure();
  if (classIndex != -1) {
    test.setClassIndex(classIndex - 1);
  } else {
    if ( (test.classIndex() == -1) || (classIndexString.length() != 0) )
      test.setClassIndex(test.numAttributes() - 1);
  }
  actualClassIndex = test.classIndex();
      }
      else {
  // percentage split
  splitPercentageString = Utils.getOption("split-percentage", options);
  if (splitPercentageString.length() != 0) {
    if (foldsString.length() != 0)
      throw new Exception(
    "Percentage split cannot be used in conjunction with "
    + "cross-validation ('-x').");
    splitPercentage = Double.parseDouble(splitPercentageString);
    if ((splitPercentage <= 0) || (splitPercentage >= 100))
      throw new Exception("Percentage split value needs be >0 and <100.");
  }
  else {
    splitPercentage = -1;
  }
  preserveOrder = Utils.getFlag("preserve-order", options);
  if (preserveOrder) {
    if (splitPercentage == -1)
      throw new Exception("Percentage split ('-percentage-split') is missing.");
  }
  // create new train/test sources
  if (splitPercentage > 0) {
    testSetPresent = true;
    Instances tmpInst = trainSource.getDataSet(actualClassIndex);
    if (!preserveOrder)
      tmpInst.randomize(new Random(seed));
    int trainSize =
            (int) Math.round(tmpInst.numInstances() * splitPercentage / 100);
    int testSize  = tmpInst.numInstances() - trainSize;
    Instances trainInst = new Instances(tmpInst, 0, trainSize);
    Instances testInst  = new Instances(tmpInst, trainSize, testSize);
    trainSource = new DataSource(trainInst);
    testSource  = new DataSource(testInst);
    template = test = testSource.getStructure();
    if (classIndex != -1) {
      test.setClassIndex(classIndex - 1);
    } else {
      if ( (test.classIndex() == -1) || (classIndexString.length() != 0) )
        test.setClassIndex(test.numAttributes() - 1);
    }
    actualClassIndex = test.classIndex();
  }
      }
      if (trainSetPresent) {
  template = train = trainSource.getStructure();
  if (classIndex != -1) {
    train.setClassIndex(classIndex - 1);
  } else {
    if ( (train.classIndex() == -1) || (classIndexString.length() != 0) )
      train.setClassIndex(train.numAttributes() - 1);
  }
  actualClassIndex = train.classIndex();
  if ((testSetPresent) && !test.equalHeaders(train)) {
    throw new IllegalArgumentException("Train and test file not compatible!");
  }
      }
      if (template == null) {
  throw new Exception("No actual dataset provided to use as template");
      }
      costMatrix = handleCostOption(
    Utils.getOption('m', options), template.numClasses());

      classStatistics = Utils.getFlag('i', options);
      noOutput = Utils.getFlag('o', options);
      trainStatistics = !Utils.getFlag('v', options);
      printComplexityStatistics = Utils.getFlag('k', options);
      printMargins = Utils.getFlag('r', options);
      printGraph = Utils.getFlag('g', options);
      sourceClass = Utils.getOption('z', options);
      printSource = (sourceClass.length() != 0);
      printDistribution = Utils.getFlag("distribution", options);
      thresholdFile = Utils.getOption("threshold-file", options);
      thresholdLabel = Utils.getOption("threshold-label", options);

      // Check -p option
      try {
  attributeRangeString = Utils.getOption('p', options);
      }
      catch (Exception e) {
  throw new Exception(e.getMessage() + "\nNOTE: the -p option has changed. " +
      "It now expects a parameter specifying a range of attributes " +
  "to list with the predictions. Use '-p 0' for none.");
      }
      if (attributeRangeString.length() != 0) {
  printClassifications = true;
  noOutput = true;
  if (!attributeRangeString.equals("0"))
    attributesToOutput = new Range(attributeRangeString);
      }

      if (!printClassifications && printDistribution)
  throw new Exception("Cannot print distribution without '-p' option!");

      // if no training file given, we don't have any priors
      if ( (!trainSetPresent) && (printComplexityStatistics) )
  throw new Exception("Cannot print complexity statistics ('-k') without training file ('-t')!");

      // If a model file is given, we can't process
      // scheme-specific options
      if (objectInputFileName.length() != 0) {
  Utils.checkForRemainingOptions(options);
      } else {

  // Set options for classifier
  if (classifier instanceof OptionHandler) {
    for (int i = 0; i < options.length; i++) {
      if (options[i].length() != 0) {
        if (schemeOptionsText == null) {
    schemeOptionsText = new StringBuffer();
        }
        if (options[i].indexOf(' ') != -1) {
    schemeOptionsText.append('"' + options[i] + "\" ");
        } else {
    schemeOptionsText.append(options[i] + " ");
        }
      }
    }
    ((OptionHandler)classifier).setOptions(options);
  }
      }
      Utils.checkForRemainingOptions(options);
    } catch (Exception e) {
      throw new Exception("\nWeka exception: " + e.getMessage()
    + makeOptionString(classifier, false));
    }

    // Setup up evaluation objects
    Evaluation trainingEvaluation = new Evaluation(new Instances(template, 0), costMatrix);
    Evaluation testingEvaluation = new Evaluation(new Instances(template, 0), costMatrix);

    // disable use of priors if no training file given
    if (!trainSetPresent)
      testingEvaluation.useNoPriors();

    if (objectInputFileName.length() != 0) {
      // Load classifier from file
      if (objectInputStream != null) {
  classifier = (Classifier) objectInputStream.readObject();
        // try and read a header (if present)
        Instances savedStructure = null;
        try {
          savedStructure = (Instances) objectInputStream.readObject();
        } catch (Exception ex) {
          // don't make a fuss
        }
        if (savedStructure != null) {
          // test for compatibility with template
          if (!template.equalHeaders(savedStructure)) {
            throw new Exception("training and test set are not compatible");
          }
        }
  objectInputStream.close();
      }
      else if (xmlInputStream != null) {
  // whether KOML is available has already been checked (objectInputStream would null otherwise)!
  classifier = (Classifier) KOML.read(xmlInputStream);
  xmlInputStream.close();
      }
    }

    // backup of fully setup classifier for cross-validation
    classifierBackup = Classifier.makeCopy(classifier);

    // Build the classifier if no object file provided
    if ((classifier instanceof UpdateableClassifier) &&
  (testSetPresent || noCrossValidation) &&
  (costMatrix == null) &&
  (trainSetPresent)) {
      // Build classifier incrementally
      trainingEvaluation.setPriors(train);
      testingEvaluation.setPriors(train);
      trainTimeStart = System.currentTimeMillis();
      if (objectInputFileName.length() == 0) {
  classifier.buildClassifier(train);
      }
      Instance trainInst;
      while (trainSource.hasMoreElements(train)) {
  trainInst = trainSource.nextElement(train);
  trainingEvaluation.updatePriors(trainInst);
  testingEvaluation.updatePriors(trainInst);
  ((UpdateableClassifier)classifier).updateClassifier(trainInst);
      }
      trainTimeElapsed = System.currentTimeMillis() - trainTimeStart;
    } else if (objectInputFileName.length() == 0) {
      // Build classifier in one go
      tempTrain = trainSource.getDataSet(actualClassIndex);
      trainingEvaluation.setPriors(tempTrain);
      testingEvaluation.setPriors(tempTrain);
      trainTimeStart = System.currentTimeMillis();
      classifier.buildClassifier(tempTrain);
      trainTimeElapsed = System.currentTimeMillis() - trainTimeStart;
    }

    // backup of fully trained classifier for printing the classifications
    if (printClassifications)
      classifierClassifications = Classifier.makeCopy(classifier);

    // Save the classifier if an object output file is provided
    if (objectOutputFileName.length() != 0) {
      OutputStream os = new FileOutputStream(objectOutputFileName);
      // binary
      if (!(objectOutputFileName.endsWith(".xml") || (objectOutputFileName.endsWith(".koml") && KOML.isPresent()))) {
  if (objectOutputFileName.endsWith(".gz")) {
    os = new GZIPOutputStream(os);
  }
  ObjectOutputStream objectOutputStream = new ObjectOutputStream(os);
  objectOutputStream.writeObject(classifier);
        if (template != null) {
          objectOutputStream.writeObject(template);
        }
  objectOutputStream.flush();
  objectOutputStream.close();
      }
      // KOML/XML
      else {
  BufferedOutputStream xmlOutputStream = new BufferedOutputStream(os);
  if (objectOutputFileName.endsWith(".xml")) {
    XMLSerialization xmlSerial = new XMLClassifier();
    xmlSerial.write(xmlOutputStream, classifier);
  }
  else
    // whether KOML is present has already been checked
    // if not present -> ".koml" is interpreted as binary - see above
    if (objectOutputFileName.endsWith(".koml")) {
View Full Code Here

        if (returnVal == JFileChooser.APPROVE_OPTION) {
          try {
            File file = m_FileChooser.getSelectedFile();
            if (!file.getAbsolutePath().toLowerCase().endsWith(".xml"))
              file = new File(file.getAbsolutePath() + ".xml");
            XMLClassifier xmlcls = new XMLClassifier();
            Classifier c = (Classifier) xmlcls.read(file);
            m_AlgorithmListModel.setElementAt(c, m_List.getSelectedIndex());
            updateExperiment();
          }
          catch (Exception ex) {
            ex.printStackTrace();
          }
        }
      }
   } else if (e.getSource() == m_SaveOptionsBut) {
      if (m_List.getSelectedValue() != null) {
        int returnVal = m_FileChooser.showSaveDialog(this);
        if (returnVal == JFileChooser.APPROVE_OPTION) {
          try {
            File file = m_FileChooser.getSelectedFile();
            if (!file.getAbsolutePath().toLowerCase().endsWith(".xml"))
              file = new File(file.getAbsolutePath() + ".xml");
            XMLClassifier xmlcls = new XMLClassifier();
            xmlcls.write(file, m_List.getSelectedValue());
          }
          catch (Exception ex) {
            ex.printStackTrace();
          }
        }
View Full Code Here

  // KOML/XML
  else {
    BufferedOutputStream xmlOutputStream = new BufferedOutputStream(
        os);
    if (objectOutputFileName.endsWith(".xml")) {
      XMLSerialization xmlSerial = new XMLClassifier();
      xmlSerial.write(xmlOutputStream, classifier);
    } else
      // whether KOML is present has already been checked
      // if not present -> ".koml" is interpreted as binary - see
      // above
      if (objectOutputFileName.endsWith(".koml")) {
View Full Code Here

        optionsTmp[i] = options[i];
      }

      if (Utils.getOption('l', optionsTmp).toLowerCase().endsWith(".xml")) {
        // load options from serialized data ('-l' is automatically erased!)
        XMLClassifier xmlserial = new XMLClassifier();
        Classifier cl = (Classifier) xmlserial.read(Utils.getOption('l', options));
        // merge options
        optionsTmp = new String[options.length + cl.getOptions().length];
        System.arraycopy(cl.getOptions(), 0, optionsTmp, 0, cl.getOptions().length);
        System.arraycopy(options, 0, optionsTmp, cl.getOptions().length, options.length);
        options = optionsTmp;
      }

      noCrossValidation = Utils.getFlag("no-cv", options);
      // Get basic options (options the same for all schemes)
      classIndexString = Utils.getOption('c', options);
      if (classIndexString.length() != 0) {
        if (classIndexString.equals("first")) {
          classIndex = 1;
        } else if (classIndexString.equals("last")) {
          classIndex = -1;
        } else {
          classIndex = Integer.parseInt(classIndexString);
        }
      }
      trainFileName = Utils.getOption('t', options);
      objectInputFileName = Utils.getOption('l', options);
      objectOutputFileName = Utils.getOption('d', options);
      testFileName = Utils.getOption('T', options);
      foldsString = Utils.getOption('x', options);
      if (foldsString.length() != 0) {
        folds = Integer.parseInt(foldsString);
      }
      seedString = Utils.getOption('s', options);
      if (seedString.length() != 0) {
        seed = Integer.parseInt(seedString);
      }
      if (trainFileName.length() == 0) {
        if (objectInputFileName.length() == 0) {
          throw new Exception("No training file and no object " +
              "input file given.");
        }
        if (testFileName.length() == 0) {
          throw new Exception("No training file and no test " +
              "file given.");
        }
      } else if ((objectInputFileName.length() != 0) &&
          ((!(classifier instanceof UpdateableClassifier)) ||
          (testFileName.length() == 0))) {
        throw new Exception("Classifier not incremental, or no " +
            "test file provided: can't " +
            "use both train and model file.");
      }
      try {
        if (trainFileName.length() != 0) {
          trainSetPresent = true;
          trainSource = new DataSource(trainFileName);
        }
        if (testFileName.length() != 0) {
          testSetPresent = true;
          testSource = new DataSource(testFileName);
        }
        if (objectInputFileName.length() != 0) {
          InputStream is = new FileInputStream(objectInputFileName);
          if (objectInputFileName.endsWith(".gz")) {
            is = new GZIPInputStream(is);
          }
          // load from KOML?
          if (!(objectInputFileName.endsWith(".koml") && KOML.isPresent())) {
            objectInputStream = new ObjectInputStream(is);
            xmlInputStream = null;
          } else {
            objectInputStream = null;
            xmlInputStream = new BufferedInputStream(is);
          }
        }
      } catch (Exception e) {
        throw new Exception("Can't open file " + e.getMessage() + '.');
      }
      if (testSetPresent) {
        template = test = testSource.getStructure();
        if (classIndex != -1) {
          test.setClassIndex(classIndex - 1);
        } else {
          if ((test.classIndex() == -1) || (classIndexString.length() != 0)) {
            test.setClassIndex(test.numAttributes() - 1);
          }
        }
        actualClassIndex = test.classIndex();
      } else {
        // percentage split
        splitPercentageString = Utils.getOption("split-percentage", options);
        if (splitPercentageString.length() != 0) {
          if (foldsString.length() != 0) {
            throw new Exception(
                "Percentage split cannot be used in conjunction with " + "cross-validation ('-x').");
          }
          splitPercentage = Integer.parseInt(splitPercentageString);
          if ((splitPercentage <= 0) || (splitPercentage >= 100)) {
            throw new Exception("Percentage split value needs be >0 and <100.");
          }
        } else {
          splitPercentage = -1;
        }
        preserveOrder = Utils.getFlag("preserve-order", options);
        if (preserveOrder) {
          if (splitPercentage == -1) {
            throw new Exception("Percentage split ('-percentage-split') is missing.");
          }
        }
        // create new train/test sources
        if (splitPercentage > 0) {
          testSetPresent = true;
          Instances tmpInst = trainSource.getDataSet(actualClassIndex);
          if (!preserveOrder) {
            tmpInst.randomize(new Random(seed));
          }
          int trainSize = tmpInst.numInstances() * splitPercentage / 100;
          int testSize = tmpInst.numInstances() - trainSize;
          Instances trainInst = new Instances(tmpInst, 0, trainSize);
          Instances testInst = new Instances(tmpInst, trainSize, testSize);
          trainSource = new DataSource(trainInst);
          testSource = new DataSource(testInst);
          template = test = testSource.getStructure();
          if (classIndex != -1) {
            test.setClassIndex(classIndex - 1);
          } else {
            if ((test.classIndex() == -1) || (classIndexString.length() != 0)) {
              test.setClassIndex(test.numAttributes() - 1);
            }
          }
          actualClassIndex = test.classIndex();
        }
      }
      if (trainSetPresent) {
        template = train = trainSource.getStructure();
        if (classIndex != -1) {
          train.setClassIndex(classIndex - 1);
        } else {
          if ((train.classIndex() == -1) || (classIndexString.length() != 0)) {
            train.setClassIndex(train.numAttributes() - 1);
          }
        }
        actualClassIndex = train.classIndex();
        if ((testSetPresent) && !test.equalHeaders(train)) {
          throw new IllegalArgumentException("Train and test file not compatible!");
        }
      }
      if (template == null) {
        throw new Exception("No actual dataset provided to use as template");
      }
      costMatrix = handleCostOption(
          Utils.getOption('m', options), template.numClasses());

      classStatistics = Utils.getFlag('i', options);
      noOutput = Utils.getFlag('o', options);
      trainStatistics = !Utils.getFlag('v', options);
      printComplexityStatistics = Utils.getFlag('k', options);
      printMargins = Utils.getFlag('r', options);
      printGraph = Utils.getFlag('g', options);
      sourceClass = Utils.getOption('z', options);
      printSource = (sourceClass.length() != 0);
      printDistribution = Utils.getFlag("distribution", options);
      thresholdFile = Utils.getOption("threshold-file", options);
      thresholdLabel = Utils.getOption("threshold-label", options);

      // Check -p option
      try {
        attributeRangeString = Utils.getOption('p', options);
      } catch (Exception e) {
        throw new Exception(e.getMessage() + "\nNOTE: the -p option has changed. " +
            "It now expects a parameter specifying a range of attributes " +
            "to list with the predictions. Use '-p 0' for none.");
      }
      if (attributeRangeString.length() != 0) {
        printClassifications = true;
        if (!attributeRangeString.equals("0")) {
          attributesToOutput = new Range(attributeRangeString);
        }
      }

      if (!printClassifications && printDistribution) {
        throw new Exception("Cannot print distribution without '-p' option!");
      }

      // if no training file given, we don't have any priors
      if ((!trainSetPresent) && (printComplexityStatistics)) {
        throw new Exception("Cannot print complexity statistics ('-k') without training file ('-t')!");
      }

      // If a model file is given, we can't process
      // scheme-specific options
      if (objectInputFileName.length() != 0) {
        Utils.checkForRemainingOptions(options);
      } else {

        // Set options for classifier
        if (classifier instanceof OptionHandler) {
          for (int i = 0; i < options.length; i++) {
            if (options[i].length() != 0) {
              if (schemeOptionsText == null) {
                schemeOptionsText = new StringBuffer();
              }
              if (options[i].indexOf(' ') != -1) {
                schemeOptionsText.append('"' + options[i] + "\" ");
              } else {
                schemeOptionsText.append(options[i] + " ");
              }
            }
          }
          ((OptionHandler) classifier).setOptions(options);
        }
      }
      Utils.checkForRemainingOptions(options);
    } catch (Exception e) {
      throw new Exception("\nWeka exception: " + e.getMessage() + makeOptionString(classifier));
    }

    // Setup up evaluation objects
    Evaluation trainingEvaluation = new Evaluation(new Instances(template, 0), costMatrix);
    Evaluation testingEvaluation = new Evaluation(new Instances(template, 0), costMatrix);

    // disable use of priors if no training file given
    if (!trainSetPresent) {
      testingEvaluation.useNoPriors();
    }

    if (objectInputFileName.length() != 0) {
      // Load classifier from file
      if (objectInputStream != null) {
        classifier = (Classifier) objectInputStream.readObject();
        // try and read a header (if present)
        Instances savedStructure = null;
        try {
          savedStructure = (Instances) objectInputStream.readObject();
        } catch (Exception ex) {
          // don't make a fuss
        }
        if (savedStructure != null) {
          // test for compatibility with template
          if (!template.equalHeaders(savedStructure)) {
            throw new Exception("training and test set are not compatible");
          }
        }
        objectInputStream.close();
      } else {
        // whether KOML is available has already been checked (objectInputStream would null otherwise)!
        classifier = (Classifier) KOML.read(xmlInputStream);
        xmlInputStream.close();
      }
    }

    // backup of fully setup classifier for cross-validation
    classifierBackup = Classifier.makeCopy(classifier);

    // Build the classifier if no object file provided
    if ((classifier instanceof UpdateableClassifier) &&
        (testSetPresent || noCrossValidation) &&
        (costMatrix == null) &&
        (trainSetPresent)) {
      // Build classifier incrementally
      trainingEvaluation.setPriors(train);
      testingEvaluation.setPriors(train);
      trainTimeStart = System.currentTimeMillis();
      if (objectInputFileName.length() == 0) {
        classifier.buildClassifier(train);
      }
      Instance trainInst;
      while (trainSource.hasMoreElements(train)) {
        trainInst = trainSource.nextElement(train);
        trainingEvaluation.updatePriors(trainInst);
        testingEvaluation.updatePriors(trainInst);
        ((UpdateableClassifier) classifier).updateClassifier(trainInst);
      }
      trainTimeElapsed = System.currentTimeMillis() - trainTimeStart;
    } else if (objectInputFileName.length() == 0) {
      // Build classifier in one go
      tempTrain = trainSource.getDataSet(actualClassIndex);
      trainingEvaluation.setPriors(tempTrain);
      testingEvaluation.setPriors(tempTrain);
      trainTimeStart = System.currentTimeMillis();
      classifier.buildClassifier(tempTrain);
      trainTimeElapsed = System.currentTimeMillis() - trainTimeStart;
    }

/*  FOR LARGE DATA SETS
    // backup of fully trained classifier for printing the classifications
    if (printClassifications) {
      classifierClassifications = Classifier.makeCopy(classifier);
    }
*/
    // Save the classifier if an object output file is provided
    if (objectOutputFileName.length() != 0) {
      OutputStream os = new FileOutputStream(objectOutputFileName);
      // binary
      if (!(objectOutputFileName.endsWith(".xml") || (objectOutputFileName.endsWith(".koml") && KOML.isPresent()))) {
        if (objectOutputFileName.endsWith(".gz")) {
          os = new GZIPOutputStream(os);
        }
        ObjectOutputStream objectOutputStream = new ObjectOutputStream(os);
        objectOutputStream.writeObject(classifier);
        if (template != null) {
          objectOutputStream.writeObject(template);
        }
        objectOutputStream.flush();
        objectOutputStream.close();
      } // KOML/XML
      else {
        BufferedOutputStream xmlOutputStream = new BufferedOutputStream(os);
        if (objectOutputFileName.endsWith(".xml")) {
          XMLSerialization xmlSerial = new XMLClassifier();
          xmlSerial.write(xmlOutputStream, classifier);
        } else // whether KOML is present has already been checked
        // if not present -> ".koml" is interpreted as binary - see above
        if (objectOutputFileName.endsWith(".koml")) {
          KOML.write(xmlOutputStream, classifier);
        }
View Full Code Here

        if (returnVal == JFileChooser.APPROVE_OPTION) {
          try {
            File file = m_FileChooser.getSelectedFile();
            if (!file.getAbsolutePath().toLowerCase().endsWith(".xml"))
              file = new File(file.getAbsolutePath() + ".xml");
            XMLClassifier xmlcls = new XMLClassifier();
            Classifier c = (Classifier) xmlcls.read(file);
            m_AlgorithmListModel.setElementAt(c, m_List.getSelectedIndex());
            updateExperiment();
          }
          catch (Exception ex) {
            ex.printStackTrace();
          }
        }
      }
   } else if (e.getSource() == m_SaveOptionsBut) {
      if (m_List.getSelectedValue() != null) {
        int returnVal = m_FileChooser.showSaveDialog(this);
        if (returnVal == JFileChooser.APPROVE_OPTION) {
          try {
            File file = m_FileChooser.getSelectedFile();
            if (!file.getAbsolutePath().toLowerCase().endsWith(".xml"))
              file = new File(file.getAbsolutePath() + ".xml");
            XMLClassifier xmlcls = new XMLClassifier();
            xmlcls.write(file, m_List.getSelectedValue());
          }
          catch (Exception ex) {
            ex.printStackTrace();
          }
        }
View Full Code Here

        if (returnVal == JFileChooser.APPROVE_OPTION) {
          try {
            File file = m_FileChooser.getSelectedFile();
            if (!file.getAbsolutePath().toLowerCase().endsWith(".xml"))
              file = new File(file.getAbsolutePath() + ".xml");
            XMLClassifier xmlcls = new XMLClassifier();
            Classifier c = (Classifier) xmlcls.read(file);
            m_AlgorithmListModel.setElementAt(c, m_List.getSelectedIndex());
            updateExperiment();
          }
          catch (Exception ex) {
            ex.printStackTrace();
          }
        }
      }
   } else if (e.getSource() == m_SaveOptionsBut) {
      if (m_List.getSelectedValue() != null) {
        int returnVal = m_FileChooser.showSaveDialog(this);
        if (returnVal == JFileChooser.APPROVE_OPTION) {
          try {
            File file = m_FileChooser.getSelectedFile();
            if (!file.getAbsolutePath().toLowerCase().endsWith(".xml"))
              file = new File(file.getAbsolutePath() + ".xml");
            XMLClassifier xmlcls = new XMLClassifier();
            xmlcls.write(file, m_List.getSelectedValue());
          }
          catch (Exception ex) {
            ex.printStackTrace();
          }
        }
View Full Code Here

        } catch (IllegalArgumentException ex) {
          success = false;
        }
        if (!success) {
          // load options from serialized data  ('-l' is automatically erased!)
          XMLClassifier xmlserial = new XMLClassifier();
          OptionHandler cl = (OptionHandler) xmlserial.read(Utils.getOption('l', options));

          // merge options
          optionsTmp = new String[options.length + cl.getOptions().length];
          System.arraycopy(cl.getOptions(), 0, optionsTmp, 0, cl.getOptions().length);
          System.arraycopy(options, 0, optionsTmp, cl.getOptions().length, options.length);
          options = optionsTmp;
        }
      }

      noCrossValidation = Utils.getFlag("no-cv", options);
      // Get basic options (options the same for all schemes)
      classIndexString = Utils.getOption('c', options);
      if (classIndexString.length() != 0) {
        if (classIndexString.equals("first"))
          classIndex = 1;
        else if (classIndexString.equals("last"))
          classIndex = -1;
        else
          classIndex = Integer.parseInt(classIndexString);
      }
      trainFileName = Utils.getOption('t', options);
      objectInputFileName = Utils.getOption('l', options);
      objectOutputFileName = Utils.getOption('d', options);
      testFileName = Utils.getOption('T', options);
      foldsString = Utils.getOption('x', options);
      if (foldsString.length() != 0) {
        folds = Integer.parseInt(foldsString);
      }
      seedString = Utils.getOption('s', options);
      if (seedString.length() != 0) {
        seed = Integer.parseInt(seedString);
      }
      if (trainFileName.length() == 0) {
        if (objectInputFileName.length() == 0) {
          throw new Exception("No training file and no object input file given.");
        }
        if (testFileName.length() == 0) {
          throw new Exception("No training file and no test file given.");
        }
      } else if ((objectInputFileName.length() != 0) &&
          ((!(classifier instanceof UpdateableClassifier)) ||
           (testFileName.length() == 0))) {
        throw new Exception("Classifier not incremental, or no " +
            "test file provided: can't "+
            "use both train and model file.");
      }
      try {
        if (trainFileName.length() != 0) {
          trainSetPresent = true;
          trainSource = new DataSource(trainFileName);
        }
        if (testFileName.length() != 0) {
          testSetPresent = true;
          testSource = new DataSource(testFileName);
        }
        if (objectInputFileName.length() != 0) {
          if (objectInputFileName.endsWith(".xml")) {
            // if this is the case then it means that a PMML classifier was
            // successfully loaded earlier in the code
            objectInputStream = null;
            xmlInputStream = null;
          } else {
            InputStream is = new FileInputStream(objectInputFileName);
            if (objectInputFileName.endsWith(".gz")) {
              is = new GZIPInputStream(is);
            }
            // load from KOML?
            if (!(objectInputFileName.endsWith(".koml") && KOML.isPresent()) ) {
              objectInputStream = new ObjectInputStream(is);
              xmlInputStream    = null;
            }
            else {
              objectInputStream = null;
              xmlInputStream    = new BufferedInputStream(is);
            }
          }
        }
      } catch (Exception e) {
        throw new Exception("Can't open file " + e.getMessage() + '.');
      }
      if (testSetPresent) {
        template = test = testSource.getStructure();
        if (classIndex != -1) {
          test.setClassIndex(classIndex - 1);
        } else {
          if ( (test.classIndex() == -1) || (classIndexString.length() != 0) )
            test.setClassIndex(test.numAttributes() - 1);
        }
        actualClassIndex = test.classIndex();
      }
      else {
        // percentage split
        splitPercentageString = Utils.getOption("split-percentage", options);
        if (splitPercentageString.length() != 0) {
          if (foldsString.length() != 0)
            throw new Exception(
                "Percentage split cannot be used in conjunction with "
                + "cross-validation ('-x').");
          splitPercentage = Double.parseDouble(splitPercentageString);
          if ((splitPercentage <= 0) || (splitPercentage >= 100))
            throw new Exception("Percentage split value needs be >0 and <100.");
        }
        else {
          splitPercentage = -1;
        }
        preserveOrder = Utils.getFlag("preserve-order", options);
        if (preserveOrder) {
          if (splitPercentage == -1)
            throw new Exception("Percentage split ('-percentage-split') is missing.");
        }
        // create new train/test sources
        if (splitPercentage > 0) {
          testSetPresent = true;
          Instances tmpInst = trainSource.getDataSet(actualClassIndex);
          if (!preserveOrder)
            tmpInst.randomize(new Random(seed));
          int trainSize =
            (int) Math.round(tmpInst.numInstances() * splitPercentage / 100);
          int testSize  = tmpInst.numInstances() - trainSize;
          Instances trainInst = new Instances(tmpInst, 0, trainSize);
          Instances testInst  = new Instances(tmpInst, trainSize, testSize);
          trainSource = new DataSource(trainInst);
          testSource  = new DataSource(testInst);
          template = test = testSource.getStructure();
          if (classIndex != -1) {
            test.setClassIndex(classIndex - 1);
          } else {
            if ( (test.classIndex() == -1) || (classIndexString.length() != 0) )
              test.setClassIndex(test.numAttributes() - 1);
          }
          actualClassIndex = test.classIndex();
        }
      }
      if (trainSetPresent) {
        template = train = trainSource.getStructure();
        if (classIndex != -1) {
          train.setClassIndex(classIndex - 1);
        } else {
          if ( (train.classIndex() == -1) || (classIndexString.length() != 0) )
            train.setClassIndex(train.numAttributes() - 1);
        }
        actualClassIndex = train.classIndex();
        if (!(classifier instanceof weka.classifiers.misc.InputMappedClassifier)) {
          if ((testSetPresent) && !test.equalHeaders(train)) {
            throw new IllegalArgumentException("Train and test file not compatible!\n" + test.equalHeadersMsg(train));
          }
        }
      }
      if (template == null) {
        throw new Exception("No actual dataset provided to use as template");
      }
      costMatrix = handleCostOption(
          Utils.getOption('m', options), template.numClasses());

      classStatistics = Utils.getFlag('i', options);
      noOutput = Utils.getFlag('o', options);
      trainStatistics = !Utils.getFlag('v', options);
      printComplexityStatistics = Utils.getFlag('k', options);
      printMargins = Utils.getFlag('r', options);
      printGraph = Utils.getFlag('g', options);
      sourceClass = Utils.getOption('z', options);
      printSource = (sourceClass.length() != 0);
      thresholdFile = Utils.getOption("threshold-file", options);
      thresholdLabel = Utils.getOption("threshold-label", options);

      String classifications = Utils.getOption("classifications", options);
      String classificationsOld = Utils.getOption("p", options);
      if (classifications.length() > 0) {
        noOutput = true;
        classificationOutput = AbstractOutput.fromCommandline(classifications);
        if (classificationOutput == null)
          throw new Exception("Failed to instantiate class for classification output: " + classifications);
        classificationOutput.setHeader(template);
      }
      // backwards compatible with old "-p range" and "-distribution" options
      else if (classificationsOld.length() > 0) {
        noOutput = true;
        classificationOutput = new PlainText();
        classificationOutput.setHeader(template);
        if (!classificationsOld.equals("0"))
          classificationOutput.setAttributes(classificationsOld);
        classificationOutput.setOutputDistribution(Utils.getFlag("distribution", options));
      }
      // -distribution flag needs -p option
      else {
        if (Utils.getFlag("distribution", options))
          throw new Exception("Cannot print distribution without '-p' option!");
      }
      discardPredictions = Utils.getFlag("no-predictions", options);
      if (discardPredictions && (classificationOutput != null))
  throw new Exception("Cannot discard predictions ('-no-predictions') and output predictions at the same time ('-classifications/-p')!");

      // if no training file given, we don't have any priors
      if ( (!trainSetPresent) && (printComplexityStatistics) )
        throw new Exception("Cannot print complexity statistics ('-k') without training file ('-t')!");

      // If a model file is given, we can't process
      // scheme-specific options
      if (objectInputFileName.length() != 0) {
        Utils.checkForRemainingOptions(options);
      } else {

        // Set options for classifier
        if (classifier instanceof OptionHandler) {
          for (int i = 0; i < options.length; i++) {
            if (options[i].length() != 0) {
              if (schemeOptionsText == null) {
                schemeOptionsText = new StringBuffer();
              }
              if (options[i].indexOf(' ') != -1) {
                schemeOptionsText.append('"' + options[i] + "\" ");
              } else {
                schemeOptionsText.append(options[i] + " ");
              }
            }
          }
          ((OptionHandler)classifier).setOptions(options);
        }
      }

      Utils.checkForRemainingOptions(options);
    } catch (Exception e) {
      throw new Exception("\nWeka exception: " + e.getMessage()
          + makeOptionString(classifier, false));
    }

    if (objectInputFileName.length() != 0) {
      // Load classifier from file
      if (objectInputStream != null) {
        classifier = (Classifier) objectInputStream.readObject();
        // try and read a header (if present)
        Instances savedStructure = null;
        try {
          savedStructure = (Instances) objectInputStream.readObject();
        } catch (Exception ex) {
          // don't make a fuss
        }
        if (savedStructure != null) {
          // test for compatibility with template
          if (!template.equalHeaders(savedStructure)) {
            throw new Exception("training and test set are not compatible\n" + template.equalHeadersMsg(savedStructure));
          }
        }
        objectInputStream.close();
      }
      else if (xmlInputStream != null) {
        // whether KOML is available has already been checked (objectInputStream would null otherwise)!
        classifier = (Classifier) KOML.read(xmlInputStream);
        xmlInputStream.close();
      }
    }

    // Setup up evaluation objects
    Evaluation trainingEvaluation = new Evaluation(new Instances(template, 0), costMatrix);
    Evaluation testingEvaluation = new Evaluation(new Instances(template, 0), costMatrix);
    if (classifier instanceof weka.classifiers.misc.InputMappedClassifier) {
      Instances mappedClassifierHeader =
        ((weka.classifiers.misc.InputMappedClassifier)classifier).
          getModelHeader(new Instances(template, 0));

      trainingEvaluation = new Evaluation(new Instances(mappedClassifierHeader, 0), costMatrix);
      testingEvaluation = new Evaluation(new Instances(mappedClassifierHeader, 0), costMatrix);
    }
    trainingEvaluation.setDiscardPredictions(discardPredictions);
    testingEvaluation.setDiscardPredictions(discardPredictions);

    // disable use of priors if no training file given
    if (!trainSetPresent)
      testingEvaluation.useNoPriors();

    // backup of fully setup classifier for cross-validation
    classifierBackup = AbstractClassifier.makeCopy(classifier);

    // Build the classifier if no object file provided
    if ((classifier instanceof UpdateableClassifier) &&
        (testSetPresent || noCrossValidation) &&
        (costMatrix == null) &&
        (trainSetPresent)) {
      // Build classifier incrementally
      trainingEvaluation.setPriors(train);
      testingEvaluation.setPriors(train);
      trainTimeStart = System.currentTimeMillis();
      if (objectInputFileName.length() == 0) {
        classifier.buildClassifier(train);
      }
      Instance trainInst;
      while (trainSource.hasMoreElements(train)) {
        trainInst = trainSource.nextElement(train);
        trainingEvaluation.updatePriors(trainInst);
        testingEvaluation.updatePriors(trainInst);
        ((UpdateableClassifier)classifier).updateClassifier(trainInst);
      }
      trainTimeElapsed = System.currentTimeMillis() - trainTimeStart;
    } else if (objectInputFileName.length() == 0) {
      // Build classifier in one go
      tempTrain = trainSource.getDataSet(actualClassIndex);

      if (classifier instanceof weka.classifiers.misc.InputMappedClassifier &&
          !trainingEvaluation.getHeader().equalHeaders(tempTrain)) {
        // we need to make a new dataset that maps the training instances to
        // the structure expected by the mapped classifier - this is only
        // to ensure that the structure and priors computed by the *testing*
        // evaluation object is correct with respect to the mapped classifier
        Instances mappedClassifierDataset =
          ((weka.classifiers.misc.InputMappedClassifier)classifier).
            getModelHeader(new Instances(template, 0));
        for (int zz = 0; zz < tempTrain.numInstances(); zz++) {
          Instance mapped = ((weka.classifiers.misc.InputMappedClassifier)classifier).
            constructMappedInstance(tempTrain.instance(zz));
          mappedClassifierDataset.add(mapped);
        }
        tempTrain = mappedClassifierDataset;
      }

      trainingEvaluation.setPriors(tempTrain);
      testingEvaluation.setPriors(tempTrain);
      trainTimeStart = System.currentTimeMillis();
      classifier.buildClassifier(tempTrain);
      trainTimeElapsed = System.currentTimeMillis() - trainTimeStart;
    }

    // backup of fully trained classifier for printing the classifications
    if (classificationOutput != null) {
      classifierClassifications = AbstractClassifier.makeCopy(classifier);
      if (classifier instanceof weka.classifiers.misc.InputMappedClassifier) {
        classificationOutput.setHeader(trainingEvaluation.getHeader());
      }
    }

    // Save the classifier if an object output file is provided
    if (objectOutputFileName.length() != 0) {
      OutputStream os = new FileOutputStream(objectOutputFileName);
      // binary
      if (!(objectOutputFileName.endsWith(".xml") || (objectOutputFileName.endsWith(".koml") && KOML.isPresent()))) {
        if (objectOutputFileName.endsWith(".gz")) {
          os = new GZIPOutputStream(os);
        }
        ObjectOutputStream objectOutputStream = new ObjectOutputStream(os);
        objectOutputStream.writeObject(classifier);
        if (template != null) {
          objectOutputStream.writeObject(template);
        }
        objectOutputStream.flush();
        objectOutputStream.close();
      }
      // KOML/XML
      else {
        BufferedOutputStream xmlOutputStream = new BufferedOutputStream(os);
        if (objectOutputFileName.endsWith(".xml")) {
          XMLSerialization xmlSerial = new XMLClassifier();
          xmlSerial.write(xmlOutputStream, classifier);
        }
        else
          // whether KOML is present has already been checked
          // if not present -> ".koml" is interpreted as binary - see above
          if (objectOutputFileName.endsWith(".koml")) {
View Full Code Here

TOP

Related Classes of weka.classifiers.xml.XMLClassifier

Copyright © 2018 www.massapicom. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.