Package opennlp.model

Source Code of opennlp.model.TrainUtil

/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements.  See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership.  The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License.  You may obtain a copy of the License at
*
*   http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied.  See the License for the
* specific language governing permissions and limitations
* under the License.
*/

package opennlp.model;

import java.io.IOException;
import java.util.Map;

import opennlp.maxent.quasinewton.QNTrainer;
import opennlp.perceptron.PerceptronTrainer;
import opennlp.perceptron.SimplePerceptronSequenceTrainer;

public class TrainUtil {

  public static final String ALGORITHM_PARAM = "Algorithm";
 
  public static final String MAXENT_VALUE = "MAXENT";
  public static final String MAXENT_QN_VALUE = "MAXENT_QN_EXPERIMENTAL";
  public static final String PERCEPTRON_VALUE = "PERCEPTRON";
  public static final String PERCEPTRON_SEQUENCE_VALUE = "PERCEPTRON_SEQUENCE";
 
 
  public static final String CUTOFF_PARAM = "Cutoff";
  private static final int CUTOFF_DEFAULT = 5;
 
  public static final String ITERATIONS_PARAM = "Iterations";
  private static final int ITERATIONS_DEFAULT = 100;
 
  public static final String DATA_INDEXER_PARAM = "DataIndexer";
  public static final String DATA_INDEXER_ONE_PASS_VALUE = "OnePass";
  public static final String DATA_INDEXER_TWO_PASS_VALUE = "TwoPass";
 
 
  private static String getStringParam(Map<String, String> trainParams, String key,
      String defaultValue, Map<String, String> reportMap) {

    String valueString = trainParams.get(key);

    if (valueString == null)
      valueString = defaultValue;
   
    if (reportMap != null)
      reportMap.put(key, valueString);
   
    return valueString;
  }
 
  private static int getIntParam(Map<String, String> trainParams, String key,
      int defaultValue, Map<String, String> reportMap) {

    String valueString = trainParams.get(key);

    if (valueString != null)
      return Integer.parseInt(valueString);
    else
      return defaultValue;
  }
 
  private static double getDoubleParam(Map<String, String> trainParams, String key,
      double defaultValue, Map<String, String> reportMap) {
   
    String valueString = trainParams.get(key);
   
    if (valueString != null)
      return Double.parseDouble(valueString);
    else
      return defaultValue;
  }
 
  private static boolean getBooleanParam(Map<String, String> trainParams, String key,
      boolean defaultValue, Map<String, String> reportMap) {

    String valueString = trainParams.get(key);

    if (valueString != null)
      return Boolean.parseBoolean(valueString);
    else
      return defaultValue;
  }
 
  public static boolean isValid(Map<String, String> trainParams) {

    // TODO: Need to validate all parameters correctly ... error prone?!
   
    String algorithmName = trainParams.get(ALGORITHM_PARAM);

    if (algorithmName != null && !(MAXENT_VALUE.equals(algorithmName) ||
      MAXENT_QN_VALUE.equals(algorithmName) ||
        PERCEPTRON_VALUE.equals(algorithmName) ||
        PERCEPTRON_SEQUENCE_VALUE.equals(algorithmName))) {
      return false;
    }

    try {
      String cutoffString = trainParams.get(CUTOFF_PARAM);
      if (cutoffString != null) Integer.parseInt(cutoffString);
     
      String iterationsString = trainParams.get(ITERATIONS_PARAM);
      if (iterationsString != null) Integer.parseInt(iterationsString);
    }
    catch (NumberFormatException e) {
      return false;
    }
   
    String dataIndexer = trainParams.get(DATA_INDEXER_PARAM);
   
    if (dataIndexer != null) {
      if (!("OnePass".equals(dataIndexer) || "TwoPass".equals(dataIndexer))) {
        return false;
      }
    }
   
    // TODO: Check data indexing ...
    
    return true;
  }
 
 
 
  // TODO: Need a way to report results and settings back for inclusion in model ...
 
