Package edu.stanford.nlp.parser.lexparser

Source Code of edu.stanford.nlp.parser.lexparser.HTKLatticeReader$LatticeWord

package edu.stanford.nlp.parser.lexparser;

import edu.stanford.nlp.trees.Tree;

import java.io.BufferedReader;
import java.io.FileReader;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

public class HTKLatticeReader {

  public final boolean DEBUG;
  public final boolean PRETTYPRINT;
  public static final boolean USESUM = true;
  public static final boolean USEMAX = false;
  private final boolean mergeType;
  public static final String SILENCE = "<SIL>";

  private int numStates;
  private List<HTKLatticeReader.LatticeWord> latticeWords;
  private int[] nodeTimes;
  private ArrayList<LatticeWord>[] wordsAtTime;
  private ArrayList<LatticeWord>[] wordsStartAt;
  private ArrayList<LatticeWord>[] wordsEndAt;

  private void readInput(BufferedReader in) throws Exception {

    // GET RID OF COMMENT LINES
    String line = in.readLine();
    while (line.trim().startsWith("#")) {
      line = in.readLine();
    }

    // READ LATTICE
    latticeWords = new ArrayList<HTKLatticeReader.LatticeWord>();

    Pattern wordLinePattern = Pattern.compile("(\\d+)\\s+(\\d+)\\s+lm=(-?\\d+\\.\\d+),am=(-?\\d+\\.\\d+)\\s+([^( ]+)(?:\\((\\d+)\\))?.*");
    Matcher wordLineMatcher = wordLinePattern.matcher(line);

    while (wordLineMatcher.matches()) {
      int startNode = Integer.parseInt(wordLineMatcher.group(1)) - 1;
      int endNode = Integer.parseInt(wordLineMatcher.group(2)) - 1;
      double lm = Double.parseDouble(wordLineMatcher.group(3));
      double am = Double.parseDouble(wordLineMatcher.group(4));
      String word = wordLineMatcher.group(5).toLowerCase();
      String pronun = wordLineMatcher.group(6);

      if (word.equalsIgnoreCase("<s>")) {
        line = in.readLine();
        wordLineMatcher = wordLinePattern.matcher(line);
        continue;
      }
      if (word.equalsIgnoreCase("</s>")) {
        word = Lexicon.BOUNDARY;
      }

      int pronunciation;
      if (pronun == null) {
        pronunciation = 0;
      } else {
        pronunciation = Integer.parseInt(pronun);
      }

      LatticeWord lw = new LatticeWord(word, startNode, endNode, lm, am, pronunciation, mergeType);
      if (DEBUG) {
        System.err.println(lw);
      }
      latticeWords.add(lw);

      line = in.readLine();
      wordLineMatcher = wordLinePattern.matcher(line);
    }

    // GET NUMBER OF NODES
    numStates = Integer.parseInt(line.trim());
    if (DEBUG) {
      System.err.println(numStates);
    }

    // READ NODE TIMES
    nodeTimes = new int[numStates];

    Pattern nodeTimePattern = Pattern.compile("(\\d+)\\s+t=(\\d+)\\s*");
    Matcher nodeTimeMatcher;

    for (int i = 0; i < numStates; i++) {
      nodeTimeMatcher = nodeTimePattern.matcher(in.readLine());

      if (!nodeTimeMatcher.matches()) {
        System.err.println("Input File Error");
        System.exit(1);
      }

      // assert ((Integer.parseInt(nodeTimeMatcher.group(1))-1) == i) ;

      nodeTimes[i] = Integer.parseInt(nodeTimeMatcher.group(2));

      if (DEBUG) {
        System.err.println(i + "\tt=" + nodeTimes[i]);
      }
    }
  }

