Package org.fnlp.nlp.parser.dep.train

Source Code of org.fnlp.nlp.parser.dep.train.JointParerTester

/**
*  This file is part of FNLP (formerly FudanNLP).
*  FNLP is free software: you can redistribute it and/or modify
*  it under the terms of the GNU Lesser General Public License as published by
*  the Free Software Foundation, either version 3 of the License, or
*  (at your option) any later version.
*  FNLP is distributed in the hope that it will be useful,
*  but WITHOUT ANY WARRANTY; without even the implied warranty of
*  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
*  GNU Lesser General Public License for more details.
*  You should have received a copy of the GNU General Public License
*  along with FudanNLP.  If not, see <http://www.gnu.org/licenses/>.
*  Copyright 2009-2014 www.fnlp.org. All rights reserved.
*/

package org.fnlp.nlp.parser.dep.train;

import java.io.BufferedWriter;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.OutputStreamWriter;

import org.apache.commons.cli.BasicParser;
import org.apache.commons.cli.CommandLine;
import org.apache.commons.cli.HelpFormatter;
import org.apache.commons.cli.Options;
import org.fnlp.nlp.parser.Sentence;
import org.fnlp.nlp.parser.Target;
import org.fnlp.nlp.parser.dep.JointParser;
import org.fnlp.nlp.parser.dep.analysis.AnalysisTest;
import org.fnlp.nlp.parser.dep.reader.CoNLLReader;
import org.fnlp.nlp.parser.dep.reader.FNLPReader;
import org.fnlp.nlp.parser.dep.reader.Malt2Reader;
import org.fnlp.util.exception.LoadModelException;

/**
* 性能测试类
*
* @version 
*/
public class JointParerTester {

  JointParser parser;
  boolean finaltest = true;

  /**
   * 构造函数
   *
   * @param modelfile
   *            模型目录
   * @throws LoadModelException
   */
  public JointParerTester(String modelfile) throws LoadModelException {
    parser = new JointParser(modelfile);
  }

  /**
   * 测试阶段
   *
   * 对于输入文件的所有句子作依赖文法分析
   *
   * @param testFile
   *            测试文件
   * @param resultFile
   *            结果文件
   * @throws Exception
   */
  public void test(String testFile, String resultFile, String charset)
      throws Exception {
    // HashMap<String, HashMap<String, Integer>> featureAlphabetByPos =
    // buildFeatureAlphabet(testFile);
    int error = 0;
    int dError = 0;
    int total = 0;
    int errsent = 0;
    int totsent = 0;
    int errroot = 0;

    System.out.print("Beginning the test ... ");
    // 输入
    FNLPReader reader = new FNLPReader (testFile);

    // 输出
    BufferedWriter writer = null;
    if (resultFile != null)
      writer = new BufferedWriter(new OutputStreamWriter(
          new FileOutputStream(resultFile), charset));

    long beginTime = System.currentTimeMillis();
    int count = 0;
    while (reader.hasNext()) {
//      System.out.println(count++);
      Sentence instance = (Sentence) reader.next();
     
      Target targets = (Target) instance.getTarget();

      int[] heads = targets.getHeads();     
      String[] depClass = targets.getRelations();
      Target t = parser.parse2R(instance);     
      String[] dPreds = (String[]) t.getRelations();
      int[] preds = (int[]) t.getHeads();
      int depCurerr = diffDepClas(depClass,dPreds);
      int curerr = diff(heads, preds);
      if(depCurerr != 0 ){
        dError += depCurerr;
      }
      if (curerr != 0) {
        errsent++;
        error += curerr;
      }
      errroot += diffRoot(heads, preds);
      totsent++;
      total += heads.length;

      if (writer != null) {
        writeTo(writer, instance, t);
      }
    }
    if (writer != null)
      writer.close();

    long endTime = System.currentTimeMillis();

    parser = null;

    float time = (endTime - beginTime) / 1000.0f;
    System.out.println("finish! =]");
    System.out.printf("total time:\t%.2f(s)\n", time);
    System.out.printf("average speed:\t%.4f(s/word)\t%.4f(s/sent)",  total
        / time, totsent / time);
    System.out.println();
    AnalysisTest at = new AnalysisTest();
    at.test(resultFile);
    /*System.out.printf("accuracy(depClass):\t%.8f\ttotal(words):\t%d\n",  1.0-1.0
        * dError / total, total);
    System.out.printf("accuracy(heads):\t%.8f\ttotal(words):\t%d\n",  1.0-1.0
        * error / total, total);
    System.out.printf("accuracy(sents):\t%.8f\ttotal(sents):\t%d\n", 1.0-1.0
        * errsent / totsent, totsent);
    System.out.printf("accuracy(root):\t%.8f\ttotal(root):\t%d\n", 1.0- 1.0
        * errroot / totsent, totsent);*/   
  }

