Package opennlp.maxent

Source Code of opennlp.maxent.GISTrainer

/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements.  See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership.  The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License.  You may obtain a copy of the License at
*
*   http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied.  See the License for the
* specific language governing permissions and limitations
* under the License.
*/

package opennlp.maxent;

import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;

import opennlp.model.DataIndexer;
import opennlp.model.EvalParameters;
import opennlp.model.EventStream;
import opennlp.model.MutableContext;
import opennlp.model.OnePassDataIndexer;
import opennlp.model.Prior;
import opennlp.model.UniformPrior;


/**
* An implementation of Generalized Iterative Scaling.  The reference paper
* for this implementation was Adwait Ratnaparkhi's tech report at the
* University of Pennsylvania's Institute for Research in Cognitive Science,
* and is available at <a href ="ftp://ftp.cis.upenn.edu/pub/ircs/tr/97-08.ps.Z"><code>ftp://ftp.cis.upenn.edu/pub/ircs/tr/97-08.ps.Z</code></a>.
*
* The slack parameter used in the above implementation has been removed by default
* from the computation and a method for updating with Gaussian smoothing has been
* added per Investigating GIS and Smoothing for Maximum Entropy Taggers, Clark and Curran (2002). 
* <a href="http://acl.ldc.upenn.edu/E/E03/E03-1071.pdf"><code>http://acl.ldc.upenn.edu/E/E03/E03-1071.pdf</code></a>
* The slack parameter can be used by setting <code>useSlackParameter</code> to true.
* Gaussian smoothing can be used by setting <code>useGaussianSmoothing</code> to true.
*
* A prior can be used to train models which converge to the distribution which minimizes the
* relative entropy between the distribution specified by the empirical constraints of the training
* data and the specified prior.  By default, the uniform distribution is used as the prior.
*/
class GISTrainer {

  /**
   * Specifies whether unseen context/outcome pairs should be estimated as occur very infrequently.
   */
  private boolean useSimpleSmoothing = false;
 
  /**
   * Specified whether parameter updates should prefer a distribution of parameters which
   * is gaussian.
   */
  private boolean useGaussianSmoothing = false;
 
  private double sigma = 2.0;

  // If we are using smoothing, this is used as the "number" of
  // times we want the trainer to imagine that it saw a feature that it
  // actually didn't see.  Defaulted to 0.1.
  private double _smoothingObservation = 0.1;

  private final boolean printMessages;

  /**
   * Number of unique events which occured in the event set.
   */
  private int numUniqueEvents;
 
  /**
   * Number of predicates.
   */
  private int numPreds;
 
  /**
   * Number of outcomes.
   */
  private int numOutcomes;

  /**
   * Records the array of predicates seen in each event.
   */
  private int[][] contexts;
 
  /**
   * The value associated with each context. If null then context values are assumes to be 1.
   */
  private float[][] values;
 
  /**
   * List of outcomes for each event i, in context[i].
   */
  private int[] outcomeList;

  /**
   * Records the num of times an event has been seen for each event i, in context[i].
   */
  private int[] numTimesEventsSeen;
 
  /**
   * The number of times a predicate occured in the training data.
   */
  private int[] predicateCounts;
 
  private int cutoff;

  /**
   * Stores the String names of the outcomes. The GIS only tracks outcomes as
   * ints, and so this array is needed to save the model to disk and thereby
   * allow users to know what the outcome was in human understandable terms.
   */
  private String[] outcomeLabels;

  /**
   * Stores the String names of the predicates. The GIS only tracks predicates
   * as ints, and so this array is needed to save the model to disk and thereby
   * allow users to know what the outcome was in human understandable terms.
   */
  private String[] predLabels;

  /**
   * Stores the observed expected values of the features based on training data.
   */
  private MutableContext[] observedExpects;

  /**
   * Stores the estimated parameter value of each predicate during iteration
   */
  private MutableContext[] params;

  /**
   * Stores the expected values of the features based on the current models
   */
  private MutableContext[][] modelExpects;

  /**
   * This is the prior distribution that the model uses for training.
   */
  private Prior prior;

  private static final double LLThreshold = 0.0001;

  /**
   * Initial probability for all outcomes.
   */
  private EvalParameters evalParams;