  private void mergeSimultaneousNodes() {

    int[] indexMap = new int[nodeTimes.length];

    indexMap[0] = 0;
    int prevNode = 0;
    int prevTime = nodeTimes[0];
    if (DEBUG) {
      System.err.println(0 + " (" + nodeTimes[0] + ")" + "-->" + 0 + " (" + nodeTimes[0] + ") ++");
    }
    for (int i = 1; i < nodeTimes.length; i++) {
      if (prevTime == nodeTimes[i]) {
        indexMap[i] = prevNode;
        if (DEBUG) {
          System.err.println(i + " (" + nodeTimes[i] + ")" + "-->" + prevNode + " (" + nodeTimes[prevNode] + ") **");
        }
      } else {
        indexMap[i] = prevNode = i;
        prevTime = nodeTimes[i];
        if (DEBUG) {
          System.err.println(i + " (" + nodeTimes[i] + ")" + "-->" + prevNode + " (" + nodeTimes[prevNode] + ") ++");
        }
      }
    }

    for  (LatticeWord lw : latticeWords) {
      lw.startNode = indexMap[lw.startNode];
      lw.endNode = indexMap[lw.endNode];
      if (DEBUG) {
        System.err.println(lw);
      }
    }
  }

  private void removeEmptyNodes() {
    int[] indexMap = new int[numStates];
    int j = 0;
    for (int i = 0; i < numStates; i++) {
      indexMap[i] = j;
      if (wordsStartAt[i].size() != 0 || wordsEndAt[i].size() != 0) {
        j++;
      }
    }

    for (HTKLatticeReader.LatticeWord lw : latticeWords) {
      wordsStartAt[lw.startNode].remove(lw);
      wordsEndAt[lw.endNode].remove(lw);
      for (int i = lw.startNode; i < lw.endNode; i++) {
        wordsAtTime[i].remove(lw);
      }

      lw.startNode = indexMap[lw.startNode];
      lw.endNode = indexMap[lw.endNode];
      wordsStartAt[lw.startNode].add(lw);
      wordsEndAt[lw.endNode].add(lw);
      for (int i = lw.startNode; i < lw.endNode; i++) {
        wordsAtTime[i].add(lw);
      }
    }

    numStates = j;
    ArrayList<LatticeWord>[] tmp = wordsAtTime;
    wordsAtTime = new ArrayList[numStates];
    System.arraycopy(tmp, 0, wordsAtTime, 0, numStates);

    tmp = wordsStartAt;
    wordsStartAt = new ArrayList[numStates];
    System.arraycopy(tmp, 0, wordsStartAt, 0, numStates);

    tmp = wordsEndAt;
    wordsEndAt = new ArrayList[numStates];
    System.arraycopy(tmp, 0, wordsEndAt, 0, numStates);

  }

  private void buildWordTimeArrays() {
    buildWordsAtTime();
    buildWordsStartAt();
    buildWordsEndAt();
  }

  private void buildWordsAtTime() {
    wordsAtTime = new ArrayList[numStates];
    for (int i = 0; i < wordsAtTime.length; i++) {
      wordsAtTime[i] = new ArrayList<LatticeWord>();
    }

    for (LatticeWord lw : latticeWords) {
      for (int j = lw.startNode; j <= lw.endNode; j++) {
        wordsAtTime[j].add(lw);
      }
    }
  }

  private void buildWordsStartAt() {
    wordsStartAt = new ArrayList[numStates];
    for (int i = 0; i < wordsStartAt.length; i++) {
      wordsStartAt[i] = new ArrayList<LatticeWord>();
    }

    for (LatticeWord lw : latticeWords) {
      wordsStartAt[lw.startNode].add(lw);
    }
  }

  private void buildWordsEndAt() {
    wordsEndAt = new ArrayList[numStates];
    for (int i = 0; i < wordsEndAt.length; i++) {
      wordsEndAt[i] = new ArrayList<LatticeWord>();
    }

    for (LatticeWord lw : latticeWords) {
      wordsEndAt[lw.endNode].add(lw);
    }
  }

  private void removeRedundency() {

    boolean changed = true;

    while (changed) {
      changed = false;
      for (int i = 0; i < wordsAtTime.length; i++) {
        if (wordsAtTime[i].size() < 2) {
          continue;
        }
        INNER: for (int j = 0; j < wordsAtTime[i].size() - 1; j++) {
          LatticeWord w1 = wordsAtTime[i].get(j);
          for (int k = j + 1; k < wordsAtTime[i].size(); k++) {
            LatticeWord w2 = wordsAtTime[i].get(k);
            if (w1.word.equalsIgnoreCase(w2.word)) {
              if (removeRedundentPair(w1, w2)) {
                //int numMerged = mergeDuplicates();
                //if (DEBUG) { System.err.println("merged " + numMerged + " identical entries."); }
                changed = true;
                //printWords();
                //j--;
                continue INNER;
                //return;
              }
            }
          }
        }
      }
    }
  }

