Package org.apache.ctakes.temporal.eval

Source Code of org.apache.ctakes.temporal.eval.EvaluationOfMetaTimeExpressionExtractor

/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements.  See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership.  The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License.  You may obtain a copy of the License at
*
*   http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied.  See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.ctakes.temporal.eval;

import java.io.File;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.logging.Level;

import org.apache.ctakes.temporal.ae.BackwardsTimeAnnotator;
import org.apache.ctakes.temporal.ae.CRFTimeAnnotator;
import org.apache.ctakes.temporal.ae.ConstituencyBasedTimeAnnotator;
import org.apache.ctakes.temporal.ae.MetaTimeAnnotator;
import org.apache.ctakes.temporal.ae.TimeAnnotator;
import org.apache.ctakes.temporal.eval.Evaluation_ImplBase.XMLFormat;
import org.apache.ctakes.typesystem.type.textsem.TimeMention;
import org.apache.ctakes.typesystem.type.textspan.Segment;
import org.apache.uima.analysis_engine.AnalysisEngine;
import org.apache.uima.analysis_engine.AnalysisEngineDescription;
import org.apache.uima.collection.CollectionReader;
import org.apache.uima.jcas.JCas;
import org.apache.uima.jcas.tcas.Annotation;
import org.apache.uima.resource.ResourceInitializationException;
import org.cleartk.classifier.crfsuite.CRFSuiteStringOutcomeDataWriter;
import org.cleartk.classifier.jar.JarClassifierBuilder;
import org.cleartk.eval.AnnotationStatistics;
import org.uimafit.component.JCasAnnotator_ImplBase;
import org.uimafit.factory.AggregateBuilder;
import org.uimafit.pipeline.JCasIterable;
import org.uimafit.pipeline.SimplePipeline;

import com.google.common.collect.Maps;
import com.lexicalscope.jewel.cli.CliFactory;

public class EvaluationOfMetaTimeExpressionExtractor extends EvaluationOfAnnotationSpans_ImplBase {
  public static int nFolds = 5;
  private List<Integer> allTrain = null;
 
  public EvaluationOfMetaTimeExpressionExtractor(File baseDirectory,
      File rawTextDirectory, File xmlDirectory,
      org.apache.ctakes.temporal.eval.Evaluation_ImplBase.XMLFormat xmlFormat,
      File xmiDirectory, File treebankDirectory,
      List<Integer> allTrain, Class<? extends Annotation> annotationClass) {
    super(baseDirectory, rawTextDirectory, xmlDirectory, xmlFormat, xmiDirectory,
        treebankDirectory, annotationClass);
    this.allTrain = allTrain;
  }

  public static void main(String[] args) throws Exception {
    Options options = CliFactory.parseArguments(Options.class, args);
    List<Integer> patientSets = options.getPatients().getList();
    List<Integer> trainItems = null;
    List<Integer> devItems = null;
    List<Integer> testItems = null;
    if(options.getXMLFormat() == XMLFormat.I2B2){
      trainItems = I2B2Data.getTrainPatientSets(options.getXMLDirectory());
      devItems = I2B2Data.getDevPatientSets(options.getXMLDirectory());
      testItems = I2B2Data.getTestPatientSets(options.getXMLDirectory());
    }else{
      trainItems = THYMEData.getTrainPatientSets(patientSets);
      devItems = THYMEData.getDevPatientSets(patientSets);
      testItems = THYMEData.getTestPatientSets(patientSets);
    }
    List<Integer> allTrain = new ArrayList<>(trainItems);
    List<Integer> allTest = null;
   
    if(options.getTest()){
      allTrain.addAll(devItems);
      allTest = new ArrayList<>(testItems);
    }else{
      allTest = new ArrayList<>(devItems);
    }

    EvaluationOfMetaTimeExpressionExtractor eval = new
        EvaluationOfMetaTimeExpressionExtractor(
            new File("target/eval/time-spans"),
            options.getRawTextDirectory(),
            options.getXMLDirectory(),
            options.getXMLFormat(),
            options.getXMIDirectory(),
            options.getTreebankDirectory(),
            allTrain,
            TimeMention.class);
    if(options.getI2B2Output()!=null) eval.setI2B2Output(options.getI2B2Output());
    AnnotationStatistics<String> stats = eval.trainAndTest(allTrain, allTest);
    System.out.println(stats.toString());
  }

