Package etc.aloe.factories

Source Code of etc.aloe.factories.CSCW2013$SingleOptionsImpl

/*
* This file is part of ALOE.
*
* ALOE is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.

* ALOE is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
* GNU General Public License for more details.

* You should have received a copy of the GNU General Public License
* along with ALOE.  If not, see <http://www.gnu.org/licenses/>.
*
* Copyright (c) 2012 SCCL, University of Washington (http://depts.washington.edu/sccl)
*/
package etc.aloe.factories;

import etc.aloe.controllers.CrossValidationController;
import etc.aloe.controllers.LabelingController;
import etc.aloe.controllers.TrainingController;
import etc.aloe.cscw2013.DownsampleBalancing;
import etc.aloe.cscw2013.FeatureExtractionImpl;
import etc.aloe.cscw2013.FeatureGenerationImpl;
import etc.aloe.cscw2013.LabelMappingImpl;
import etc.aloe.cscw2013.NullSegmentation;
import etc.aloe.cscw2013.ResolutionImpl;
import etc.aloe.cscw2013.SMOFeatureWeighting;
import etc.aloe.cscw2013.ThresholdSegmentation;
import etc.aloe.cscw2013.TrainingImpl;
import etc.aloe.cscw2013.UpsampleBalancing;
import etc.aloe.cscw2013.WekaModel;
import etc.aloe.data.Model;
import etc.aloe.filters.StringToDictionaryVector;
import etc.aloe.options.InteractiveOptions;
import etc.aloe.options.LabelOptions;
import etc.aloe.options.ModeOptions;
import etc.aloe.options.SingleOptions;
import etc.aloe.options.TrainOptions;
import etc.aloe.processes.Balancing;
import etc.aloe.processes.FeatureExtraction;
import etc.aloe.processes.FeatureGeneration;
import etc.aloe.processes.FeatureWeighting;
import etc.aloe.processes.LabelMapping;
import etc.aloe.processes.SegmentResolution;
import etc.aloe.processes.Segmentation;
import etc.aloe.processes.Training;
import java.io.File;
import java.io.FileNotFoundException;
import java.text.DateFormat;
import java.text.SimpleDateFormat;
import java.util.List;
import org.kohsuke.args4j.Option;

/**
* Provides implementations for the CSCW 2013 pipeline.
*
* @author Michael Brooks <mjbrooks@uw.edu>
*/
public class CSCW2013 implements PipelineFactory {

    public ModeOptions options;

    @Override
    public void initialize() {
        double falseNegativeCost = 1;
        double falsePositiveCost = 1;

        if (options instanceof TrainOptionsImpl) {
            TrainOptionsImpl trainOpts = (TrainOptionsImpl) options;
            falseNegativeCost = trainOpts.falseNegativeCost;
            falsePositiveCost = trainOpts.falsePositiveCost;
        } else if (options instanceof LabelOptionsImpl) {
            LabelOptionsImpl labelOpts = (LabelOptionsImpl) options;
            falseNegativeCost = labelOpts.falseNegativeCost;
            falsePositiveCost = labelOpts.falsePositiveCost;
        }

        //Normalize the cost factors (sum to 2)
        double costNormFactor = 0.5 * (falseNegativeCost + falsePositiveCost);
        falseNegativeCost /= costNormFactor;
        falsePositiveCost /= costNormFactor;
        System.out.println("Costs normalized to " + falseNegativeCost + " (FN) " + falsePositiveCost + " (FP).");

        if (options instanceof TrainOptionsImpl) {
            TrainOptionsImpl trainOpts = (TrainOptionsImpl) options;
            trainOpts.falseNegativeCost = falseNegativeCost;
            trainOpts.falsePositiveCost = falsePositiveCost;
        } else if (options instanceof LabelOptionsImpl) {
            LabelOptionsImpl labelOpts = (LabelOptionsImpl) options;
            labelOpts.falseNegativeCost = falseNegativeCost;
            labelOpts.falsePositiveCost = falsePositiveCost;
        }
    }

    protected List<String> loadTermList(File emoticonFile) {
        try {
            return StringToDictionaryVector.readDictionaryFile(emoticonFile);
        } catch (FileNotFoundException ex) {
            System.err.println("Unable to read emoticon dictionary file " + emoticonFile);
            System.err.println("\t" + ex.getMessage());
            System.exit(1);
        }
        return null;
    }

    @Override
    public Model constructModel() {
        return new WekaModel();
    }

    @Override
    public FeatureExtraction constructFeatureExtraction() {
        return new FeatureExtractionImpl();
    }