  /**
   * Creates a new <code>GISTrainer</code> instance which does not print
   * progress messages about training to STDOUT.
   *
   */
  GISTrainer() {
    printMessages = false;
  }

  /**
   * Creates a new <code>GISTrainer</code> instance.
   *
   * @param printMessages sends progress messages about training to
   *                      STDOUT when true; trains silently otherwise.
   */
  GISTrainer(boolean printMessages) {
    this.printMessages = printMessages;
  }

  /**
   * Sets whether this trainer will use smoothing while training the model.
   * This can improve model accuracy, though training will potentially take
   * longer and use more memory.  Model size will also be larger.
   *
   * @param smooth true if smoothing is desired, false if not
   */
  public void setSmoothing(boolean smooth) {
    useSimpleSmoothing = smooth;
  }

  /**
   * Sets whether this trainer will use smoothing while training the model.
   * This can improve model accuracy, though training will potentially take
   * longer and use more memory.  Model size will also be larger.
   *
   * @param timesSeen the "number" of times we want the trainer to imagine
   *                  it saw a feature that it actually didn't see
   */
  public void setSmoothingObservation(double timesSeen) {
    _smoothingObservation = timesSeen;
  }
 
  /**
   * Sets whether this trainer will use smoothing while training the model.
   * This can improve model accuracy, though training will potentially take
   * longer and use more memory.  Model size will also be larger.
   *
   * @param smooth true if smoothing is desired, false if not
   */
  public void setGaussianSigma(double sigmaValue) {
    useGaussianSmoothing = true;
    sigma = sigmaValue;
  }

  /**
   * Trains a GIS model on the event in the specified event stream, using the specified number
   * of iterations and the specified count cutoff.
   * @param eventStream A stream of all events.
   * @param iterations The number of iterations to use for GIS.
   * @param cutoff The number of times a feature must occur to be included.
   * @return A GIS model trained with specified
   */
  public GISModel trainModel(EventStream eventStream, int iterations, int cutoff) throws IOException {
    return trainModel(iterations, new OnePassDataIndexer(eventStream,cutoff),cutoff);
  }
 
  /**
   * Train a model using the GIS algorithm.
   *
   * @param iterations  The number of GIS iterations to perform.
   * @param di The data indexer used to compress events in memory.
   * @return The newly trained model, which can be used immediately or saved
   *         to disk using an opennlp.maxent.io.GISModelWriter object.
   */
  public GISModel trainModel(int iterations, DataIndexer di, int cutoff) {
    return trainModel(iterations,di,new UniformPrior(),cutoff,1);
  }