  private boolean removeRedundentPair(LatticeWord w1, LatticeWord w2) {

    if (DEBUG) {
      System.err.println("trying to remove:");
      System.err.println(w1);
      System.err.println(w2);
    }

    int w1Start = w1.startNode;
    int w2Start = w2.startNode;
    int w1End = w1.endNode;
    int w2End = w2.endNode;

    // we must pick new start and end times that are legal
    int newStart, oldStart;
    if (w1Start < w2Start) {
      newStart = w2Start;
      oldStart = w1Start;
    } else {
      newStart = w1Start;
      oldStart = w2Start;
    }

    int newEnd, oldEnd;
    if (w1End < w2End) {
      newEnd = w1End;
      oldEnd = w2End;
    } else {
      newEnd = w2End;
      oldEnd = w1End;
    }

    // check legality (illegality not guarenteed)
    for (LatticeWord lw : wordsStartAt[oldStart]) {
      if (lw.endNode < newStart || ((lw.endNode == newStart) && (lw.endNode != lw.startNode))) {
        if (DEBUG) {
          System.err.println("failed");
        }
        return false;
      }
    }
    for (LatticeWord lw : wordsEndAt[oldEnd]) {
      if (lw.startNode > newEnd || ((lw.startNode == newEnd) && (lw.endNode != lw.startNode))) {
        if (DEBUG) {
          System.err.println("failed");
        }
        return false;
      }
    }

    // change start/end times of adjacent entries
    changeStartTimes(wordsStartAt[oldEnd], newEnd);
    changeEndTimes(wordsEndAt[oldStart], newStart);

    // change start/end times of words adjacent to adjacent entries
    changeStartTimes(wordsStartAt[oldStart], newStart);
    changeEndTimes(wordsEndAt[oldEnd], newEnd);

    if (DEBUG) {
      System.err.println("succeeded");
    }
    return true;
  }


  private void changeStartTimes(List<LatticeWord> words, int newStartTime) {
    ArrayList<LatticeWord> toRemove = new ArrayList<LatticeWord>();
    for (LatticeWord lw : words) {
      latticeWords.remove(lw);
      int oldStartTime = lw.startNode;
      lw.startNode = newStartTime;

      if (latticeWords.contains(lw)) {
        if (DEBUG) {
          System.err.println("duplicate found");
        }
        LatticeWord twin = latticeWords.get(latticeWords.indexOf(lw));
        // assert (twin != lw) ;
        lw.startNode = oldStartTime;
        twin.merge(lw);
        //wordsStartAt[lw.startNode].remove(lw);
        toRemove.add(lw);
        wordsEndAt[lw.endNode].remove(lw);
        for (int i = lw.startNode; i <= lw.endNode; i++) {
          wordsAtTime[i].remove(lw);
        }
      } else {
        if (oldStartTime < newStartTime) {
          for (int i = oldStartTime; i < newStartTime; i++) {
            wordsAtTime[i].remove(lw);
          }
        } else {
          for (int i = newStartTime; i < oldStartTime; i++) {
            wordsAtTime[i].add(lw);
          }
        }
        latticeWords.add(lw);
        if (oldStartTime != newStartTime) {
          //wordsStartAt[oldStartTime].remove(lw);
          toRemove.add(lw);
          wordsStartAt[newStartTime].add(lw);
        }
      }
    }
    words.removeAll(toRemove);
  }

  private void changeEndTimes(List<LatticeWord> words, int newEndTime) {
    ArrayList<LatticeWord> toRemove = new ArrayList<LatticeWord>();
    for (LatticeWord lw : words) {
      latticeWords.remove(lw);
      int oldEndTime = lw.endNode;
      lw.endNode = newEndTime;

      if (latticeWords.contains(lw)) {
        if (DEBUG) {
          System.err.println("duplicate found");
        }
        LatticeWord twin = latticeWords.get(latticeWords.indexOf(lw));
        // assert (twin != lw) ;
        lw.endNode = oldEndTime;
        twin.merge(lw);
        wordsStartAt[lw.startNode].remove(lw);
        //wordsEndAt[lw.endNode].remove(lw);
        toRemove.add(lw);
        for (int i = lw.startNode; i <= lw.endNode; i++) {
          wordsAtTime[i].remove(lw);
        }
      } else {
        if (oldEndTime > newEndTime) {
          for (int i = newEndTime + 1; i <= oldEndTime; i++) {
            wordsAtTime[i].remove(lw);
          }
        } else {
          for (int i = oldEndTime + 1; i <= newEndTime; i++) {
            wordsAtTime[i].add(lw);
          }
        }
        latticeWords.add(lw);
        if (oldEndTime != newEndTime) {
          //wordsEndAt[oldEndTime].remove(lw);
          toRemove.add(lw);
          wordsEndAt[newEndTime].add(lw);
        }
      }
    }
    words.removeAll(toRemove);
  }