    @Override
    public FeatureGeneration constructFeatureGeneration() {
        if (options instanceof TrainOptionsImpl) {
            TrainOptionsImpl trainOpts = (TrainOptionsImpl) options;
            //Read the emoticons
            List<String> termList = loadTermList(trainOpts.emoticonFile);
            FeatureGenerationImpl featureGen = new FeatureGenerationImpl(termList);
            featureGen.setParticipantFeatureCount(trainOpts.participantFeatures);
            return featureGen;
        } else {
            throw new IllegalArgumentException("Options not for Training");
        }
    }

    @Override
    public LabelMapping constructLabelMapping() {
        return new LabelMappingImpl();
    }

    @Override
    public SegmentResolution constructSegmentResolution() {
        return new ResolutionImpl();
    }

    @Override
    public FeatureWeighting constructFeatureWeighting() {
        return new SMOFeatureWeighting();
    }

    @Override
    public Segmentation constructSegmentation() {
        boolean disableSegmentation = false;
        int segmentationThresholdSeconds = 30;
        boolean ignoreParticipants = false;

        if (options instanceof TrainOptionsImpl) {
            TrainOptionsImpl trainOpts = (TrainOptionsImpl) options;
            disableSegmentation = trainOpts.disableSegmentation;
            segmentationThresholdSeconds = trainOpts.segmentationThresholdSeconds;
            ignoreParticipants = trainOpts.ignoreParticipants;
        } else if (options instanceof LabelOptionsImpl) {
            LabelOptionsImpl labelOpts = (LabelOptionsImpl) options;
            disableSegmentation = labelOpts.disableSegmentation;
            segmentationThresholdSeconds = labelOpts.segmentationThresholdSeconds;
            ignoreParticipants = labelOpts.ignoreParticipants;
        } else {
            throw new IllegalArgumentException("Options should be for Training or Labeling");
        }

        if (disableSegmentation) {
            return new NullSegmentation();
        } else {
            Segmentation segmentation = new ThresholdSegmentation(segmentationThresholdSeconds,
                    !ignoreParticipants);
            segmentation.setSegmentResolution(new ResolutionImpl());
            return segmentation;
        }
    }

    @Override
    public Training constructTraining() {
        if (options instanceof TrainOptionsImpl) {
            TrainOptionsImpl trainOpts = (TrainOptionsImpl) options;

            TrainingImpl trainingImpl = new TrainingImpl();
           
            trainingImpl.setBuildLogisticModel(true);
           
            if (trainOpts.useMinCost || trainOpts.useReweighting) {
                trainingImpl.setUseCostTraining(true);
                trainingImpl.setFalsePositiveCost(trainOpts.falsePositiveCost);
                trainingImpl.setFalseNegativeCost(trainOpts.falseNegativeCost);
                trainingImpl.setUseReweighting(trainOpts.useReweighting);
            }

            return trainingImpl;
        } else {
            throw new IllegalArgumentException("Options must be for Training");
        }
    }

    @Override
    public Balancing constructBalancing() {
        if (options instanceof TrainOptionsImpl) {
            TrainOptionsImpl trainOpts = (TrainOptionsImpl) options;
            if (trainOpts.useDownsampling) {
                return new DownsampleBalancing(trainOpts.falsePositiveCost, trainOpts.falseNegativeCost);
            } else if (trainOpts.useUpsampling) {
                return new UpsampleBalancing(trainOpts.falsePositiveCost, trainOpts.falseNegativeCost);
            } else {
                return null;
            }
        } else {
            throw new IllegalArgumentException("Options must be for Training");
        }
    }

    @Override
    public void configureLabeling(LabelingController labelingController) {
        if (options instanceof LabelOptions) {
            LabelOptionsImpl labelOpts = (LabelOptionsImpl) options;

            //Options
            labelingController.setCosts(labelOpts.falsePositiveCost, labelOpts.falseNegativeCost);

            //Implementations
            labelingController.setFeatureExtractionImpl(constructFeatureExtraction());
            labelingController.setMappingImpl(constructLabelMapping());

        }
    }

    @Override
    public void configureCrossValidation(CrossValidationController crossValidationController) {
        if (options instanceof TrainOptionsImpl) {
            TrainOptionsImpl trainOpts = (TrainOptionsImpl) options;
            //Implementations
            crossValidationController.setFeatureGenerationImpl(this.constructFeatureGeneration());
            crossValidationController.setFeatureExtractionImpl(this.constructFeatureExtraction());
            crossValidationController.setTrainingImpl(this.constructTraining());
            crossValidationController.setBalancingImpl(this.constructBalancing());
            crossValidationController.setMappingImpl(this.constructLabelMapping());
           
            //Options
            crossValidationController.setFolds(trainOpts.crossValidationFolds);
            crossValidationController.setCosts(trainOpts.falsePositiveCost, trainOpts.falseNegativeCost);
            crossValidationController.setBalanceTestSet(trainOpts.balanceTestSet);
        } else {
            throw new IllegalArgumentException("Options must be for Training");
        }
    }

