Package edu.stanford.nlp.tagger.maxent

Source Code of edu.stanford.nlp.tagger.maxent.TestClassifier

package edu.stanford.nlp.tagger.maxent;

import edu.stanford.nlp.io.NumberRangesFileFilter;
import edu.stanford.nlp.io.PrintFile;
import edu.stanford.nlp.ling.TaggedWord;
import edu.stanford.nlp.ling.Word;
import edu.stanford.nlp.ling.Sentence;
import edu.stanford.nlp.maxent.CGRunner;
import edu.stanford.nlp.maxent.Problem;
import edu.stanford.nlp.trees.*;

import java.io.BufferedReader;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.ArrayList;
import java.util.List;
import java.util.StringTokenizer;
import java.util.Arrays;


/** Tags data and can handle either data with gold-standard tags (computing
*  performance statistics) or unlabeled data.
*  (The Constructor actually runs the tagger. The main entry points are the
*  static methods at the bottom of the class.)
*
*  Also can train data using the saveModel method.  This class is really the entry
*  point to all tagger operations, it seems.
*
@author Kristina Toutanova
@version 1.0
*/
public class TestClassifier {

  private String filename;
  private TestSentence ts;
  private int numRight;
  private int numWrong;
  private int unknownWords;
  private int numWrongUnknown;
  private int numCorrectSentences;
  private int numSentences;
  static boolean writeUnknDict = false;
  static boolean writeWords = false;
  static boolean writeTopWords = false;
  private Dictionary wrongWords = new Dictionary();
  // Dictionary unknownWordsDict = new Dictionary();


  TestClassifier(TaggerConfig config) throws Exception {
    /* format can be either of 1 and 2
     *  1 means the test file has the correct tags and is in format one word tag per line
     *  0 means the test file does not have the correct tags and is just tokenized.
     *  In this case the file is in the format one sentence per line (not ending in eos)
     */
    //Default format is 1
    this(config, 1);
  }

  /** format can be either of 1 or 0
   *  1 means the test file has the correct tags and is in format one word tag per line
   *  0 means the test file does not have the correct tags and is just tokenized.
   *  In this case the file is in the format one sentence per line (not ending in eos)
   *  @throws Exception
   */
  private TestClassifier(TaggerConfig config, int format) throws Exception {
    this.filename = config.getFile();
    ts = new TestSentence(GlobalHolder.getLambdaSolve());

    String dPrefix = config.getDebugPrefix();
    if (dPrefix == null || dPrefix.equals("")) {
      dPrefix = config.getFile();
    }
    if (config.getInitFromTrees()) {
      test(config, dPrefix);
    } else {
      //throw new UnsupportedOperationException();
      test(format, dPrefix, config.getTagSeparator(), config.getEncoding());
    }
  }

  /**
   * Begin tagging. The format variable (one of 0,1) determines whether data are assumed
   * to have gold standard tags (1) or to be unlabeled (0).
   * @throws IOException
   */
  private void test(int format, String saveRoot, String tagSeparator, String encoding) throws IOException {
    if ((format == 1)) {
      test(saveRoot, tagSeparator, encoding); //the data is tagged
    } else {
      BufferedReader in = new BufferedReader(new InputStreamReader(new FileInputStream(filename), encoding));
      PrintFile pf = null;
      PrintFile pf1 = null;
      if (writeWords) pf = new PrintFile(saveRoot + ".words");
      if (writeUnknDict) pf1 = new PrintFile(saveRoot + ".un.dict");

      for (String s; (s = in.readLine()) != null; ) {
        Sentence<Word> sent = Sentence.toSentence(Arrays.asList(s.split("\\s+")));
        ts.tagSentence(sent);
        if (pf != null) {
          pf.println(ts.getTaggedNice());
        }
      }

      in.close();
      if (pf != null) pf.close();
      if (pf1 != null) pf1.close();
    }
  }