  public static AbstractModel train(EventStream events, Map<String, String> trainParams, Map<String, String> reportMap)
      throws IOException {
   
    if (!isValid(trainParams))
        throw new IllegalArgumentException("trainParams are not valid!");
   
    if(isSequenceTraining(trainParams))
      throw new IllegalArgumentException("sequence training is not supported by this method!");
   
    String algorithmName = getStringParam(trainParams, ALGORITHM_PARAM, MAXENT_VALUE, reportMap);
   
    int iterations = getIntParam(trainParams, ITERATIONS_PARAM, ITERATIONS_DEFAULT, reportMap);
       
    int cutoff = getIntParam(trainParams, CUTOFF_PARAM, CUTOFF_DEFAULT, reportMap);

    boolean sortAndMerge;
   
    if (MAXENT_VALUE.equals(algorithmName) || MAXENT_QN_VALUE.equals(algorithmName))
      sortAndMerge = true;
    else if (PERCEPTRON_VALUE.equals(algorithmName))
      sortAndMerge = false;
    else
      throw new IllegalStateException("Unexpected algorithm name: " + algorithmName);

    HashSumEventStream hses = new HashSumEventStream(events);
   
    String dataIndexerName = getStringParam(trainParams, DATA_INDEXER_PARAM,
        DATA_INDEXER_TWO_PASS_VALUE, reportMap);

    DataIndexer indexer = null;
   
    if (DATA_INDEXER_ONE_PASS_VALUE.equals(dataIndexerName)) {
      indexer = new OnePassDataIndexer(hses, cutoff, sortAndMerge);
    }
    else if (DATA_INDEXER_TWO_PASS_VALUE.equals(dataIndexerName)) {
      indexer = new TwoPassDataIndexer(hses, cutoff, sortAndMerge);
    }
    else {
      throw new IllegalStateException("Unexpected data indexer name: " +  dataIndexerName);
    }
   
    AbstractModel model;
    if (MAXENT_VALUE.equals(algorithmName)) {
     
      int threads = getIntParam(trainParams, "Threads", 1, reportMap);
     
      model = opennlp.maxent.GIS.trainModel(iterations, indexer,
          true, false, null, 0, threads);
    }
    else if (MAXENT_QN_VALUE.equals(algorithmName)) {
      int m = getIntParam(trainParams, "numOfUpdates", QNTrainer.DEFAULT_M, reportMap);
      int maxFctEval = getIntParam(trainParams, "maxFctEval", QNTrainer.DEFAULT_MAX_FCT_EVAL, reportMap);
      model = new QNTrainer(m, maxFctEval, true).trainModel(indexer);
    }
    else if (PERCEPTRON_VALUE.equals(algorithmName)) {
      boolean useAverage = getBooleanParam(trainParams, "UseAverage", true, reportMap);
     
      boolean useSkippedAveraging = getBooleanParam(trainParams, "UseSkippedAveraging", false, reportMap);
     
      // overwrite otherwise it might not work
      if (useSkippedAveraging)
        useAverage = true;
     
      double stepSizeDecrease = getDoubleParam(trainParams, "StepSizeDecrease", 0, reportMap);
     
      double tolerance = getDoubleParam(trainParams, "Tolerance", PerceptronTrainer.TOLERANCE_DEFAULT, reportMap);
     
      opennlp.perceptron.PerceptronTrainer perceptronTrainer = new opennlp.perceptron.PerceptronTrainer();
      perceptronTrainer.setSkippedAveraging(useSkippedAveraging);
     
      if (stepSizeDecrease > 0)
        perceptronTrainer.setStepSizeDecrease(stepSizeDecrease);
     
      perceptronTrainer.setTolerance(tolerance);
     
      model = perceptronTrainer.trainModel(
          iterations, indexer, cutoff, useAverage);
    }
    else {
      throw new IllegalStateException("Algorithm not supported: " + algorithmName);
    }
   
    if (reportMap != null)
        reportMap.put("Training-Eventhash", hses.calculateHashSum().toString(16));
   
    return model;
  }
 
  /**
   * Detects if the training algorithm requires sequence based feature generation
   * or not.
   */
  public static boolean isSequenceTraining(Map<String, String> trainParams) {
    return PERCEPTRON_SEQUENCE_VALUE.equals(trainParams.get(ALGORITHM_PARAM));
  }
 
  public static AbstractModel train(SequenceStream events, Map<String, String> trainParams,
      Map<String, String> reportMap) throws IOException {
   
    if (!isValid(trainParams))
      throw new IllegalArgumentException("trainParams are not valid!");
 
    if (!isSequenceTraining(trainParams))
      throw new IllegalArgumentException("Algorithm must be a sequence algorithm!");
   
    int iterations = getIntParam(trainParams, ITERATIONS_PARAM, ITERATIONS_DEFAULT, reportMap);
    int cutoff = getIntParam(trainParams, CUTOFF_PARAM, CUTOFF_DEFAULT, reportMap);
   
    boolean useAverage = getBooleanParam(trainParams, "UseAverage", true, reportMap);
   
    return new SimplePerceptronSequenceTrainer().trainModel(
        iterations, events, cutoff,useAverage);
  }
}
TOP

Related Classes of opennlp.model.TrainUtil

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.