    @Override
    public void configureTraining(TrainingController trainingController) {
        trainingController.setFeatureGenerationImpl(this.constructFeatureGeneration());
        trainingController.setFeatureExtractionImpl(this.constructFeatureExtraction());
        trainingController.setTrainingImpl(this.constructTraining());
        trainingController.setFeatureWeightingImpl(this.constructFeatureWeighting());
        trainingController.setBalancingImpl(this.constructBalancing());
    }

    @Override
    public DateFormat constructDateFormat() {
        return new SimpleDateFormat(options.dateFormatString);
    }

    @Override
    public InteractiveOptions constructInteractiveOptions() {
        return new InteractiveOptionsImpl();
    }

    @Override
    public LabelOptions constructLabelOptions() {
        return new LabelOptionsImpl();
    }

    @Override
    public TrainOptions constructTrainOptions() {
        return new TrainOptionsImpl();
    }

    @Override
    public SingleOptions constructSingleOptions() {
        return new SingleOptionsImpl();
    }

    @Override
    public void setOptions(ModeOptions options) {
        this.options = options;
    }

    static class InteractiveOptionsImpl extends InteractiveOptions {
    }

    static class SingleOptionsImpl extends SingleOptions {
    }

    static class LabelOptionsImpl extends LabelOptions {

        @Option(name = "--fp-cost", usage = "the cost of a false positive (default 1)", metaVar = "COST")
        public double falsePositiveCost = 1;
        @Option(name = "--fn-cost", usage = "the cost of a false negative (default 1)", metaVar = "COST")
        public double falseNegativeCost = 1;
        @Option(name = "--ignore-participants", usage = "ignore participants during segmentation")
        public boolean ignoreParticipants = false;
        @Option(name = "--threshold", aliases = {"-t"}, usage = "segmentation threshold in seconds (default 30)", metaVar = "SECONDS")
        public int segmentationThresholdSeconds = 30;
        @Option(name = "--no-segmentation", usage = "disable segmentation (each message is in its own segment)")
        public boolean disableSegmentation = false;
    }

    static class TrainOptionsImpl extends TrainOptions {
        @Option(name="--participant-features", usage="use up to this many participant names as features")
        public int participantFeatures = 0;
        @Option(name = "--fp-cost", usage = "the cost of a false positive (default 1)", metaVar = "COST")
        public double falsePositiveCost = 1;
        @Option(name = "--fn-cost", usage = "the cost of a false negative (default 1)", metaVar = "COST")
        public double falseNegativeCost = 1;
        @Option(name = "--ignore-participants", usage = "ignore participants during segmentation")
        public boolean ignoreParticipants = false;
        @Option(name = "--threshold", aliases = {"-t"}, usage = "segmentation threshold in seconds (default 30)", metaVar = "SECONDS")
        public int segmentationThresholdSeconds = 30;
        @Option(name = "--no-segmentation", usage = "disable segmentation (each message is in its own segment)")
        public boolean disableSegmentation = false;
        @Option(name = "--upsample", aliases = {"-us"}, usage = "upsample the minority class in training sets to match the cost ratio")
        public boolean useUpsampling = false;
        @Option(name = "--reweight", aliases = {"-rw"}, usage = "reweight the training data")
        public boolean useReweighting = false;
        @Option(name = "--min-cost", usage = "train a classifier that uses the min-cost criterion")
        public boolean useMinCost = false;
        @Option(name = "--downsample", aliases = {"-ds"}, usage = "downsample the majority class in training sets to match the cost ratio")
        public boolean useDownsampling = false;
        @Option(name = "--folds", aliases = {"-k"}, usage = "number of cross-validation folds (default 10, 0 to disable cross validation)", metaVar = "FOLDS")
        public int crossValidationFolds = 10;
        @Option(name = "--balance-test-set", usage = "apply balancing to the test set as well as the training set")
        public boolean balanceTestSet = false;
        @Option(name = "--emoticons", aliases = {"-e"}, usage = "emoticon dictionary file (default emoticons.txt)")
        public File emoticonFile = new File("emoticons.txt");
    }
}
TOP

Related Classes of etc.aloe.factories.CSCW2013$SingleOptionsImpl

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.