  private void removeSilence() {
    ArrayList<HTKLatticeReader.LatticeWord> silences = new ArrayList<HTKLatticeReader.LatticeWord>();
    for (LatticeWord lw : latticeWords) {
      if (lw.word.equalsIgnoreCase(SILENCE)) {
        silences.add(lw);
      }
    }
    for (LatticeWord lw : silences) {
      //if (lw.endNode == numStates) {
      changeEndTimes(wordsEndAt[lw.startNode], lw.endNode);
      //} else {
      //changeStartTimes(wordsStartAt[lw.endNode], lw.startNode);
      //}
    }
    silences.clear();
    for (HTKLatticeReader.LatticeWord lw : latticeWords) {
      if (lw.word.equalsIgnoreCase(SILENCE)) {
        silences.add(lw);
      }
    }
    for (LatticeWord lw : silences) {
      if (lw.word.equalsIgnoreCase(SILENCE)) {
        latticeWords.remove(lw);
        wordsStartAt[lw.startNode].remove(lw);
        wordsEndAt[lw.endNode].remove(lw);
        for (int j = lw.startNode; j <= lw.endNode; j++) {
          wordsAtTime[j].remove(lw);
        }
      }
    }
  }

  private int mergeDuplicates() {
    int numMerged = 0;
    for (int i = 0; i < latticeWords.size() - 1; i++) {
      LatticeWord first = latticeWords.get(i);
      for (int j = i + 1; j < latticeWords.size(); j++) {
        LatticeWord second = latticeWords.get(j);
        if (first.equals(second)) {
          if (DEBUG) {
            System.err.println("removed duplicate");
          }
          first.merge(second);
          latticeWords.remove(j);
          wordsStartAt[second.startNode].remove(second);
          wordsEndAt[second.endNode].remove(second);
          for (int k = second.startNode; k <= second.endNode; k++) {
            wordsAtTime[k].remove(second);
          }
          numMerged++;
          j--;
        }
      }
    }
    return numMerged;
  }

  public void printWords() {
    Collections.sort(latticeWords);
    System.out.println("Words: ");
    for (LatticeWord lw : latticeWords) {
      System.out.println(lw);
    }
  }

  private double getProb(LatticeWord lw) {
    return lw.am * 100.0 + lw.lm;
  }

  //     private LatticeWord[][] nBest(int n) {

  //     }

  public void processLattice() {
    // System.err.println(1);
    buildWordTimeArrays();
    //System.err.println(2);
    removeSilence();
    //System.err.println(3);
    mergeDuplicates();
    //System.err.println(4);
    removeRedundency();
    //System.err.println(5);
    removeEmptyNodes();
    //System.err.println(6);
    if (PRETTYPRINT) {
      printWords();
    }

  }


  public HTKLatticeReader(String filename) throws Exception {
    this(filename, USESUM, false, false);
  }

  public HTKLatticeReader(String filename, boolean mergeType) throws Exception {
    this(filename, mergeType, false, false);
  }

  public HTKLatticeReader(String filename, boolean mergeType, boolean debug, boolean prettyPrint) throws Exception {
    this.DEBUG = debug;
    this.PRETTYPRINT = prettyPrint;
    this.mergeType = mergeType;

    BufferedReader in = new BufferedReader(new FileReader(filename));
    //System.err.println(-1);
    readInput(in);
    //System.err.println(0);
    if (PRETTYPRINT) {
      printWords();
    }

    processLattice();

  }

  public List<HTKLatticeReader.LatticeWord> getLatticeWords() {
    return latticeWords;
  }

