package org.sf.mustru.train;
import java.io.File;
import java.io.FileOutputStream;
import java.io.ObjectOutputStream;
import com.aliasi.util.Files;
import com.aliasi.classify.DynamicLMClassifier;
import org.apache.log4j.Logger;
import org.apache.log4j.PropertyConfigurator;
import org.sf.mustru.utils.Constants;
import java.io.IOException;
/**
* Train the text category classifier
*/
public class TrainTtypeClassifier
{
//*-- define the training and testing directories and list of categories
private static File TRAINING_DIR = new File(Constants.TRAININGDIR + File.separator + "categories" + File.separator + "news");
private static String[] CATEGORIES = {"business", "health", "politics", "sports", "technology", "science"};
private static int NGRAM_SIZE = 6;
private static boolean BOUNDED = false;
public static void main(String[] args) throws ClassNotFoundException, IOException
{
PropertyConfigurator.configure (Constants.LOG4J_FILE);
Logger logger = Logger.getLogger(TrainTtypeClassifier.class.getName());
logger.debug("Started TrainCatClassifier");
DynamicLMClassifier classifier = new DynamicLMClassifier(CATEGORIES, NGRAM_SIZE, BOUNDED);
//*-- Start of training
//*-- loop through the list of categories and verify that a directory exists for each one.
for (int i = 0; i < CATEGORIES.length; ++i)
{
File classDir = new File(TRAINING_DIR, CATEGORIES[i]);
if (!classDir.isDirectory())
{ logger.fatal("Could not find training directory for category: " + classDir); }
//*-- get the list of training files for the category and train the classifier on each of the files
String[] trainingFiles = classDir.list();
for (int j = 0; j < trainingFiles.length; ++j)
{
String text = Files.readFromFile(new File(classDir,trainingFiles[j]));
logger.debug("Training on " + CATEGORIES[i] + File.separator + trainingFiles[j]);
classifier.train(CATEGORIES[i], text);
} //*-- end of inner for
} //*-- end of outer for
//*-- end of training
//*-- dump the classification model to a file
logger.info("Start compiling TrainCatClassifier");
String modelFile = Constants.TTYPE_CLASS_MODEL;
ObjectOutputStream os = new ObjectOutputStream( new FileOutputStream(modelFile) );
classifier.compileTo(os);
os.close();
logger.info("End compiling classifier");
logger.debug("Ended TrainCatClassifier");
}
}