  private void writeTo(BufferedWriter writer, Sentence instance, Target t)
      throws IOException {
   
    StringBuffer buf = new StringBuffer();
 
    String[] words = instance.getWords();
    String[] tags = instance.getTags();
    int[] heads = ((Target) instance.getTarget()).getHeads();
    String[] relations = ((Target) instance.getTarget()).getRelations();
    int[] predheads = t.getHeads();     
    String[] predRel= t.getRelations();
   
    for (int i = 0; i < words.length; i++) {
      buf.append(words[i]);
      buf.append("\t");
      buf.append(tags[i]);
      buf.append("\t");
      buf.append(heads[i]);
      buf.append("\t");
      buf.append(relations[i]);
      buf.append("\t");
      buf.append(predheads[i]);
      buf.append("\t");
      buf.append(predRel[i]);
      buf.append("\n");
    }
    writer.write(buf.toString());
    writer.newLine();
    writer.flush();
  }

  /**
   * 比较函数
   *
   * @param golds
   *            标准依存关系树
   * @param preds
   *            预测的依存关系树
   * @return 不同的依存关系的数量
   */
  private int diff(int[] golds, int[] preds) {
    int ret = 0;

    int[] ref = golds;
    if (golds.length > preds.length)
      ref = preds;
    for (int i = 0; i < ref.length; i++)
      if (golds[i] != preds[i])
        ret++;

    return ret;
  }
  private int diffDepClas(String[] golds, String[] preds) {
    int ret = 0;

    String[] ref = golds;
    if (golds.length > preds.length)
      ref = preds;
    for (int i = 0; i < ref.length; i++)
      if (!golds[i].equals(preds[i]))
        ret++;

    return ret;
  }


  private int diffRoot(int[] golds, int[] preds) {
    int ret = 0;
    for (int i = 0; i < golds.length; i++) {
      if (golds[i] == -1) {
        if (preds[i] != -1)
          ret = 1;
        break;
      }
    }
    return ret;
  }

  public static void main(String[] args) throws Exception {

    Options opt = new Options();

    opt.addOption("h", false, "Print help for this application");

    BasicParser parser = new BasicParser();
    CommandLine cl;
    try {
      cl = parser.parse(opt, args);
    } catch (Exception e) {
      System.err.println("Parameters format error");
      return;
    }

    if (args.length == 0 || cl.hasOption('h')) {
      HelpFormatter f = new HelpFormatter();
      f.printHelp(
          "Tagger:\n"
              + "ParserTester [option] model_file test_file result_file;\n",
          opt);
      return;
    }

    String[] args1 = cl.getArgs();
    String modelfile  = args1[0];
    String testfile = args1[1];
    String resultfile = args1[2];

    JointParerTester tester = new JointParerTester(modelfile);
    tester.test(testfile, resultfile, "UTF-8");
  }
}
TOP

Related Classes of org.fnlp.nlp.parser.dep.train.JointParerTester

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.