Package org.apache.ctakes.relationextractor.eval

Source Code of org.apache.ctakes.relationextractor.eval.SHARPXMI$Options

/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements.  See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership.  The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License.  You may obtain a copy of the License at
*
*   http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied.  See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.ctakes.relationextractor.eval;

import java.io.File;
import java.io.FileOutputStream;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.regex.Pattern;

import org.apache.ctakes.core.ae.SHARPKnowtatorXMLReader;
import org.apache.ctakes.core.util.DocumentIDAnnotationUtil;
import org.apache.ctakes.typesystem.type.structured.DocumentID;
import org.apache.uima.UIMAFramework;
import org.apache.uima.analysis_engine.AnalysisEngineProcessException;
import org.apache.uima.cas.CAS;
import org.apache.uima.cas.CASException;
import org.apache.uima.cas.impl.XmiCasSerializer;
import org.apache.uima.collection.CollectionReader;
import org.apache.uima.jcas.JCas;
import org.apache.uima.util.XMLInputSource;
import org.apache.uima.util.XMLParser;
import org.apache.uima.util.XMLSerializer;
import org.cleartk.eval.AnnotationStatistics;
import org.cleartk.util.ViewURIUtil;
import org.cleartk.util.ae.UriToDocumentTextAnnotator;
import org.cleartk.util.cr.UriCollectionReader;
import org.uimafit.component.JCasAnnotator_ImplBase;
import org.uimafit.component.ViewCreatorAnnotator;
import org.uimafit.factory.AggregateBuilder;
import org.uimafit.factory.AnalysisEngineFactory;
import org.uimafit.factory.CollectionReaderFactory;
import org.uimafit.factory.TypeSystemDescriptionFactory;
import org.uimafit.pipeline.JCasIterable;
import org.xml.sax.ContentHandler;

import com.google.common.base.Function;
import com.google.common.base.Functions;
import com.google.common.collect.Lists;
import com.google.common.collect.Ordering;
import com.lexicalscope.jewel.cli.Option;

public class SHARPXMI {

  public static List<File> getTrainTextFiles(File batchesDirectory) {
    // seed_set1: batches 2, 3, 4, 5, 6, 7, 8, 9, 13, 14, 15, 16, 18, 19
    // seed_set2: batches 1, 2, 3, 4, 5, 6, 7, 8, 9, 13, 14, 15, 16, 18, 19
    // seed_set3: batches 1, 2, 3, 4, 5, 6, 7, 8, 9, 11, 12, 15, 16, 18, 19
    // seed_set4: batches 1, 2, 3, 4, 5, 6, 7, 8, 9, 11, 12, 15, 16, 18, 19
    return getTextFilesFor(
        batchesDirectory,
        Pattern.compile("^(ss[1234]_batch0[2-9]|ss[1234]_batch1[56]"
            + "|ss[1234]_batch1[89]|ss[123]_batch01"
            + "|ss[12]_batch1[34]|ss[34]_batch1[12])$"));
  }

  public static List<File> getDevTextFiles(File batchesDirectory) {
    // seed_set1: batches 10, 17
    // seed_set2: batches 10, 17
    // seed_set3: batches 10, 17
    // seed_set4: batches 10, 17
    return getTextFilesFor(batchesDirectory, Pattern.compile("^(ss[1234]_batch1[07])$"));
  }

  public static List<File> getTestTextFiles(File batchesDirectory) {
    // seed_set1: batches 11, 12
    // seed_set2: batches 11, 12
    // seed_set3: batches 13, 14
    // seed_set4: batches 13, 14
    return getTextFilesFor(
        batchesDirectory,
        Pattern.compile("^(ss[12]_batch1[12]|ss[34]_batch1[34])$"));
  }

  public static List<File> getAllTextFiles(File batchesDirectory) {
    return getTextFilesFor(batchesDirectory, Pattern.compile(""));
  }

  private static List<File> getTextFilesFor(File batchesDirectory, Pattern pattern) {
    List<File> files = Lists.newArrayList();
    for (File batchDir : batchesDirectory.listFiles()) {
      if (batchDir.isDirectory() && !batchDir.isHidden()) {
        if (pattern.matcher(batchDir.getName()).find()) {
          File textDirectory = new File(batchDir, "Knowtator/text");
          for (File textFile : textDirectory.listFiles()) {
            if (textFile.isFile() && !textFile.isHidden()) {
              files.add(textFile);
            }
          }
        }
      }
    }
    return files;
  }

  public static List<File> toXMIFiles(Options options, List<File> textFiles) {
    List<File> xmiFiles = Lists.newArrayList();
    for (File textFile : textFiles) {
      xmiFiles.add(toXMIFile(options, textFile));
    }
    return xmiFiles;
  }

