Package etc.aloe

Source Code of etc.aloe.AloeTrain

/*
* This file is part of ALOE.
*
* ALOE is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.

* ALOE is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
* GNU General Public License for more details.

* You should have received a copy of the GNU General Public License
* along with ALOE.  If not, see <http://www.gnu.org/licenses/>.
*
* Copyright (c) 2012 SCCL, University of Washington (http://depts.washington.edu/sccl)
*/
package etc.aloe;

import etc.aloe.controllers.CrossValidationController;
import etc.aloe.controllers.TrainingController;
import etc.aloe.data.EvaluationReport;
import etc.aloe.data.FeatureSpecification;
import etc.aloe.data.MessageSet;
import etc.aloe.data.Model;
import etc.aloe.data.ROC;
import etc.aloe.data.SegmentSet;
import etc.aloe.options.ModeOptions;
import etc.aloe.options.TrainOptions;
import etc.aloe.processes.Segmentation;
import java.io.File;
import java.util.List;
import java.util.Map;
import weka.core.Instances;

/**
* Class that takes input training data, uses cross validation to evaluate the
* model, then trains a final model on the full training set.
*
* @author Michael Brooks <mjbrooks@uw.edu>
*/
public class AloeTrain extends Aloe {

    @Override
    public void run(ModeOptions modeOptions) {
        System.out.println("== Preparation ==");

        if (modeOptions instanceof TrainOptions) {
            TrainOptions options = (TrainOptions) modeOptions;

            saveCommand(options.outputCommandFile);
           
            //Get and preprocess the data
            MessageSet messages = this.loadMessages(options.inputCSVFile);
            Segmentation segmentation = factory.constructSegmentation();
            SegmentSet segments = segmentation.segment(messages);

            //Set up a cross validation controller.
            CrossValidationController crossValidationController = new CrossValidationController();
            //Configure controller
            factory.configureCrossValidation(crossValidationController);

            //Run cross validation
            crossValidationController.setSegmentSet(segments);
            crossValidationController.run();

            EvaluationReport evalReport = crossValidationController.getEvaluationReport();
            if (evalReport != null) {
                System.out.println("== Saving Results of Cross Validation ==");
               
                saveEvaluationReport(evalReport, options.outputEvaluationReportFile);
               
                System.out.println("Aggregated cross-validation report:");
                System.out.println(evalReport);
                System.out.println("---------");

                if (options.makeROC) {
                    options.outputROCDir.mkdirs();
                   
                    for (ROC roc : evalReport.getROCs()) {
                        String fileName = roc.getName() + FileNames.ROC_SUFFIX;
                        File outputFile = new File(options.outputROCDir, fileName);

                        saveROC(roc, outputFile);
                    }
                }
               
                if (options.outputTests) {
                    options.outputTestsDir.mkdirs();
                    List<SegmentSet> testSets = evalReport.getTestSets();
                    List<String> testSetNames = evalReport.getTestSetNames();
                   
                    SegmentSet combined = new SegmentSet();
                   
                    for (int i = 0; i < testSets.size(); i++) {
                        String fileName = testSetNames.get(i) + FileNames.TEST_DATA_SUFFIX;
                        SegmentSet testSet = testSets.get(i);
                        combined.addAll(testSet.getSegments());
                       
                        File outputFile = new File(options.outputTestsDir, fileName);
                       
                        saveMessages(testSet.getMessages(messages), outputFile);
                    }
                   
                    String fileName = FileNames.OUTPUT_TEST_DATA_COMBINED_NAME;
                    File outputFile = new File(options.outputTestsDir, fileName);
                    saveMessages(combined.getMessages(messages), outputFile);

                }
            }
           
           
            //Create a training controller for making the final model
            TrainingController trainingController = new TrainingController();
            //Configure the training controller
            factory.configureTraining(trainingController);

            //Run the full training
            trainingController.setSegmentSet(segments);
            trainingController.run();

            //Get the fruits of our labors
            System.out.println("== Saving Output ==");

            FeatureSpecification spec = trainingController.getFeatureSpecification();
            Model model = trainingController.getModel();
            List<String> topFeatures = trainingController.getTopFeatures();
            List<Map.Entry<String, Double>> featureWeights = trainingController.getFeatureWeights();

            saveFeatureSpecification(spec, options.outputFeatureSpecFile);
            saveModel(model, options.outputModelFile);
            saveTopFeatures(topFeatures, options.outputTopFeaturesFile);
            saveFeatureWeights(featureWeights, options.outputFeatureWeightsFile);
           
            if (options.outputFeatureValues) {
                Instances featureValues = trainingController.getFeatureValues();
                saveInstances(featureValues, options.outputFeatureValuesFile);
            }
        } else {
            throw new IllegalArgumentException("Options must be for Training");
        }
    }
}
TOP

Related Classes of etc.aloe.AloeTrain

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.