  /**
   * Train a model using the GIS algorithm.
   *
   * @param iterations  The number of GIS iterations to perform.
   * @param di The data indexer used to compress events in memory.
   * @param modelPrior The prior distribution used to train this model.
   * @return The newly trained model, which can be used immediately or saved
   *         to disk using an opennlp.maxent.io.GISModelWriter object.
   */
  public GISModel trainModel(int iterations, DataIndexer di, Prior modelPrior, int cutoff, int threads) {
   
    if (threads <= 0)
      throw new IllegalArgumentException("threads must be at leat one or greater!");
   
    modelExpects = new MutableContext[threads][];
   
    /************** Incorporate all of the needed info ******************/
    display("Incorporating indexed data for training...  \n");
    contexts = di.getContexts();
    values = di.getValues();
    this.cutoff = cutoff;
    predicateCounts = di.getPredCounts();
    numTimesEventsSeen = di.getNumTimesEventsSeen();
    numUniqueEvents = contexts.length;
    this.prior = modelPrior;
    //printTable(contexts);

    // determine the correction constant and its inverse
    double correctionConstant = 0;
    for (int ci = 0; ci < contexts.length; ci++) {
      if (values == null || values[ci] == null) {
        if (contexts[ci].length > correctionConstant) {
          correctionConstant = contexts[ci].length;
        }
      }
      else {
        float cl = values[ci][0];
        for (int vi=1;vi<values[ci].length;vi++) {
          cl+=values[ci][vi];
        }
       
        if (cl > correctionConstant) {
          correctionConstant = cl;
        }
      }
    }
    display("done.\n");

    outcomeLabels = di.getOutcomeLabels();
    outcomeList = di.getOutcomeList();
    numOutcomes = outcomeLabels.length;

    predLabels = di.getPredLabels();
    prior.setLabels(outcomeLabels,predLabels);
    numPreds = predLabels.length;

    display("\tNumber of Event Tokens: " + numUniqueEvents + "\n");
    display("\t    Number of Outcomes: " + numOutcomes + "\n");
    display("\t  Number of Predicates: " + numPreds + "\n");

    // set up feature arrays
    float[][] predCount = new float[numPreds][numOutcomes];
    for (int ti = 0; ti < numUniqueEvents; ti++) {
      for (int j = 0; j < contexts[ti].length; j++) {
        if (values != null && values[ti] != null) {
          predCount[contexts[ti][j]][outcomeList[ti]] += numTimesEventsSeen[ti]*values[ti][j];
        }
        else {         
          predCount[contexts[ti][j]][outcomeList[ti]] += numTimesEventsSeen[ti];
        }
      }
    }

    //printTable(predCount);
    di = null; // don't need it anymore

    // A fake "observation" to cover features which are not detected in
    // the data.  The default is to assume that we observed "1/10th" of a
    // feature during training.
    final double smoothingObservation = _smoothingObservation;

    // Get the observed expectations of the features. Strictly speaking,
    // we should divide the counts by the number of Tokens, but because of
    // the way the model's expectations are approximated in the
    // implementation, this is cancelled out when we compute the next
    // iteration of a parameter, making the extra divisions wasteful.
    params = new MutableContext[numPreds];
    for (int i = 0; i< modelExpects.length; i++)
      modelExpects[i] = new MutableContext[numPreds];
    observedExpects = new MutableContext[numPreds];
   
    // The model does need the correction constant and the correction feature. The correction constant
    // is only needed during training, and the correction feature is not necessary.
    // For compatibility reasons the model contains form now on a correction constant of 1,
    // and a correction param 0.
    evalParams = new EvalParameters(params,0,1,numOutcomes);
    int[] activeOutcomes = new int[numOutcomes];
    int[] outcomePattern;
    int[] allOutcomesPattern= new int[numOutcomes];
    for (int oi = 0; oi < numOutcomes; oi++) {
      allOutcomesPattern[oi] = oi;
    }
    int numActiveOutcomes = 0;
    for (int pi = 0; pi < numPreds; pi++) {
      numActiveOutcomes = 0;
      if (useSimpleSmoothing) {
        numActiveOutcomes = numOutcomes;
        outcomePattern = allOutcomesPattern;
      }
      else { //determine active outcomes
        for (int oi = 0; oi < numOutcomes; oi++) {
          if (predCount[pi][oi] > 0 && predicateCounts[pi] >= cutoff) {
            activeOutcomes[numActiveOutcomes] = oi;
            numActiveOutcomes++;
          }
        }
        if (numActiveOutcomes == numOutcomes) {
          outcomePattern = allOutcomesPattern;
        }
        else {
          outcomePattern = new int[numActiveOutcomes];
          for (int aoi=0;aoi<numActiveOutcomes;aoi++) {
            outcomePattern[aoi] = activeOutcomes[aoi];
          }
        }
      }
      params[pi] = new MutableContext(outcomePattern,new double[numActiveOutcomes]);
      for (int i = 0; i< modelExpects.length; i++)
        modelExpects[i][pi] = new MutableContext(outcomePattern,new double[numActiveOutcomes]);
      observedExpects[pi] = new MutableContext(outcomePattern,new double[numActiveOutcomes]);
      for (int aoi=0;aoi<numActiveOutcomes;aoi++) {
        int oi = outcomePattern[aoi];
        params[pi].setParameter(aoi, 0.0);
        for (int i = 0; i< modelExpects.length; i++)
          modelExpects[i][pi].setParameter(aoi, 0.0);
        if (predCount[pi][oi] > 0) {
            observedExpects[pi].setParameter(aoi, predCount[pi][oi]);
        }
        else if (useSimpleSmoothing) {
          observedExpects[pi].setParameter(aoi,smoothingObservation);
        }
      }
    }

    predCount = null; // don't need it anymore

    display("...done.\n");

    /***************** Find the parameters ************************/
    if (threads == 1)
      display("Computing model parameters ...\n");
    else
      display("Computing model parameters in " + threads +" threads...\n");
   
    findParameters(iterations, correctionConstant);

    /*************** Create and return the model ******************/
    // To be compatible with old models the correction constant is always 1
    return new GISModel(params, predLabels, outcomeLabels, 1, evalParams.getCorrectionParam());

  }