  private static File toXMIFile(Options options, File textFile) {
    return new File(options.getXMIDirectory(), textFile.getName() + ".xmi");
  }

  public static interface Options {
    @Option(
        longName = "batches-dir",
        description = "directory containing ssN_batchNN directories, each of which should contain "
            + "a Knowtator directory and a Knowtator_XML directory")
    public File getBatchesDirectory();

    @Option(
        longName = "xmi-dir",
        defaultValue = "target/xmi",
        description = "directory to store and load XMI serialization of annotations")
    public File getXMIDirectory();

    @Option(
        longName = "generate-xmi",
        description = "read in the gold annotations and serialize them as XMI")
    public boolean getGenerateXMI();
  }

  public static final String GOLD_VIEW_NAME = "GoldView";

  public static void generateXMI(Options options) throws Exception {
    // if necessary, write the XMIs first
    if (options.getGenerateXMI()) {
      if (!options.getXMIDirectory().exists()) {
        options.getXMIDirectory().mkdirs();
      }

      // create a collection reader that loads URIs for all Knowtator text files
      List<File> files = Lists.newArrayList();
      files.addAll(getTrainTextFiles(options.getBatchesDirectory()));
      files.addAll(getDevTextFiles(options.getBatchesDirectory()));
      files.addAll(getTestTextFiles(options.getBatchesDirectory()));
      CollectionReader reader = UriCollectionReader.getCollectionReaderFromFiles(files);

      // load the text from the URI, run the preprocessor, then run the
      // Knowtator XML reader
      AggregateBuilder builder = new AggregateBuilder();
      builder.add(UriToDocumentTextAnnotator.getDescription());
      File preprocessDescFile = new File("desc/analysis_engine/RelationExtractorPreprocessor.xml");
      XMLParser parser = UIMAFramework.getXMLParser();
      XMLInputSource source = new XMLInputSource(preprocessDescFile);
      builder.add(parser.parseAnalysisEngineDescription(source));
      builder.add(AnalysisEngineFactory.createPrimitiveDescription(
          ViewCreatorAnnotator.class,
          ViewCreatorAnnotator.PARAM_VIEW_NAME,
          GOLD_VIEW_NAME));
      builder.add(AnalysisEngineFactory.createPrimitiveDescription(CopyDocumentTextToGoldView.class));
      builder.add(
          AnalysisEngineFactory.createPrimitiveDescription(DocumentIDAnnotator.class),
          CAS.NAME_DEFAULT_SOFA,
          GOLD_VIEW_NAME);
      builder.add(
          AnalysisEngineFactory.createPrimitiveDescription(SHARPKnowtatorXMLReader.class),
          CAS.NAME_DEFAULT_SOFA,
          GOLD_VIEW_NAME);

      // write out an XMI for each file
      for (JCas jCas : new JCasIterable(reader, builder.createAggregate())) {
        JCas goldView = jCas.getView(GOLD_VIEW_NAME);
        String documentID = DocumentIDAnnotationUtil.getDocumentID(goldView);
        if (documentID == null) {
          throw new IllegalArgumentException("No documentID for CAS:\n" + jCas);
        }
        File outFile = toXMIFile(options, new File(documentID));
        FileOutputStream stream = new FileOutputStream(outFile);
        ContentHandler handler = new XMLSerializer(stream).getContentHandler();
        new XmiCasSerializer(jCas.getTypeSystem()).serialize(jCas.getCas(), handler);
        stream.close();
      }
    }
  }

  public enum EvaluateOn {
    TRAIN, DEV, TEST
  }

  public static interface EvaluationOptions extends Options {
    @Option(
        longName = "evalute-on",
        defaultValue = "DEV",
        description = "perform evaluation using the training (TRAIN), development (DEV) or test "
            + "(TEST) data.")
    public EvaluateOn getEvaluteOn();

    @Option(
        longName = "grid-search",
        description = "run a grid search to select the best parameters")
    public boolean getGridSearch();
  }
 
  public static abstract class Evaluation_ImplBase extends org.cleartk.eval.Evaluation_ImplBase<File, AnnotationStatistics<String>> {

    public Evaluation_ImplBase(File baseDirectory) {
      super(baseDirectory);
    }

    @Override
    public CollectionReader getCollectionReader(List<File> items) throws Exception {
      return CollectionReaderFactory.createCollectionReader(
          XMIReader.class,
          TypeSystemDescriptionFactory.createTypeSystemDescription(),
          XMIReader.PARAM_FILES,
          items);
    }

  }