  /**
   * Test on a file containing correct tags already when init'ing from trees
   * TODO: Add the ability to have a second transformer to transform output back; possibly combine this method
   * with method below
   */
  private void test(TaggerConfig config, String saveRoot) throws IOException {
    numSentences = 0;
    String eosTag = "EOS";
    String eosWord = "EOS";
    PrintFile pf = null;
    PrintFile pf1 = null;
    PrintFile pf3 = null;

    if(writeWords) pf = new PrintFile(saveRoot + ".words");
    if(writeUnknDict) pf1 = new PrintFile(saveRoot + ".un.dict");
    if(writeTopWords) pf3 = new PrintFile(saveRoot + ".words.top");
    TreeReaderFactory trf = new LabeledScoredTreeReaderFactory();
    DiskTreebank treebank = new DiskTreebank(trf,config.getEncoding());
    TreeTransformer transformer = config.getTreeTransformer();
    TreeNormalizer normalizer = config.getTreeNormalizer();

    if (config.getTreeRange() != null) {
      treebank.loadPath(filename, new NumberRangesFileFilter(config.getTreeRange(), true));
    } else {
      treebank.loadPath(filename);
    }
    for (Tree t : treebank) {
      if (normalizer != null) {
        t = normalizer.normalizeWholeTree(t, t.treeFactory());
      }
      if (transformer != null) {
        t = t.transform(transformer);
      }

      List<String> sentence = new ArrayList<String>();
      List<String> tagsArr = new ArrayList<String>();

      for (TaggedWord cur : t.taggedYield()) {
        tagsArr.add(cur.tag());
        sentence.add(cur.word());
      }
      //the sentence is read already, add eos
      sentence.add(eosWord);
      tagsArr.add(eosTag);
      numSentences++;

      int len = sentence.size();
      String[] testSent = new String[len];
      String[] correctTags = new String[len];
      for (int i = 0; i < len; i++) {
        testSent[i] = sentence.get(i);
        correctTags[i] = tagsArr.get(i);
      }

      TestSentence testS = new TestSentence(GlobalHolder.getLambdaSolve(), testSent, correctTags, pf, wrongWords);
      if (writeUnknDict) testS.printUnknown(numSentences, pf1);
      if (writeTopWords) testS.printTop(pf3);

      numWrong = numWrong + testS.numWrong;
      numRight = numRight + testS.numRight;
      unknownWords = unknownWords + testS.numUnknown;
      numWrongUnknown = numWrongUnknown + testS.numWrongUnknown;
      if (testS.numWrong == 0) {
        numCorrectSentences++;
      }
      System.out.println("Sentence number: " + numSentences + "; length " + (len-1) + "; correct: " + testS.numRight + "; wrong: " + testS.numWrong + "; unknown wrong: " + testS.numWrongUnknown);
      System.out.println("  Total tags correct: " + numRight + "; wrong: " + numWrong + "; unknown wrong: " + numWrongUnknown);
    }

    if(pf != null) pf.close();
    if(pf1 != null) pf1.close();
    if(pf3 != null) pf3.close();
  }