  /* Estimate and return the model parameters. */
  private void findParameters(int iterations, double correctionConstant) {
    double prevLL = 0.0;
    double currLL = 0.0;
    display("Performing " + iterations + " iterations.\n");
    for (int i = 1; i <= iterations; i++) {
      if (i < 10)
        display("  " + i + ":  ");
      else if (i < 100)
        display(" " + i + ":  ");
      else
        display(i + ":  ");
      currLL = nextIteration(correctionConstant);
      if (i > 1) {
        if (prevLL > currLL) {
          System.err.println("Model Diverging: loglikelihood decreased");
          break;
        }
        if (currLL - prevLL < LLThreshold) {
          break;
        }
      }
      prevLL = currLL;
    }

    // kill a bunch of these big objects now that we don't need them
    observedExpects = null;
    modelExpects = null;
    numTimesEventsSeen = null;
    contexts = null;
  }
 
  //modeled on implementation in  Zhang Le's maxent kit
  private double gaussianUpdate(int predicate, int oid, int n, double correctionConstant) {
    double param = params[predicate].getParameters()[oid];
    double x0 = 0.0;
    double modelValue = modelExpects[0][predicate].getParameters()[oid];
    double observedValue = observedExpects[predicate].getParameters()[oid];
    for (int i = 0; i < 50; i++) {
      double tmp = modelValue * Math.exp(correctionConstant * x0);
      double f = tmp + (param + x0) / sigma - observedValue;
      double fp = tmp * correctionConstant + 1 / sigma;
      if (fp == 0) {
        break;
      }
      double x = x0 - f / fp;
      if (Math.abs(x - x0) < 0.000001) {
        x0 = x;
        break;
      }
      x0 = x;
    }
    return x0;
  }
 
  private class ModelExpactationComputeTask implements Callable<ModelExpactationComputeTask> {

    private final int startIndex;
    private final int length;
   
    private double loglikelihood = 0;
   
    private int numEvents = 0;
    private int numCorrect = 0;
   
    final private int threadIndex;

    // startIndex to compute, number of events to compute
    ModelExpactationComputeTask(int threadIndex, int startIndex, int length) {
      this.startIndex = startIndex;
      this.length = length;
      this.threadIndex = threadIndex;
    }
   
    public ModelExpactationComputeTask call() {
     
      final double[] modelDistribution = new double[numOutcomes];
     
     
      for (int ei = startIndex; ei < startIndex + length; ei++) {
       
        // TODO: check interruption status here, if interrupted set a poisoned flag and return
       
        if (values != null) {
          prior.logPrior(modelDistribution, contexts[ei], values[ei]);
          GISModel.eval(contexts[ei], values[ei], modelDistribution, evalParams);
        }
        else {
          prior.logPrior(modelDistribution,contexts[ei]);
          GISModel.eval(contexts[ei], modelDistribution, evalParams);
        }
        for (int j = 0; j < contexts[ei].length; j++) {
          int pi = contexts[ei][j];
          if (predicateCounts[pi] >= cutoff) {
            int[] activeOutcomes = modelExpects[threadIndex][pi].getOutcomes();
            for (int aoi=0;aoi<activeOutcomes.length;aoi++) {
              int oi = activeOutcomes[aoi];
             
              // numTimesEventsSeen must also be thread safe
              if (values != null && values[ei] != null) {
                modelExpects[threadIndex][pi].updateParameter(aoi,modelDistribution[oi] * values[ei][j] * numTimesEventsSeen[ei]);
              }
              else {
                modelExpects[threadIndex][pi].updateParameter(aoi,modelDistribution[oi] * numTimesEventsSeen[ei]);
              }
            }
          }
        }
       
        loglikelihood += Math.log(modelDistribution[outcomeList[ei]]) * numTimesEventsSeen[ei];
       
        numEvents += numTimesEventsSeen[ei];
        if (printMessages) {
          int max = 0;
          for (int oi = 1; oi < numOutcomes; oi++) {
            if (modelDistribution[oi] > modelDistribution[max]) {
              max = oi;
            }
          }
          if (max == outcomeList[ei]) {
            numCorrect += numTimesEventsSeen[ei];
          }
        }
       
      }
     
      return this;
    }
   