  public int getNumStates() {
    return numStates;
  }

  public List<HTKLatticeReader.LatticeWord> getWordsOverSpan(int a, int b) {
    ArrayList<HTKLatticeReader.LatticeWord> words = new ArrayList<HTKLatticeReader.LatticeWord>();
    for (LatticeWord lw : wordsStartAt[a]) {
      if (lw.endNode == b) {
        words.add(lw);
      }
    }
    return words;
  }

  public static void main(String[] args) throws Exception {

    boolean mergeType = USESUM;
    boolean prettyPrint = true;
    boolean debug = false;
    String parseGram = null;
    String filename = args[0];

    for (int i = 1; i < args.length; i++) {
      if (args[i].equalsIgnoreCase("-debug")) {
        debug = true;
      } else if (args[i].equalsIgnoreCase("-useMax")) {
        mergeType = USEMAX;
      } else if (args[i].equalsIgnoreCase("-useSum")) {
        mergeType = USESUM;
      } else if (args[i].equalsIgnoreCase("-noPrettyPrint")) {
        prettyPrint = false;
      } else if (args[i].equalsIgnoreCase("-parser")) {
        parseGram = args[++i];
      } else {
        System.err.println("unrecognized flag: " + args[i]);
        System.err.println("usage: java LatticeReader <file> [ -debug ] [ -useMax ] [ -useSum ] [ -noPrettyPrint ] [ -parser parserFile ]");
        System.exit(0);
      }
    }

    HTKLatticeReader lr = new HTKLatticeReader(filename, mergeType, debug, prettyPrint);

    if (parseGram != null) {
      Options op = new Options();
      // TODO: these options all get clobbered by the Options object
      // stored in the LexicalizedParser (unless it's a text file?)
      op.doDep = false;
      op.testOptions.maxLength = 80;
      op.testOptions.maxSpanForTags = 80;
      LexicalizedParser lp = LexicalizedParser.loadModel(parseGram, op);
      // TODO: somehow merge this into ParserQuery instead of being
      // LexicalizedParserQuery specific
      LexicalizedParserQuery pq = lp.lexicalizedParserQuery();
      pq.parse(lr);
      Tree t = pq.getBestParse();
      t.pennPrint();
    }
    //lr.processLattice();
  }

  public static class LatticeWord implements Comparable<LatticeWord> {
    public String word;
    public int startNode, endNode;
    public double lm, am;
    public int pronunciation;
    public final boolean mergeType;

    public LatticeWord(String word, int startNode, int endNode, double lm, double am, int pronunciation, boolean mergeType) {

      this.word = word;
      this.startNode = startNode;
      this.endNode = endNode;
      this.lm = lm;
      this.am = am;
      this.pronunciation = pronunciation;
      this.mergeType = mergeType;
    }

    public void merge(LatticeWord lw) {
      if (mergeType == USEMAX) {
        am = Math.max(am, lw.am);
        lw.am = am;
      } else if (mergeType == USESUM) {
        double tmp = lw.am;
        lw.am += am;
        am += tmp;
      }
    }

    @Override
    public String toString() {
      StringBuffer sb = new StringBuffer();
      sb.append(startNode).append("\t");
      sb.append(endNode).append("\t");
      sb.append("lm=").append(lm).append(",");
      sb.append("am=").append(am).append("\t");
      sb.append(word);//.append("(").append(pronunciation).append(")");
      return sb.toString();
    }

    @Override
    public boolean equals(Object o) {
      if (!(o instanceof LatticeWord)) {
        return false;
      }
      LatticeWord other = (LatticeWord) o;
      if (!word.equalsIgnoreCase(other.word)) {
        return false;
      }
      if (startNode != other.startNode) {
        return false;
      }
      if (endNode != other.endNode) {
        return false;
      }
      //if (pronunciation != other.pronunciation) { return false; }
      return true;
    }

    public int compareTo(LatticeWord other) {
      if (startNode < other.startNode) {
        return -1;
      } else if (startNode > other.startNode) {
        return 1;
      }

      if (endNode < other.endNode) {
        return -1;
      } else if (endNode > other.endNode) {
        return 1;
      }

      return 0;
    }

  }

}
TOP

Related Classes of edu.stanford.nlp.parser.lexparser.HTKLatticeReader$LatticeWord

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.