  /**
   * Test on a file containing correct tags already.
   */
  private void test(String saveRoot, String tagSeparator, String encoding) throws IOException {
    numSentences = 0;
    String eosTag = "EOS";
    String eosWord = "EOS";
    PrintFile pf = null;
    PrintFile pf1 = null;
    PrintFile pf3 = null;

    BufferedReader rf = new BufferedReader(new InputStreamReader(new FileInputStream(filename), encoding));
    if (writeWords) pf = new PrintFile(saveRoot + ".words");
    if (writeUnknDict) pf1 = new PrintFile(saveRoot + ".un.dict");
    if (writeTopWords) pf3 = new PrintFile(saveRoot + ".words.top");

    for (String s; (s = rf.readLine()) != null; ) {
      List<String> sentence = new ArrayList<String>();
      List<String> tagsArr = new ArrayList<String>();
      StringTokenizer st = new StringTokenizer(s);
      while (st.hasMoreTokens()) { // find the sentence there

        String token = st.nextToken();
        int index = token.lastIndexOf(tagSeparator);

        if (index == -1) {
          throw new RuntimeException("I was unable to find the delimiter '" + tagSeparator + "' in the token '" + token + "'. Consider using -delimiter.");
        }

        String w1 = token.substring(0, index);
        sentence.add(w1);
        String t1 = token.substring(index + 1);
        tagsArr.add(t1);
      }

      //the sentence is read already, add eos
      sentence.add(eosWord);
      tagsArr.add(eosTag);
      numSentences++;

      int len = sentence.size();
      String[] testSent = new String[len];
      String[] correctTags = new String[len];
      for (int i = 0; i < len; i++) {
        testSent[i] = sentence.get(i);
        correctTags[i] = tagsArr.get(i);
      }

      TestSentence testS = new TestSentence(GlobalHolder.getLambdaSolve(), testSent, correctTags, pf, wrongWords);
      if(writeUnknDict) testS.printUnknown(numSentences, pf1);
      if(writeTopWords) testS.printTop(pf3);

      numWrong = numWrong + testS.numWrong;
      numRight = numRight + testS.numRight;
      unknownWords = unknownWords + testS.numUnknown;
      numWrongUnknown = numWrongUnknown + testS.numWrongUnknown;
      if (testS.numWrong == 0) {
        numCorrectSentences++;
      }
      System.out.println("Sentence number: " + numSentences + "; length " + (len-1) + "; correct: " + testS.numRight + "; wrong: " + testS.numWrong + "; unknown wrong: " + testS.numWrongUnknown);
      System.out.println("  Total tags correct: " + numRight + "; wrong: " + numWrong + "; unknown wrong: " + numWrongUnknown);
    }

    rf.close();
    if (pf != null) pf.close();
    if (pf1 != null) pf1.close();
    if (pf3 != null) pf3.close();
  }

  /**
   * Warning: This method almost certainly no longer works.
   */
  /*
  @SuppressWarnings({"UnusedDeclaration"})
  private static void iterate(String filename) {
    try {
      GlobalHolder.readModelAndInit(filename);
      GlobalHolder.getLambdaSolve().improvedIterative();
      if (GlobalHolder.getLambdaSolve().checkCorrectness()) {
        System.out.println("model is correct");
      } else {
        System.out.println("model is not correct");
      }
      GlobalHolder.saveModel(filename, null);
    } catch (Exception e) {
      System.err.println("Exception while iterating.");
      e.printStackTrace();
    }
  }
  */

  /**
   * Reads in the training corpus from a filename and trains the tagger
   *
   * @param config Configuration parameters for training a model (filename, etc.
   * @throws IOException If IO problem
   */
  public static void trainAndSaveModel(TaggerConfig config) throws IOException {

    String modelName = config.getModel();
    GlobalHolder.init(config);

    // Allow clobbering.  You want it all the time when running experiments.

    TaggerExperiments samples = new TaggerExperiments(config);
    GlobalHolder.domain = samples;
    TaggerFeatures feats = samples.getTaggerFeatures();
    System.err.println("Samples from " + config.getFile());
    System.err.println("Number of features: " + feats.size());
    Problem p = new Problem(samples, feats);
    LambdaSolveTagger prob = new LambdaSolveTagger(p, 0.0001, 0.00001);
    GlobalHolder.prob = prob;

    if (config.getSearch().equals("owlqn")) {
      CGRunner runner = new CGRunner(prob, config.getModel(), config.getSigmaSquared());
      runner.solveL1(config.getRegL1());
    } else if (config.getSearch().equals("cg")) {
      CGRunner runner = new CGRunner(prob, config.getModel(), config.getSigmaSquared());
      runner.solveCG();
    } else if (config.getSearch().equals("qn")) {
      CGRunner runner = new CGRunner(prob, config.getModel(), config.getSigmaSquared());
      runner.solveQN();
    } else {
      prob.improvedIterative(config.getIterations());
    }

    if (prob.checkCorrectness()) {
      System.out.println("Model is correct [empirical expec = model expec]");
    } else {
      System.out.println("Model is not correct");
    }
    GlobalHolder.saveModel(modelName, config);
  }


