Package edu.stanford.nlp.parser.dvparser

Source Code of edu.stanford.nlp.parser.dvparser.AverageDVModels

package edu.stanford.nlp.parser.dvparser;

import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.Set;

import org.ejml.simple.SimpleMatrix;

import edu.stanford.nlp.parser.lexparser.LexicalizedParser;
import edu.stanford.nlp.util.CollectionUtils;
import java.util.function.Function;
import edu.stanford.nlp.util.Generics;
import edu.stanford.nlp.util.Pair;
import edu.stanford.nlp.util.TwoDimensionalMap;
import edu.stanford.nlp.util.TwoDimensionalSet;

/**
* Given a list of input DVParser models, this tool will output a new
* DVParser which is the average of all of those models.  Sadly, this
* does not actually seem to help; the resulting model is generally
* worse than the input models, and definitely worse than the models
* used in combination.
*
* @author John Bauer
*/
public class AverageDVModels {
  public static TwoDimensionalSet<String, String> getBinaryMatrixNames(List<TwoDimensionalMap<String, String, SimpleMatrix>> maps) {
    TwoDimensionalSet<String, String> matrixNames = new TwoDimensionalSet<String, String>();
    for (TwoDimensionalMap<String, String, SimpleMatrix> map : maps) {
      for (TwoDimensionalMap.Entry<String, String, SimpleMatrix> entry : map) {
        matrixNames.add(entry.getFirstKey(), entry.getSecondKey());
      }
    }
    return matrixNames;
  }

  public static Set<String> getUnaryMatrixNames(List<Map<String, SimpleMatrix>> maps) {
    Set<String> matrixNames = Generics.newHashSet();
    for (Map<String, SimpleMatrix> map : maps) {
      for (Map.Entry<String, SimpleMatrix> entry : map.entrySet()) {
        matrixNames.add(entry.getKey());
      }
    }
    return matrixNames;
  }

  public static TwoDimensionalMap<String, String, SimpleMatrix> averageBinaryMatrices(List<TwoDimensionalMap<String, String, SimpleMatrix>> maps) {
    TwoDimensionalMap<String, String, SimpleMatrix> averages = TwoDimensionalMap.treeMap();
    for (Pair<String, String> binary : getBinaryMatrixNames(maps)) {
      int count = 0;
      SimpleMatrix matrix = null;
      for (TwoDimensionalMap<String, String, SimpleMatrix> map : maps) {
        if (!map.contains(binary.first(), binary.second())) {
          continue;
        }
        SimpleMatrix original = map.get(binary.first(), binary.second());
        ++count;
        if (matrix == null) {
          matrix = original;
        } else {
          matrix = matrix.plus(original);
        }
      }
      matrix = matrix.divide(count);
      averages.put(binary.first(), binary.second(), matrix);
    }
    return averages;
  }

  public static Map<String, SimpleMatrix> averageUnaryMatrices(List<Map<String, SimpleMatrix>> maps) {
    Map<String, SimpleMatrix> averages = Generics.newTreeMap();
    for (String name : getUnaryMatrixNames(maps)) {
      int count = 0;
      SimpleMatrix matrix = null;
      for (Map<String, SimpleMatrix> map : maps) {
        if (!map.containsKey(name)) {
          continue;
        }
        SimpleMatrix original = map.get(name);
        ++count;
        if (matrix == null) {
          matrix = original;
        } else {
          matrix = matrix.plus(original);
        }
      }
      matrix = matrix.divide(count);
      averages.put(name, matrix);
    }
    return averages;
  }

  /**
   * Command line arguments for this program:
   * <br>
   * -output: the model file to output
   * -input: a list of model files to input
   */
  public static void main(String[] args) {
    String outputModelFilename = null;
    List<String> inputModelFilenames = Generics.newArrayList();
   
    for (int argIndex = 0; argIndex < args.length; ) {
      if (args[argIndex].equalsIgnoreCase("-output")) {
        outputModelFilename = args[argIndex + 1];
        argIndex += 2;
      } else if (args[argIndex].equalsIgnoreCase("-input")) {
        for (++argIndex; argIndex < args.length && !args[argIndex].startsWith("-"); ++argIndex) {
          inputModelFilenames.addAll(Arrays.asList(args[argIndex].split(",")));
        }
      } else {
        throw new RuntimeException("Unknown argument " + args[argIndex]);
      }
    }

    if (outputModelFilename == null) {
      System.err.println("Need to specify output model name with -output");
      System.exit(2);
    }

    if (inputModelFilenames.size() == 0) {
      System.err.println("Need to specify input model names with -input");
      System.exit(2);
    }

    System.err.println("Averaging " + inputModelFilenames);
    System.err.println("Outputting result to " + outputModelFilename);

    LexicalizedParser lexparser = null;
    List<DVModel> models = Generics.newArrayList();
    for (String filename : inputModelFilenames) {
      LexicalizedParser parser = LexicalizedParser.loadModel(filename);
      if (lexparser == null) {
        lexparser = parser;
      }
      models.add(DVParser.getModelFromLexicalizedParser(parser));
    }

    List<TwoDimensionalMap<String, String, SimpleMatrix>> binaryTransformMaps =
      CollectionUtils.transformAsList(models, model -> model.binaryTransform);

    List<TwoDimensionalMap<String, String, SimpleMatrix>> binaryScoreMaps =
      CollectionUtils.transformAsList(models, model -> model.binaryScore);

    List<Map<String, SimpleMatrix>> unaryTransformMaps =
      CollectionUtils.transformAsList(models, model -> model.unaryTransform);

    List<Map<String, SimpleMatrix>> unaryScoreMaps =
      CollectionUtils.transformAsList(models, model -> model.unaryScore);

    List<Map<String, SimpleMatrix>> wordMaps =
      CollectionUtils.transformAsList(models, model -> model.wordVectors);

    TwoDimensionalMap<String, String, SimpleMatrix> binaryTransformAverages = averageBinaryMatrices(binaryTransformMaps);
    TwoDimensionalMap<String, String, SimpleMatrix> binaryScoreAverages = averageBinaryMatrices(binaryScoreMaps);
    Map<String, SimpleMatrix> unaryTransformAverages = averageUnaryMatrices(unaryTransformMaps);
    Map<String, SimpleMatrix> unaryScoreAverages = averageUnaryMatrices(unaryScoreMaps);
    Map<String, SimpleMatrix> wordAverages = averageUnaryMatrices(wordMaps);

    DVModel newModel = new DVModel(binaryTransformAverages, unaryTransformAverages,
                                   binaryScoreAverages, unaryScoreAverages,
                                   wordAverages, lexparser.getOp());
    DVParser newParser = new DVParser(newModel, lexparser);
    newParser.saveModel(outputModelFilename);
  }
}
TOP

Related Classes of edu.stanford.nlp.parser.dvparser.AverageDVModels

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.