  @Override
  protected void train(CollectionReader collectionReader, File directory)
      throws Exception {
   
    Class<? extends JCasAnnotator_ImplBase>[] annotatorClasses = MetaTimeAnnotator.getComponents();
   
    // add more annotator types?
    Map<Class<? extends JCasAnnotator_ImplBase>, String[]> annotatorTrainingArguments = Maps.newHashMap();
    annotatorTrainingArguments.put(BackwardsTimeAnnotator.class, new String[]{"-c", "0.3"});
    annotatorTrainingArguments.put(TimeAnnotator.class, new String[]{"-c", "0.1"});
    annotatorTrainingArguments.put(ConstituencyBasedTimeAnnotator.class, new String[]{"-c", "0.3"});
    annotatorTrainingArguments.put(CRFTimeAnnotator.class, new String[]{"-p", "c2=0.03"});
   
    JCasIterable[] casIters = new JCasIterable[nFolds];
    for (int fold = 0; fold < nFolds; ++fold) {
      List<Integer> xfoldTrain = selectTrainItems(allTrain, nFolds, fold);
      List<Integer> xfoldTest = selectTestItems(allTrain, nFolds, fold);
      AggregateBuilder aggregateBuilder = this.getPreprocessorAggregateBuilder();
      File modelDirectory = getModelDirectory(new File("target/eval/time-spans/fold_"+fold));
      for (Class<? extends JCasAnnotator_ImplBase> annotatorClass : annotatorClasses) {
        EvaluationOfTimeSpans evaluation = new EvaluationOfTimeSpans(
            new File("target/eval/time-spans/" ),
            this.rawTextDirectory,
            this.xmlDirectory,
            this.xmlFormat,
            this.xmiDirectory,
            this.treebankDirectory,
            1,
            0,
            annotatorClass,
            false,
            annotatorTrainingArguments.get(annotatorClass));
        evaluation.prepareXMIsFor(allTrain);
        String name = String.format("%s.errors", annotatorClass.getSimpleName());
        evaluation.setLogging(Level.FINE, new File("target/eval", name));

        // train on 4 of the folds of the training data:
        evaluation.train(evaluation.getCollectionReader(xfoldTrain), modelDirectory);
        if(fold == 0){
          // train the main model as well:
          evaluation.train(evaluation.getCollectionReader(allTrain), directory);
        }
       
      }
      casIters[fold] = new JCasIterable(getCollectionReader(xfoldTest), aggregateBuilder.createAggregate());
    }
    // run meta data-writer for this fold:
    AggregateBuilder writerBuilder = new AggregateBuilder();
    writerBuilder.add(CopyFromGold.getDescription(TimeMention.class));
    writerBuilder.add(this.getDataWriterDescription(directory));
    AnalysisEngine writer = writerBuilder.createAggregate();
    for(JCasIterable casIter : casIters){
      for(JCas jcas : casIter){
        SimplePipeline.runPipeline(jcas, writer);
      }
    }
    writer.collectionProcessComplete();
    JarClassifierBuilder.trainAndPackage(getModelDirectory(directory), new String[]{"-p", "c2=0.3"});
  }
 
  private static List<Integer> selectTrainItems(List<Integer> items, int numFolds, int fold) {
    List<Integer> trainItems = new ArrayList<>();
    for (int i = 0; i < items.size(); ++i) {
      if (i % numFolds != fold) {
        trainItems.add(items.get(i));
      }
    }
    return trainItems;
  }
 
  private static List<Integer> selectTestItems(List<Integer> items, int numFolds, int fold) {
    List<Integer> trainItems = new ArrayList<>();
    for (int i = 0; i < items.size(); ++i) {
      if (i % numFolds == fold) {
        trainItems.add(items.get(i));
      }
    }
    return trainItems;
  }
 

  @Override
  protected AnalysisEngineDescription getDataWriterDescription(File directory)
      throws ResourceInitializationException {
    return MetaTimeAnnotator.getDataWriterDescription(CRFSuiteStringOutcomeDataWriter.class, directory);         
  }

  @Override
  protected void trainAndPackage(File directory) throws Exception {
    JarClassifierBuilder.trainAndPackage(getModelDirectory(directory), "-p", "c2=0.3");
  }

  @Override
  protected AnalysisEngineDescription getAnnotatorDescription(File directory)
      throws ResourceInitializationException {
    return MetaTimeAnnotator.getAnnotatorDescription(directory);
  }

  @Override
  protected Collection<? extends Annotation> getGoldAnnotations(JCas jCas,
      Segment segment) {
    return selectExact(jCas, TimeMention.class, segment);
  }

  @Override
  protected Collection<? extends Annotation> getSystemAnnotations(JCas jCas,
      Segment segment) {
    return selectExact(jCas, TimeMention.class, segment);
  }

  private static File getModelDirectory(File directory) {
    return new File(directory, "MetaTimeAnnotator");
  }

}
TOP

Related Classes of org.apache.ctakes.temporal.eval.EvaluationOfMetaTimeExpressionExtractor

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.