   /**
   * This saves the parameters in a file like for the Improved Iterative.
   * This calculates the model from a filename, with the specified
   * parameters for the history and saves the result back to that filename.
   *
   * Warning: This method almost certainly no longer works.
   */
//  public static void save_param(String filename, String delimiter, String encoding) throws Exception {
//    GlobalHolder.init();
//    TaggerExperiments samples = new TaggerExperiments(filename, null, delimiter, encoding);
//    GlobalHolder.domain = samples;
//    TaggerFeatures feats = TaggerExperiments.feats;
//    System.out.println("Before" + feats.size());
//    //feats.print_by_numbers();
//    Problem p = new Problem(samples, feats);
//    System.out.println(" Entering lambda solve ");
//    LambdaSolveTagger prob = new LambdaSolveTagger(p, 0.0001, 0.00001);
//    GlobalHolder.prob = prob;
//    OutDataStreamFile rf = new OutDataStreamFile(filename);
//    GlobalHolder.save_prev(null, rf);
//    Runtime rt = Runtime.getRuntime();
//    System.out.println(" before " + rt.freeMemory());
//    GlobalHolder.release_mem();
//    rt.gc();
//    System.out.println(" after " + rt.freeMemory());
//    //prob.improvedIterative();
//    //prob.save_problem(filename+".math");
//    GlobalHolder.prob = prob;
//    if (prob.checkCorrectness()) {
//      System.out.println("model is correct");
//    } else {
//      System.out.println("model is not correct");
//    }
//    GlobalHolder.save_after(rf);
//    rf.close();
//
//  }


  /**
   * Warning: This method almost certainly no longer works.
   */
//  public static void expandModel(String filename, String oldModelFile, int iters, String delimiter, String encoding) throws Exception {
//    GlobalHolder.init();
//    TaggerExperiments samples = new TaggerExperiments(filename, arg_outputs, delimiter, encoding);
//    GlobalHolder.domain = samples;
//    TaggerFeatures feats = TaggerExperiments.feats;
//    System.out.println("Before" + feats.size());
//    //feats.print_by_numbers();
//    Problem p = new Problem(samples, feats);
//    LambdaSolveTagger prob = new LambdaSolveTagger(p, 0.0001, 0.00001);
//    GlobalHolder.prob = prob;
//    OutDataStreamFile rf = new OutDataStreamFile(filename);
//    GlobalHolder.save_prev(null, rf);
//    Runtime rt = Runtime.getRuntime();
//    System.out.println(" before " + rt.freeMemory());
//    GlobalHolder.release_mem();
//    rt.gc();
//    System.out.println(" after " + rt.freeMemory());
//    prob.readOldLambdas(filename, oldModelFile);
//    GlobalHolder.getLambdaSolve().improvedIterative(iters);
//    if (prob.checkCorrectness()) {
//      System.out.println("model is correct");
//    } else {
//      System.out.println("model is not correct");
//    }
//    GlobalHolder.save_after(rf);
//    rf.close();
//  }


  void printModelAndAccuracy(TaggerConfig config) {
    System.out.println("Model " + config.getModel() + " has xSize=" + GlobalHolder.xSize + ", ySize=" + GlobalHolder.ySize + ", and numFeatures=" + GlobalHolder.prob.lambda.length + '.');
    System.out.println("Results on " + numSentences + " sentences and " + (numRight + numWrong) + " words, of which " + unknownWords + " were unknown.");
    System.out.printf("Total sentences right: %d (%f%%); wrong: %d (%f%%).\n", numCorrectSentences, numCorrectSentences * 100.0 / numSentences, numSentences - numCorrectSentences, (numSentences - numCorrectSentences) * 100.0 / (numSentences));
    System.out.printf("Total tags right: %d (%f%%); wrong: %d (%f%%).\n", numRight, numRight * 100.0 / (numRight + numWrong), numWrong, numWrong * 100.0 / (numRight + numWrong));
    if (unknownWords > 0) { System.out.printf("Unknown words right: %d (%f%%); wrong: %d (%f%%).\n", (unknownWords - numWrongUnknown), 100.0 - (numWrongUnknown * 100.0 / unknownWords), numWrongUnknown, numWrongUnknown * 100.0 / unknownWords); }
  }


  int getNumWords() {
    return numRight + numWrong;
  }

  static void setDebug(boolean status) {
    writeUnknDict = status;
    writeWords = status;
    writeTopWords = status;
  }


}
TOP

Related Classes of edu.stanford.nlp.tagger.maxent.TestClassifier

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.