    synchronized int getNumEvents() {
      return numEvents;
    }
   
    synchronized int getNumCorrect() {
      return numCorrect;
    }
   
    synchronized double getLoglikelihood() {
      return loglikelihood;
    }
  }
 
  /* Compute one iteration of GIS and retutn log-likelihood.*/
  private double nextIteration(double correctionConstant) {
    // compute contribution of p(a|b_i) for each feature and the new
    // correction parameter
    double loglikelihood = 0.0;
    int numEvents = 0;
    int numCorrect = 0;
   
    int numberOfThreads = modelExpects.length;
   
    ExecutorService executor = Executors.newFixedThreadPool(numberOfThreads);
   
    int taskSize = numUniqueEvents / numberOfThreads;
   
    int leftOver = numUniqueEvents % numberOfThreads;
   
    List<Future<?>> futures = new ArrayList<Future<?>>();
   
    for (int i = 0; i < numberOfThreads; i++) {
      if (i != numberOfThreads - 1)
        futures.add(executor.submit(new ModelExpactationComputeTask(i, i*taskSize, taskSize)));
      else
        futures.add(executor.submit(new ModelExpactationComputeTask(i, i*taskSize, taskSize + leftOver)));
    }
   
    for (Future<?> future : futures) {
      ModelExpactationComputeTask finishedTask = null;
      try {
        finishedTask = (ModelExpactationComputeTask) future.get();
      } catch (InterruptedException e) {
        // TODO: We got interrupted, but that is currently not really supported!
        // For now we just print the exception and fail hard. We hopefully soon
        // handle this case properly!
        e.printStackTrace();
        throw new IllegalStateException("Interruption is not supported!", e);
      } catch (ExecutionException e) {
        // Only runtime exception can be thrown during training, if one was thrown
        // it should be re-thrown. That could for example be a NullPointerException
        // which is caused through a bug in our implementation.
        throw new RuntimeException(e.getCause());
      }
     
      // When they are done, retrieve the results ...
      numEvents += finishedTask.getNumEvents();
      numCorrect += finishedTask.getNumCorrect();
      loglikelihood += finishedTask.getLoglikelihood();
    }

    executor.shutdown();
   
    display(".");

    // merge the results of the two computations
    for (int pi = 0; pi < numPreds; pi++) {
      int[] activeOutcomes = params[pi].getOutcomes();
     
      for (int aoi=0;aoi<activeOutcomes.length;aoi++) {
        for (int i = 1; i < modelExpects.length; i++) {
          modelExpects[0][pi].updateParameter(aoi, modelExpects[i][pi].getParameters()[aoi]);
        }
      }
    }
   
    display(".");
   
    // compute the new parameter values
    for (int pi = 0; pi < numPreds; pi++) {
      double[] observed = observedExpects[pi].getParameters();
      double[] model = modelExpects[0][pi].getParameters();
      int[] activeOutcomes = params[pi].getOutcomes();
      for (int aoi=0;aoi<activeOutcomes.length;aoi++) {
        if (useGaussianSmoothing) {
          params[pi].updateParameter(aoi,gaussianUpdate(pi,aoi,numEvents,correctionConstant));
        }
        else {
          if (model[aoi] == 0) {
            System.err.println("Model expects == 0 for "+predLabels[pi]+" "+outcomeLabels[aoi]);
          }
          //params[pi].updateParameter(aoi,(Math.log(observed[aoi]) - Math.log(model[aoi])));
          params[pi].updateParameter(aoi,((Math.log(observed[aoi]) - Math.log(model[aoi]))/correctionConstant));
        }
       
        for (int i = 0; i< modelExpects.length; i++)
          modelExpects[i][pi].setParameter(aoi,0.0); // re-initialize to 0.0's

      }
    }

    display(". loglikelihood=" + loglikelihood + "\t" + ((double) numCorrect / numEvents) + "\n");
   
    return loglikelihood;
  }

  private void display(String s) {
    if (printMessages)
      System.out.print(s);
  }
}
TOP

Related Classes of opennlp.maxent.GISTrainer

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.