  public static void validate(EvaluationOptions options) throws Exception {
    // error on invalid option combinations
    if (options.getEvaluteOn().equals(EvaluateOn.TEST) && options.getGridSearch()) {
      throw new IllegalArgumentException("grid search can only be run on the train or dev sets");
    }
  }

  public static <T extends Evaluation_ImplBase> void evaluate(
      EvaluationOptions options,
      ParameterSettings bestSettings,
      List<ParameterSettings> gridOfSettings,
      Function<ParameterSettings, T> getEvaluation) throws Exception {
    // define the set of possible training parameters
    List<ParameterSettings> possibleParams;
    if (options.getGridSearch()) {
      possibleParams = gridOfSettings;
    } else {
      possibleParams = Lists.newArrayList(bestSettings);
    }

    // run an evaluation for each set of parameters
    Map<ParameterSettings, Double> scoredParams = new HashMap<ParameterSettings, Double>();
    for (ParameterSettings params : possibleParams) {
      Evaluation_ImplBase evaluation = getEvaluation.apply(params);

      List<File> trainFiles, devFiles, testFiles;
      switch (options.getEvaluteOn()) {
      case TRAIN:
        // run n-fold cross-validation on the training set
        trainFiles = getTrainTextFiles(options.getBatchesDirectory());
        trainFiles = toXMIFiles(options, trainFiles);
        List<AnnotationStatistics<String>> foldStats = evaluation.crossValidation(trainFiles, 2);
        params.stats = AnnotationStatistics.addAll(foldStats);
        break;
      case DEV:
        // train on the training set and evaluate on the dev set
        trainFiles = getTrainTextFiles(options.getBatchesDirectory());
        trainFiles = toXMIFiles(options, trainFiles);
        devFiles = getDevTextFiles(options.getBatchesDirectory());
        devFiles = toXMIFiles(options, devFiles);
        params.stats = evaluation.trainAndTest(trainFiles, devFiles);
        break;
      case TEST:
        // train on the training set + dev set and evaluate on the test set
        List<File> allTrainFiles = new ArrayList<File>();
        allTrainFiles.addAll(getTrainTextFiles(options.getBatchesDirectory()));
        allTrainFiles.addAll(getDevTextFiles(options.getBatchesDirectory()));
        allTrainFiles = toXMIFiles(options, allTrainFiles);
        testFiles = getTestTextFiles(options.getBatchesDirectory());
        testFiles = toXMIFiles(options, testFiles);
        params.stats = evaluation.trainAndTest(allTrainFiles, testFiles);
        break;
      default:
        throw new IllegalArgumentException("Invalid EvaluateOn: " + options.getEvaluteOn());
      }
      scoredParams.put(params, params.stats.f1());
    }

    // print parameters sorted by F1
    List<ParameterSettings> list = new ArrayList<ParameterSettings>(scoredParams.keySet());
    Function<ParameterSettings, Double> getCount = Functions.forMap(scoredParams);
    Collections.sort(list, Ordering.natural().onResultOf(getCount));

    // print performance of each set of parameters
    if (list.size() > 1) {
      System.err.println("Summary");
      for (ParameterSettings params : list) {
        System.err.printf(
            "F1=%.3f P=%.3f R=%.3f %s\n",
            params.stats.f1(),
            params.stats.precision(),
            params.stats.recall(),
            params);
      }
      System.err.println();
    }

    // print overall best model
    if (!list.isEmpty()) {
      ParameterSettings lastParams = list.get(list.size() - 1);
      System.err.println("Best model:");
      System.err.print(lastParams.stats);
      System.err.println(lastParams);
      System.err.println(lastParams.stats.confusions());
      System.err.println();
    }
  }

  public static class DocumentIDAnnotator extends JCasAnnotator_ImplBase {

    @Override
    public void process(JCas jCas) throws AnalysisEngineProcessException {
      String documentID = new File(ViewURIUtil.getURI(jCas)).getPath();
      DocumentID documentIDAnnotation = new DocumentID(jCas);
      documentIDAnnotation.setDocumentID(documentID);
      documentIDAnnotation.addToIndexes();
    }
  }

  public static class CopyDocumentTextToGoldView extends JCasAnnotator_ImplBase {
    @Override
    public void process(JCas jCas) throws AnalysisEngineProcessException {
      try {
        JCas goldView = jCas.getView(GOLD_VIEW_NAME);
        goldView.setDocumentText(jCas.getDocumentText());
      } catch (CASException e) {
        throw new AnalysisEngineProcessException(e);
      }
    }
  }
}
TOP

Related Classes of org.apache.ctakes.relationextractor.eval.SHARPXMI$Options

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.