Package edu.stanford.nlp.util

Examples of edu.stanford.nlp.util.Timing


  public void loadModelFile(String modelFile) {
    loadModelFile(modelFile, true);
  }

  private void loadModelFile(String modelFile, boolean verbose) {
    Timing t = new Timing();
    try {
      // System.err.println(Config.SEPARATOR);
      System.err.println("Loading depparse model file: " + modelFile + " ... ");
      String s;
      BufferedReader input = IOUtils.readerFromString(modelFile);

      int nDict, nPOS, nLabel;
      int eSize, hSize, nTokens, nPreComputed;
      nDict = nPOS = nLabel = eSize = hSize = nTokens = nPreComputed = 0;

      for (int k = 0; k < 7; ++k) {
        s = input.readLine();
        if (verbose) {
          System.err.println(s);
        }
        int number = Integer.parseInt(s.substring(s.indexOf('=') + 1));
        switch (k) {
          case 0:
            nDict = number;
            break;
          case 1:
            nPOS = number;
            break;
          case 2:
            nLabel = number;
            break;
          case 3:
            eSize = number;
            break;
          case 4:
            hSize = number;
            break;
          case 5:
            nTokens = number;
            break;
          case 6:
            nPreComputed = number;
            break;
          default:
            break;
        }
      }


      knownWords = new ArrayList<String>();
      knownPos = new ArrayList<String>();
      knownLabels = new ArrayList<String>();
      double[][] E = new double[nDict + nPOS + nLabel][eSize];
      String[] splits;
      int index = 0;

      for (int k = 0; k < nDict; ++k) {
        s = input.readLine();
        splits = s.split(" ");
        knownWords.add(splits[0]);
        for (int i = 0; i < eSize; ++i)
          E[index][i] = Double.parseDouble(splits[i + 1]);
        index = index + 1;
      }
      for (int k = 0; k < nPOS; ++k) {
        s = input.readLine();
        splits = s.split(" ");
        knownPos.add(splits[0]);
        for (int i = 0; i < eSize; ++i)
          E[index][i] = Double.parseDouble(splits[i + 1]);
        index = index + 1;
      }
      for (int k = 0; k < nLabel; ++k) {
        s = input.readLine();
        splits = s.split(" ");
        knownLabels.add(splits[0]);
        for (int i = 0; i < eSize; ++i)
          E[index][i] = Double.parseDouble(splits[i + 1]);
        index = index + 1;
      }
      generateIDs();

      double[][] W1 = new double[hSize][eSize * nTokens];
      for (int j = 0; j < W1[0].length; ++j) {
        s = input.readLine();
        splits = s.split(" ");
        for (int i = 0; i < W1.length; ++i)
          W1[i][j] = Double.parseDouble(splits[i]);
      }

      double[] b1 = new double[hSize];
      s = input.readLine();
      splits = s.split(" ");
      for (int i = 0; i < b1.length; ++i)
        b1[i] = Double.parseDouble(splits[i]);

      double[][] W2 = new double[nLabel * 2 - 1][hSize];
      for (int j = 0; j < W2[0].length; ++j) {
        s = input.readLine();
        splits = s.split(" ");
        for (int i = 0; i < W2.length; ++i)
          W2[i][j] = Double.parseDouble(splits[i]);
      }

      preComputed = new ArrayList<Integer>();
      while (preComputed.size() < nPreComputed) {
        s = input.readLine();
        splits = s.split(" ");
        for (String split : splits) {
          preComputed.add(Integer.parseInt(split));
        }
      }
      input.close();
      classifier = new Classifier(config, E, W1, b1, W2, preComputed);
    } catch (IOException e) {
      throw new RuntimeIOException(e);
    }

    // initialize the loaded parser
    initialize(verbose);
    t.done("Initializing dependency parser");
  }
