Package org.sf.mustru.train

Source Code of org.sf.mustru.train.TrainMClassifier

package org.sf.mustru.train;

import java.io.File;
import java.io.FileOutputStream;
import java.io.ObjectOutputStream;

import com.aliasi.util.Files;
import com.aliasi.classify.DynamicLMClassifier;

import org.apache.log4j.Logger;
import org.apache.log4j.PropertyConfigurator;
import org.sf.mustru.utils.Constants;

import java.io.IOException;

/**
* Test classification using the Lingpipe classifier with a collection
* of Reuters documents. Three categories with 10-15 documents each
* are used for training.
*/
public class TrainMClassifier
//*-- define the training and testing directories and list of categories
private static String MUSTRU_HOME =  Constants.MUSTRU_HOME;
private static File TRAINING_DIR = new File(
   MUSTRU_HOME + File.separator + "data" + File.separator + "training" + File.separator + "tcat");
private static String[] CATEGORIES = {"sugar", "coffee", "cocoa", "misc"};
private static int NGRAM_SIZE = 6;
private static boolean BOUNDED = false;

public static void main(String[] args) throws ClassNotFoundException, IOException
{
  PropertyConfigurator.configure (Constants.LOG4J_FILE);
  Logger logger = Logger.getLogger(TrainMClassifier.class.getName());
  logger.debug("Started TrainMClassifier");
  DynamicLMClassifier classifier = new DynamicLMClassifier(CATEGORIES, NGRAM_SIZE, BOUNDED);

  //*-- Start of training
  //*-- loop through the list of categories and verify that a directory exists for each one.
  for (int i = 0; i < CATEGORIES.length; ++i)
  {
   File classDir = new File(TRAINING_DIR, CATEGORIES[i]);
   if (!classDir.isDirectory())
   { logger.fatal("Could not find training directory=" + classDir); }

   //*-- get the list of training files for the category and train the classifier on each of the files
   String[] trainingFiles = classDir.list();
   for (int j=0; j<trainingFiles.length; ++j)
   {
    String text = Files.readFromFile(new File(classDir,trainingFiles[j]));
    logger.debug("Training on " + CATEGORIES[i] + File.separator + trainingFiles[j]);
    classifier.train(CATEGORIES[i], text);
   } //*-- end of inner for
  } //*-- end of outer for
  //*-- end of training

  //*-- dump the classification model to a file
  logger.info("Start compiling classifier");
  String modelFile = MUSTRU_HOME + File.separator + "data" + File.separator + "training" + File.separator + "tcat" + File.separator + "tcat_classifier";
  ObjectOutputStream os = new ObjectOutputStream( new FileOutputStream(modelFile) );
  classifier.compileTo(os);
  os.close();
  logger.info("End compiling classifier");
  logger.debug("Ended TrainMClassifier");
}
}
TOP

Related Classes of org.sf.mustru.train.TrainMClassifier

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.