/* This file is part of the Joshua Machine Translation System.
*
* Joshua is free software; you can redistribute it and/or modify
* it under the terms of the GNU Lesser General Public License as
* published by the Free Software Foundation; either version 2.1
* of the License, or (at your option) any later version.
*
* This library is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with this library; if not, write to the Free
* Software Foundation, Inc., 59 Temple Place, Suite 330, Boston,
* MA 02111-1307 USA
*/
package joshua.corpus.vocab;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.io.ObjectInput;
import java.io.ObjectOutput;
import java.util.*;
import java.util.logging.Level;
import java.util.logging.Logger;
import joshua.corpus.suffix_array.BasicPhrase;
import joshua.decoder.ff.tm.hiero.HieroFormatReader;
import joshua.util.io.BinaryIn;
import joshua.util.io.LineReader;
/**
* Vocabulary is the class that keeps track of the unique words
* that occur in a corpus of text for a particular language. It
* assigns integer IDs to Words, which is useful when we are creating
* suffix arrays or doing similar things.
*
* @author Chris Callison-Burch
* @since 8 February 2005
* @author Lane Schwartz
* @version $LastChangedDate:2008-07-30 17:15:52 -0400 (Wed, 30 Jul 2008) $
*/
public class Vocabulary extends AbstractExternalizableSymbolTable
implements Iterable<String>, ExternalizableSymbolTable {
//===============================================================
// Constants
//===============================================================
private static final Logger logger = Logger.getLogger(Vocabulary.class.getName());
//===============================================================
// Member variables
//===============================================================
protected final Map<String,Integer> nonterminalToInt;
protected final Map<String,Integer> terminalToInt;
protected final Map<Integer,String> intToString;
// /**
// * Determines whether new words may be added to the vocabulary.
// */
// protected boolean isFixed;
/**
* The value returned by this class's <code>hashCode</code>
* method.
*/
protected static final int HASH_CODE = 42;
//===============================================================
// Constructor(s)
//===============================================================
/**
* Constructor creates an empty vocabulary.
*/
public Vocabulary() {
nonterminalToInt = new HashMap<String,Integer>();
terminalToInt = new HashMap<String,Integer>();
intToString = new HashMap<Integer,String>();
// isFixed = false;
// terminalToInt.put(UNKNOWN_WORD_STRING, UNKNOWN_WORD);
// intToString.put(UNKNOWN_WORD, UNKNOWN_WORD_STRING);
addNonterminal(X_STRING);
addNonterminal(X1_STRING);
addNonterminal(X2_STRING);
addNonterminal(S_STRING);
addNonterminal(S1_STRING);
this.addTerminal("<unk>");
this.addTerminal("<s>");
this.addTerminal("</s>");
this.addTerminal("-pau-");
}
/**
* Constructor creates a vocabulary initialized with the given
* set of words.
*/
public Vocabulary(Set<String> words) {
this();
for (String word : words) {
this.addTerminal(word);
}
// alphabetize();
// isFixed = true;
}
/**
* Constructs a vocabulary using the words from an SRILM
* language model file.
*
* @param scanner Scanner configured to read an SRILM
* language model file.
* @return Vocabulary initialized with the words from the
* SRILM language model file.
*/
public static Vocabulary getVocabFromSRILM(Scanner scanner) {
Vocabulary vocab = new Vocabulary();
int counter = 0;
int ignored = 0;
while (scanner.hasNextLine()) {
String line = scanner.nextLine();
String[] parts = line.split("\\s+");
if (parts.length==2) {
Integer id = Integer.valueOf(parts[0]);
String word = parts[1];
vocab.intToString.put(id, word);
if (vocab.isNonterminal(id)) {
vocab.nonterminalToInt.put(word, id);
} else {
vocab.terminalToInt.put(word, id);
}
counter += 1;
} else {
ignored += 1;
logger.warning("Line is improperly formatted: " + line);
}
}
if (logger.isLoggable(Level.FINE)) {
int total = counter + ignored;
logger.fine(total + " lines read, of which " + counter + " were included and " + ignored + " were ignored");
}
return vocab;
}
/**
* Initializes a Vocabulary by adding all words from a
* specified plain text file.
*
* @param inputFilename the plain text file
* @param vocab the Vocabulary to which words should
* be added
* @param fixVocabulary Should the vocabulary be fixed and
* alphabetized at the end of
* initialization
* @return a tuple containing the number of words in the
* corpus and number of sentences in the corpus
*/
public static int[] initializeVocabulary(String inputFilename, Vocabulary vocab, boolean fixVocabulary) throws IOException {
int numSentences = 0;
int numWords = 0;
LineReader lineReader = new LineReader(inputFilename);
for (String line : lineReader) {
BasicPhrase sentence = new BasicPhrase(line, vocab);
numWords += sentence.size();
numSentences++;
if(logger.isLoggable(Level.FINE) && numSentences % 10000==0) logger.fine(""+numWords);
}
// if (fixVocabulary) {
// vocab.fixVocabulary();
// vocab.alphabetize();
// }
int[] numberOfWordsSentences = { numWords, numSentences };
return numberOfWordsSentences;
}
//===============================================================
// Public
//===============================================================
//===========================================================
// Accessor methods (set/get)
//===========================================================
/**
* Gets an integer identifier for the word.
* <p>
* If the word is in the vocabulary, the integer returned
* will uniquely identify that word.
*
* If the word is not in the vocabulary, the constant
* <code>UNKNOWN_WORD</code> will be returned.
*
* @return the unique integer identifier for wordString,
* or UNKNOWN_WORD if wordString is not in the
* vocabulary
*/
public int getID(String wordString) {
// String s = HieroFormatReader.getFieldDelimiter();
if (SymbolTable.X_STRING.equals(wordString) ||
SymbolTable.X1_STRING.equals(wordString) ||
SymbolTable.X2_STRING.equals(wordString) ||
SymbolTable.S_STRING.equals(wordString) ||
SymbolTable.S1_STRING.equals(wordString) ||
HieroFormatReader.isNonTerminal(wordString)) {
return this.addNonterminal(wordString);
} else {
return this.addTerminal(wordString);
}
// return add
// if (terminalToInt.containsKey(wordString)) {
// return terminalToInt.get(wordString);
// } else if (nonterminalToInt.containsKey(wordString)) {
// return nonterminalToInt.get(wordString);
// } else {
// return UNKNOWN_WORD;
// }
}
public int getNonterminalID(String nonterminalString) {
return this.addNonterminal(nonterminalString);
// return nonterminalToInt.get(nonterminalString);
}
/**
* Gets the integer identifiers for all words in the provided
* sentence.
* <p>
* The sentence will be split (on spaces) into words, then
* the integer identifier for each word will be retrieved
* using <code>getID</code>.
*
* @see #getID(String)
* @param sentence String of words, separated by spaces.
* @return Array of integer identifiers for each word in
* the sentence
*/
public int[] getIDs(String sentence) {
if (sentence==null || sentence.trim().length()==0) {
return new int[]{};
} else {
String[] words = sentence.trim().split(" ");
int[] wordIDs = new int[words.length];
for (int i=0; i<words.length; i++) {
wordIDs[i] = getID(words[i]);
}
return wordIDs;
}
}
/**
* Gets the String that corresponds to the specified integer
* identifier.
*
* @return the String that corresponds to the specified
* integer identifier, or <code>UNKNOWN_WORD_STRING</code>
* if the identifier does not correspond to a word
* in the vocabulary
*/
public String getWord(int wordID) {
// if (wordID==UNKNOWN_WORD || wordID >= terminalToInt.size() || wordID < -(nonterminalToInt.size())) {
//// if (wordID==UNKNOWN_WORD || wordID >= terminalToInt.size() || wordID < -(nonterminalToInt.size())) {
// return UNKNOWN_WORD_STRING;
// }
String word = intToString.get(wordID);
if (word==null) {
word = UNKNOWN_WORD_STRING;
}
return word;
}
/**
* @return an Iterator over all words in the Vocabulary.
*/
public Iterator<String> iterator() {
return intToString.values().iterator();
}
public Collection<Integer> getAllIDs() {
return terminalToInt.values();
}
/**
* Gets the list of all words represented by this vocabulary.
*
* @return the list of all words represented by this
* vocabulary
*/
public Collection<String> getWords() {
return intToString.values();
}
/**
* Gets the number of unique words in the vocabulary.
*
* @return the number of unique words in the vocabulary.
*/
public int size() {
return intToString.size();
}
// /**
// * Fixes the size of the vocabulary so that new words may
// * not be added.
// */
// public void fixVocabulary() {
// isFixed = true;
// }
/**
* Determines if the phrase contains any words that are not
* in the vocabulary.
*
* @return <code>true</code> if there are unknown words in
* the phrase, <code>false</code> otherwise
*/
public boolean containsUnknownWords(BasicPhrase phrase) {
for(int i = 0; i < phrase.size(); i++) {
if(phrase.getWordID(i) == UNKNOWN_WORD) return true;
}
return false;
}
/**
* Checks that the Vocabularies are the same, by first
* checking that they have the same number of terminals,
* and then checking that each word corresponds to the same
* ID.
* <p>
* XXX This method does NOT check to verify that the
* nonterminal vocabulary is the same.
*
* @param o the object to check equivalence with
* @return <code>true</code> if the other object is a
* Vocabulary representing the same set of words
* with identically assigned IDs, <code>false</code>
* otherwise
*/
public boolean equals(Object o) {
if (o==this) {
return true;
} else if (o instanceof SymbolTable) {
SymbolTable other = (SymbolTable) o;
if(other.size() != this.size()) return false;
for (int i=-(nonterminalToInt.size()), n=terminalToInt.size(); i<n; i++) {
String thisWord = this.getWord(i);//.intToString.get(i);
String otherWord = other.getWord(i);
if (thisWord==null && otherWord!=null) return false;
if(!(thisWord.equals(otherWord))) return false;
Integer thisID = (this.isNonterminal(i)) ? this.nonterminalToInt.get(thisWord) : this.terminalToInt.get(thisWord);
Integer otherID;
if (o instanceof Vocabulary) {
Vocabulary otherVocab = (Vocabulary) other;
otherID = (otherVocab.isNonterminal(i)) ? otherVocab.nonterminalToInt.get(thisWord) : otherVocab.terminalToInt.get(thisWord);
} else {
otherID = other.getID(otherWord);
}
if(thisID != null && otherID != null) {
if(!(thisID.equals(otherID))) return false;
}
}
return true;
} else {
return false;
}
}
/**
* It is expected that instances of this class will never
* be put into a hash table.
* <p>
* Therefore, this method always returns a constant value.
*
* @return a constant value
*/
@Override
public int hashCode() {
assert false : "hashCode not designed";
return HASH_CODE;
}
//===========================================================
// Methods
//===========================================================
public String toString() {
return intToString.toString();
}
// /**
// * Sorts the vocabulary alphabetically and re-assigns IDs
// * in ascending order.
// */
// public void alphabetize() {
//
// ArrayList<String> wordList = new ArrayList<String>(terminalToInt.keySet());//intToString.values());
//
// // alphabetize
// Collections.sort(wordList, new Comparator<String>(){
// public int compare(String o1, String o2) {
// if (UNKNOWN_WORD_STRING.equals(o1) || null==o1) {
// if (UNKNOWN_WORD_STRING.equals(o2) || null==o2) {
// return 0;
// } else {
// return -1;
// }
// } else if (UNKNOWN_WORD_STRING.equals(o2) || null==o2) {
// return 1;
// } else {
// return o1.compareTo(o2);
// }
// }
// });
//
// // Clear current mappings
// terminalToInt.clear();
// intToString.clear();
//
// // Reassign nonterminal mappings
// for (Map.Entry<String, Integer> ntEntry : nonterminalToInt.entrySet()) {
// intToString.put(ntEntry.getValue(), ntEntry.getKey());
// }
//
// // Reassign terminal mappings
// for(int i = 0; i < wordList.size(); i++) {
// String wordString = wordList.get(i);
// terminalToInt.put(wordString, i);
// intToString.put(i, wordString);
// }
//
// }
public int getHighestID() {
return terminalToInt.size();
// return terminalToInt.size() - 1;
}
public int getLowestID() {
return -(nonterminalToInt.size());
}
//===============================================================
// Protected
//===============================================================
//===============================================================
// Methods
//===============================================================
//===============================================================
// Private
//===============================================================
//===============================================================
// Methods
//===============================================================
//===============================================================
// Static
//===============================================================
public int addNonterminal(String nonterminal) {
Integer id = nonterminalToInt.get(nonterminal);
if (id != null) {
return id.intValue();
} else {
int size = nonterminalToInt.size();
id = -(size+1);
nonterminalToInt.put(nonterminal, id);
intToString.put(id, nonterminal);
return id;
}
}
public int addTerminal(String terminal) {
Integer id = terminalToInt.get(terminal);
if (id != null) {
return id.intValue();
} else {
id = Integer.valueOf(terminalToInt.size()+1);
intToString.put(id, terminal);
terminalToInt.put(terminal, id);
return id.intValue();
}
}
public String getTerminal(int wordId) {
return getWord(wordId);
}
public String getTerminals(int[] wordIDs) {
return getWords(wordIDs, false);
}
public String getWords(int[] ids) {
return getWords(ids, false);
}
public static Vocabulary readExternal(String binaryFileName) throws FileNotFoundException, IOException, ClassNotFoundException {
Vocabulary vocab = new Vocabulary();
ObjectInput in = BinaryIn.vocabulary(binaryFileName);
vocab.readExternal(in);
return vocab;
}
public void readExternal(ObjectInput in) throws IOException,
ClassNotFoundException {
String characterEncoding = getExternalizableEncoding();
// First read the number of bytes required to store the vocabulary data
int totalBytes = in.readInt();
if (logger.isLoggable(Level.FINEST)) logger.finest("Read total bytes: " + totalBytes);
// Next, read the actual vocabulary data
int bytesRemaining = totalBytes - 4;
while (bytesRemaining > 0) {
// Read the integer id of the word
int id = in.readInt();
if (logger.isLoggable(Level.FINEST)) logger.finest("Read ID: " + id);
// Read the number of bytes used to store the word string
int wordLength = in.readInt();
if (logger.isLoggable(Level.FINEST)) logger.finest("Read string length: " + wordLength);
// We have now read eight more bytes (4 bytes per int)
bytesRemaining -= 8;
// Read the bytes used to store the word string
byte[] wordBytes = new byte[wordLength];
in.readFully(wordBytes);
String word = new String(wordBytes, characterEncoding);
if (logger.isLoggable(Level.FINEST)) logger.finest("Read string bytes: " + Arrays.toString(wordBytes) + " for \"" + word + "\"");
// We have now read more bytes
bytesRemaining -= wordBytes.length;
// Store the word in the vocabulary
intToString.put(id, word);
if (logger.isLoggable(Level.FINEST)) logger.finest("Mapped int " + id + " to word \"" + word + "\"");
if (isNonterminal(id)) {
nonterminalToInt.put(word, id);
if (logger.isLoggable(Level.FINEST)) logger.finest("Mapped word \"" + word + "\" to int " + id);
} else {
terminalToInt.put(word, id);
if (logger.isLoggable(Level.FINEST)) logger.finest("Mapped word \"" + word + "\" to int " + id);
}
}
}
public void writeExternal(ObjectOutput out) throws IOException {
String characterEncoding = getExternalizableEncoding();
// First, calculate the number of bytes required to store the vocabulary data
int totalBytes = 0;
for (String word : intToString.values()) {
byte[] wordBytes = word.getBytes(characterEncoding);
totalBytes += 8 + wordBytes.length;
}
// Now, write the total number of bytes used to store the vocabulary data
totalBytes += 4; // 4 bytes for this int
if (logger.isLoggable(Level.FINEST)) logger.finest("Writing total bytes: " + totalBytes);
out.writeInt(totalBytes);
// Next, write the actual vocabulary data
for (Map.Entry<Integer, String> entry : intToString.entrySet()) {
int id = entry.getKey();
String word = entry.getValue();
byte[] wordBytes = word.getBytes(characterEncoding);
// Write the integer id of the word
if (logger.isLoggable(Level.FINEST)) logger.finest("Writing ID: " + id);
out.writeInt(id);
// Write the number of bytes to store the word
if (logger.isLoggable(Level.FINEST)) logger.finest("Writing string length: " + wordBytes.length);
out.writeInt(wordBytes.length);
// Write the byte data for the string
if (logger.isLoggable(Level.FINEST)) logger.finest("Writing string bytes: " + Arrays.toString(wordBytes) + " for \"" + word + "\"");
out.write(wordBytes);
}
}
static final long serialVersionUID = 1L;
}