View Full Code Here


   @param outFile File to write results to in CoNLL-X format.  If null, no output is written
   *  @return The LAS score on the dataset
   */
  public double testCoNLL(String testFile, String outFile) {
    System.err.println("Test File: " + testFile);
    Timing timer = new Timing();
    List<CoreMap> testSents = new ArrayList<>();
    List<DependencyTree> testTrees = new ArrayList<DependencyTree>();
    Util.loadConllFile(testFile, testSents, testTrees);
    // count how much to parse
    int numWords = 0;
    int numSentences = 0;
    for (CoreMap testSent : testSents) {
      numSentences += 1;
      numWords += testSent.get(CoreAnnotations.TokensAnnotation.class).size();
    }

    List<DependencyTree> predicted = testSents.stream().map(this::predictInner).collect(toList());
    Map<String, Double> result = system.evaluate(testSents, predicted, testTrees);
    double lasNoPunc = result.get("LASwoPunc");
    System.err.printf("UAS = %.4f%n", result.get("UASwoPunc"));
    System.err.printf("LAS = %.4f%n", lasNoPunc);
    long millis = timer.stop();
    double wordspersec = numWords / (((double) millis) / 1000);
    double sentspersec = numSentences / (((double) millis) / 1000);
    System.err.printf("%s tagged %d words in %d sentences in %.1fs at %.1f w/s, %.1f sent/s.%n",
            StringUtils.getShortClassName(this), numWords, numSentences, millis / 1000.0, wordspersec, sentspersec);

View Full Code Here

    preprocessor.setSentenceFinalPuncWords(config.tlp.sentenceFinalPunctuationWords());
    preprocessor.setEscaper(config.escaper);
    preprocessor.setSentenceDelimiter(config.sentenceDelimiter);
    preprocessor.setTokenizerFactory(config.tlp.getTokenizerFactory());

    Timing timer = new Timing();

    MaxentTagger tagger = new MaxentTagger(config.tagger);
    List<List<TaggedWord>> tagged = new ArrayList<>();
    for (List<HasWord> sentence : preprocessor) {
      tagged.add(tagger.tagSentence(sentence));
    }

    System.err.printf("Tagging completed in %.2f sec.%n",
        timer.stop() / 1000.0);

    timer.start();

    int numSentences = 0;
    for (List<TaggedWord> taggedSentence : tagged) {
      GrammaticalStructure parse = predict(taggedSentence);

      Collection<TypedDependency> deps = parse.typedDependencies();
      for (TypedDependency dep : deps)
        output.println(dep);
      output.println();

      numSentences++;
    }

    long millis = timer.stop();
    double seconds = millis / 1000.0;
    System.err.printf("Parsed %d sentences in %.2f seconds (%.2f sents/sec).%n",
        numSentences, seconds, numSentences / seconds);
  }
