if (partition != -1)
MachineReadingProperties.logger.info("In partition #" + partition);
String modelName = MachineReadingProperties.serializedRelationExtractorPath;
if (partition != -1)
modelName += "." + partition;
Annotation predicted = null;
if (MachineReadingProperties.useRelationExtractionModelMerging) {
String[] modelNames = MachineReadingProperties.serializedRelationExtractorPath.split(",");
if (partition != -1) {
for (int i = 0; i < modelNames.length; i++) {
modelNames[i] += "." + partition;
}
}
relationExtractor = ExtractorMerger.buildRelationExtractorMerger(modelNames);
} else if (!this.forceRetraining&& new File(modelName).exists()) {
MachineReadingProperties.logger.info("Loading relation extraction model from " + modelName + " ...");
//TODO change this to load any type of BasicRelationExtractor
relationExtractor = BasicRelationExtractor.load(modelName);
} else {
RelationFeatureFactory rff = makeRelationFeatureFactory(MachineReadingProperties.relationFeatureFactoryClass, MachineReadingProperties.relationFeatures, MachineReadingProperties.doNotLexicalizeFirstArg);
Execution.fillOptions(rff, args);
if (MachineReadingProperties.trainRelationsUsingPredictedEntities) {
// generate predicted entities
assert(entityExtractor != null);
predicted = AnnotationUtils.deepMentionCopy(training);
entityExtractor.annotate(predicted);
for (ResultsPrinter rp : entityResultsPrinterSet){
String msg = rp.printResults(training, predicted);
MachineReadingProperties.logger.info("Training relation extraction using predicted entitities: entity scores using printer " + rp.getClass() + ":\n" + msg);
}
// change relation mentions to use predicted entity mentions rather than gold ones
try {
changeGoldRelationArgsToPredicted(predicted);
} catch (Exception e) {
// we may get here for unknown EntityMentionComparator class
throw new RuntimeException(e);
}
}
Annotation dataset;
if (MachineReadingProperties.trainRelationsUsingPredictedEntities) {
dataset = predicted;
} else {
dataset = training;
}
Set<String> relationsToSkip = new HashSet<String>(StringUtils.split(MachineReadingProperties.relationsToSkipDuringTraining, ","));
List<List<RelationMention>> backedUpRelations = new ArrayList<List<RelationMention>>();
if (relationsToSkip.size() > 0) {
// we need to backup the relations since removeSkippableRelations modifies dataset in place and we can't duplicate CoreMaps safely (or can we?)
for (CoreMap sent : dataset.get(CoreAnnotations.SentencesAnnotation.class)) {
List<RelationMention> relationMentions = sent.get(MachineReadingAnnotations.RelationMentionsAnnotation.class);
backedUpRelations.add(relationMentions);
}
removeSkippableRelations(dataset, relationsToSkip);
}
//relationExtractor = new BasicRelationExtractor(rff, MachineReadingProperties.createUnrelatedRelations, makeRelationMentionFactory(MachineReadingProperties.relationMentionFactoryClass));
relationExtractor = makeRelationExtractor(MachineReadingProperties.relationClassifier, rff, MachineReadingProperties.createUnrelatedRelations,
makeRelationMentionFactory(MachineReadingProperties.relationMentionFactoryClass));
Execution.fillOptions(relationExtractor, args);
//Arguments.parse(args,relationExtractor);
MachineReadingProperties.logger.info("Training relation extraction model...");
relationExtractor.train(dataset);
MachineReadingProperties.logger.info("Serializing relation extraction model to " + modelName + " ...");
relationExtractor.save(modelName);
if (relationsToSkip.size() > 0) {
// restore backed up relations into dataset
int sentenceIndex = 0;
for (CoreMap sentence : dataset.get(CoreAnnotations.SentencesAnnotation.class)) {
List<RelationMention> relationMentions = backedUpRelations.get(sentenceIndex);
sentence.set(MachineReadingAnnotations.RelationMentionsAnnotation.class, relationMentions);
sentenceIndex++;
}
}
}
}
//
// train event extraction -- currently just works with MSTBasedEventExtractor
//
if (MachineReadingProperties.extractEvents) {
MachineReadingProperties.logger.info("Training event extraction model(s)");
if (partition != -1) MachineReadingProperties.logger.info("In partition #" + partition);
String modelName = MachineReadingProperties.serializedEventExtractorPath;
if (partition != -1) modelName += "." + partition;
File modelFile = new File(modelName);
Annotation predicted = null;
if(!this.forceRetraining&& modelFile.exists()) {
MachineReadingProperties.logger.info("Loading event extraction model from " + modelName + " ...");
Method mstLoader = (Class.forName("MSTBasedEventExtractor")).getMethod("load", String.class);
eventExtractor = (Extractor) mstLoader.invoke(null, modelName);
} else {