View Full Code Here

    // we use QN to minimize the cost function for the model
    // to do this minimization, we turn all of the matrices in the
    //   DVModel into one big Theta, which is the set of variables to
    //   be optimized by the QN.

    Timing timing = new Timing();
    long maxTrainTimeMillis = op.trainOptions.maxTrainTimeSeconds * 1000;
    int batchCount = 0;
    int debugCycle = 0;
    double bestLabelF1 = 0.0;

    if (op.trainOptions.useContextWords) {
      for (Tree tree : sentences) {
        Trees.convertToCoreLabels(tree);
        tree.setSpans();
      }
    }

    // for AdaGrad
    double[] sumGradSquare = new double[dvModel.totalParamSize()];
    Arrays.fill(sumGradSquare, 1.0);

    int numBatches = sentences.size() / op.trainOptions.batchSize + 1;
    System.err.println("Training on " + sentences.size() + " trees in " + numBatches + " batches");
    System.err.println("Times through each training batch: " + op.trainOptions.trainingIterations);
    System.err.println("QN iterations per batch: " + op.trainOptions.qnIterationsPerBatch);
    for (int iter = 0; iter < op.trainOptions.trainingIterations; ++iter) {
      List<Tree> shuffledSentences = new ArrayList<Tree>(sentences);
      Collections.shuffle(shuffledSentences, dvModel.rand);
      for (int batch = 0; batch < numBatches; ++batch) {
        ++batchCount;
        // This did not help performance
        //System.err.println("Setting AdaGrad's sum of squares to 1...");
        //Arrays.fill(sumGradSquare, 1.0);

        System.err.println("======================================");
        System.err.println("Iteration " + iter + " batch " + batch);

        // Each batch will be of the specified batch size, except the
        // last batch will include any leftover trees at the end of
        // the list
        int startTree = batch * op.trainOptions.batchSize;
        int endTree = (batch + 1) * op.trainOptions.batchSize;
        if (endTree + op.trainOptions.batchSize > shuffledSentences.size()) {
          endTree = shuffledSentences.size();
        }

        executeOneTrainingBatch(shuffledSentences.subList(startTree, endTree), compressedParses, sumGradSquare);

        long totalElapsed = timing.report();
        System.err.println("Finished iteration " + iter + " batch " + batch + "; total training time " + totalElapsed + " ms");

        if (maxTrainTimeMillis > 0 && totalElapsed > maxTrainTimeMillis) {
          // no need to debug output, we're done now
          break;
        }

        if (op.trainOptions.debugOutputFrequency > 0 && batchCount % op.trainOptions.debugOutputFrequency == 0) {
          System.err.println("Finished " + batchCount + " total batches, running evaluation cycle");
          // Time for debugging output!
          double tagF1 = 0.0;
          double labelF1 = 0.0;
          if (testTreebank != null) {
            EvaluateTreebank evaluator = new EvaluateTreebank(attachModelToLexicalizedParser());
            evaluator.testOnTreebank(testTreebank);
            labelF1 = evaluator.getLBScore();
            tagF1 = evaluator.getTagScore();
            if (labelF1 > bestLabelF1) {
              bestLabelF1 = labelF1;
            }
            System.err.println("Best label f1 on dev set so far: " + NF.format(bestLabelF1));
          }

          String tempName = null;
          if (modelPath != null) {
            tempName = modelPath;
            if (modelPath.endsWith(".ser.gz")) {
              tempName = modelPath.substring(0, modelPath.length() - 7) + "-" + FILENAME.format(debugCycle) + "-" + NF.format(labelF1) + ".ser.gz";
            }
            saveModel(tempName);
          }

          String statusLine = ("CHECKPOINT:" +
                               " iteration " + iter +
                               " batch " + batch +
                               " labelF1 " + NF.format(labelF1) +
                               " tagF1 " + NF.format(tagF1) +
                               " bestLabelF1 " + NF.format(bestLabelF1) +
                               " model " + tempName +
                               op.trainOptions +
                               " word vectors: " + op.lexOptions.wordVectorFile +
                               " numHid: " + op.lexOptions.numHid);
          System.err.println(statusLine);
          if (resultsRecordPath != null) {
            FileWriter fout = new FileWriter(resultsRecordPath, true); // append
            fout.write(statusLine);
            fout.write("\n");
            fout.close();
          }

          ++debugCycle;
        }
      }
      long totalElapsed = timing.report();

      if (maxTrainTimeMillis > 0 && totalElapsed > maxTrainTimeMillis) {
        // no need to debug output, we're done now
        System.err.println("Max training time exceeded, exiting");
        break;
View Full Code Here

  }

  static final int MINIMIZER = 3;

  public void executeOneTrainingBatch(List<Tree> trainingBatch, IdentityHashMap<Tree, byte[]> compressedParses, double[] sumGradSquare) {
    Timing convertTiming = new Timing();
    convertTiming.doing("Converting trees");
    IdentityHashMap<Tree, List<Tree>> topParses = CacheParseHypotheses.convertToTrees(trainingBatch, compressedParses, op.trainOptions.trainingThreads);
    convertTiming.done();

    DVParserCostAndGradient gcFunc = new DVParserCostAndGradient(trainingBatch, topParses, dvModel, op);
    double[] theta = dvModel.paramsToVector();

    //maxFuncIter = 10;
View Full Code Here

   @param printLoading Whether to print a message saying what model file is being loaded and how long it took when finished.
   *  @throws RuntimeIOException if I/O errors or serialization errors
   */
  protected void readModelAndInit(Properties config, DataInputStream rf, boolean printLoading) {
    try {
      Timing t = new Timing();
      if (printLoading) {
        String source = null;
        if (config != null) {
          // TODO: "model"
          source = config.getProperty("model");
        }
        if (source == null) {
          source = "data stream";
        }
        t.doing("Reading POS tagger model from " + source);
      }
      TaggerConfig taggerConfig = TaggerConfig.readConfig(rf);
      if (config != null) {
        taggerConfig.setProperties(config);
      }
      // then init tagger
      init(taggerConfig);

      xSize = rf.readInt();
      ySize = rf.readInt();
      // dict = new Dictionary();  // this method is called in constructor, and it's initialized as empty already
      dict.read(rf);

      if (VERBOSE) {
        System.err.println(" dictionary read ");
      }
      tags.read(rf);
      readExtractors(rf);
      dict.setAmbClasses(ambClasses, veryCommonWordThresh, tags);

      int[] numFA = new int[extractors.size() + extractorsRare.size()];
      int sizeAssoc = rf.readInt();
      fAssociations = Generics.newArrayList();
      for (int i = 0; i < extractors.size() + extractorsRare.size(); ++i) {
        fAssociations.add(Generics.<String, int[]>newHashMap());
      }
      if (VERBOSE) System.err.printf("Reading %d feature keys...%n",sizeAssoc);
      PrintFile pfVP = null;
      if (VERBOSE) {
        pfVP = new PrintFile("pairs.txt");
      }
      for (int i = 0; i < sizeAssoc; i++) {
        int numF = rf.readInt();
        FeatureKey fK = new FeatureKey();
        fK.read(rf);
        numFA[fK.num]++;

        // TODO: rewrite the writing / reading code to store
        // fAssociations in a cleaner manner?  Only do this when
        // rebuilding all the tagger models anyway.  When we do that, we
        // can get rid of FeatureKey
        Map<String, int[]> fValueAssociations = fAssociations.get(fK.num);
        int[] fTagAssociations = fValueAssociations.get(fK.val);
        if (fTagAssociations == null) {
          fTagAssociations = new int[ySize];
          for (int j = 0; j < ySize; ++j) {
            fTagAssociations[j] = -1;
          }
          fValueAssociations.put(fK.val, fTagAssociations);
        }
        fTagAssociations[tags.getIndex(fK.tag)] = numF;
      }
      if (VERBOSE) {
        IOUtils.closeIgnoringExceptions(pfVP);
      }
      if (VERBOSE) {
        for (int k = 0; k < numFA.length; k++) {
          System.err.println(" Number of features of kind " + k + ' ' + numFA[k]);
        }
      }
      prob = new LambdaSolveTagger(rf);
      if (VERBOSE) {
        System.err.println(" prob read ");
      }
      if (printLoading) t.done();
    } catch (IOException e) {
      throw new RuntimeIOException("Unrecoverable error while loading a tagger model", e);
    } catch (ClassNotFoundException e) {
      throw new RuntimeIOException("Unrecoverable error while loading a tagger model", e);
    }
View Full Code Here

    }

    try {
      MaxentTagger tagger = new MaxentTagger(config.getModel(), config);

      Timing t = new Timing();
      TestClassifier testClassifier = new TestClassifier(tagger);
      long millis = t.stop();
      printErrWordsPerSec(millis, testClassifier.getNumWords());
      testClassifier.printModelAndAccuracy(tagger);
    } catch (Exception e) {
      System.err.println("An error occurred while testing the tagger.");
      e.printStackTrace();
View Full Code Here

  {
    Date now = new Date();

    System.err.println("## tagger training invoked at " + now + " with arguments:");
    config.dump();
    Timing tim = new Timing();

    PrintFile log = new PrintFile(config.getModel() + ".props");
    log.println("## tagger training invoked at " + now + " with arguments:");
    config.dump(log);
    log.close();

    trainAndSaveModel(config);
    tim.done("Training POS tagger");
  }
View Full Code Here

      }
    }

    // Some optimization methods prints out a line without an end, so our
    // debugging statements are misaligned
    Timing scoreTiming = new Timing();
    scoreTiming.doing("Scoring trees");
    int treeNum = 0;
    MulticoreWrapper<Tree, Pair<DeepTree, DeepTree>> wrapper = new MulticoreWrapper<Tree, Pair<DeepTree, DeepTree>>(op.trainOptions.trainingThreads, new ScoringProcessor());
    for (Tree tree : trainingBatch) {
      wrapper.put(tree);
    }
    wrapper.join();
    scoreTiming.done();
    while (wrapper.peek()) {
      Pair<DeepTree, DeepTree> result = wrapper.poll();
      DeepTree goldTree = result.first;
      DeepTree bestTree = result.second;
View Full Code Here

    int nThreads = op.trainOptions.trainingThreads;
    nThreads = nThreads <= 0 ? Runtime.getRuntime().availableProcessors() : nThreads;

    Tagger tagger = null;
    if (op.testOptions.preTag) {
      Timing retagTimer = new Timing();
      tagger = Tagger.loadModel(op.testOptions.taggerSerializedFile);
      redoTags(binarizedTrees, tagger, nThreads);
      retagTimer.done("Retagging");
    }

    Set<String> knownStates = findKnownStates(binarizedTrees);
    Set<String> rootStates = findRootStates(binarizedTrees);
    Set<String> rootOnlyStates = findRootOnlyStates(binarizedTrees, rootStates);

    System.err.println("Known states: " + knownStates);
    System.err.println("States which occur at the root: " + rootStates);
    System.err.println("States which only occur at the root: " + rootStates);

    Timing transitionTimer = new Timing();
    List<List<Transition>> transitionLists = CreateTransitionSequence.createTransitionSequences(binarizedTrees, op.compoundUnaries, rootStates, rootOnlyStates);
    Index<Transition> transitionIndex = new HashIndex<Transition>();
    for (List<Transition> transitions : transitionLists) {
      transitionIndex.addAll(transitions);
    }
    transitionTimer.done("Converting trees into transition lists");
    System.err.println("Number of transitions: " + transitionIndex.size());

    Random random = new Random(op.trainOptions.randomSeed);

    Treebank devTreebank = null;
View Full Code Here

TOP

Related Classes of edu.stanford.nlp.util.Timing

Copyright © 2018 www.massapicom. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.