Package ivory.sqe.querygenerator

Source Code of ivory.sqe.querygenerator.MtNQueryGenerator

package ivory.sqe.querygenerator;

import ivory.core.tokenize.Tokenizer;
import ivory.core.tokenize.TokenizerFactory;
import ivory.sqe.retrieval.Constants;
import ivory.sqe.retrieval.PairOfFloatMap;
import ivory.sqe.retrieval.StructuredQuery;

import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.FileSystem;
import org.apache.log4j.Level;
import org.apache.log4j.Logger;
import org.tartarus.snowball.SnowballStemmer;

import com.google.gson.JsonArray;
import com.google.gson.JsonObject;
import com.google.gson.JsonPrimitive;

import edu.umd.cloud9.io.map.HMapSFW;
import edu.umd.cloud9.util.map.HMapKI;
import edu.umd.cloud9.util.map.MapKF.Entry;

/**
* A different way of formulating the N-best translation output from cdec.
* It is in format {#combine { #weight ... } ... } where each #weight
* corresponds to a single source token representation. 
*
* Retrieval engine computes a TF and DF value for each source token,
* based on the alternative translations used in the N-best list
*
* @author ferhanture
*
*/
public class MtNQueryGenerator implements QueryGenerator {
  private static final Logger LOG = Logger.getLogger(MtNQueryGenerator.class);
  Tokenizer docLangTokenizer, queryLangTokenizerWithStemming, queryLangTokenizer;
  SnowballStemmer stemmer;
  int length;
  private int kBest;
  boolean bigramSegment = false;
  private ProbabilisticStructuredQueryGenerator clGenerator;
  private SCFGQueryGenerator scfgGenerator;
  private float mtWeight, bitextWeight,scfgWeight, tokenWeight, phraseWeight, alpha, lexProbThreshold;
  private String queryLang, docLang;
  private boolean scaling;

  public MtNQueryGenerator() {
    super();
  }

  @Override
  public void init(FileSystem fs, Configuration conf) throws IOException {
    if (conf.getBoolean(Constants.Quiet, false)) {
      LOG.setLevel(Level.OFF);
    }

    queryLang = conf.get(Constants.QueryLanguage);
    docLang = conf.get(Constants.DocLanguage);

    LOG.info("Stemmed stopword list file in query-language:" + conf.get(Constants.StemmedStopwordListQ));
    LOG.info("Stemmed stopword list file in doc-language:" + conf.get(Constants.StemmedStopwordListD));

    tokenWeight = conf.getFloat(Constants.TokenWeight, 1f);
    phraseWeight = conf.getFloat(Constants.PhraseWeight, 0f);
    alpha = conf.getFloat(Constants.Alpha, 1);
    scaling = conf.getBoolean(Constants.Scaling, false);
    lexProbThreshold = conf.getFloat(Constants.LexicalProbThreshold, 0f);

    String queryTokenizerPath = conf.get(Constants.QueryTokenizerData);
    String docTokenizerPath = conf.get(Constants.DocTokenizerData);
    kBest = conf.getInt(Constants.KBest, 1);
    LOG.info("K = " + kBest);

    mtWeight = conf.getFloat(Constants.MTWeight, 1f);
    bitextWeight = conf.getFloat(Constants.BitextWeight, 0f);
    scfgWeight = conf.getFloat(Constants.SCFGWeight, 0f);
    LOG.info(conf.get(Constants.MTWeight));
    LOG.info(conf.get(Constants.BitextWeight));
    LOG.info(conf.get(Constants.SCFGWeight));

    queryLangTokenizer = TokenizerFactory.createTokenizer(fs, conf, queryLang, queryTokenizerPath, false, null, null, null);
    queryLangTokenizerWithStemming = TokenizerFactory.createTokenizer(fs, conf, queryLang, queryTokenizerPath, true, null, conf.get(Constants.StemmedStopwordListQ), null);
    docLangTokenizer = TokenizerFactory.createTokenizer(fs, conf, docLang, docTokenizerPath, true, null, conf.get(Constants.StemmedStopwordListD), null);

    clGenerator = new ProbabilisticStructuredQueryGenerator();
    clGenerator.init(fs, conf);

    scfgGenerator = new SCFGQueryGenerator();
    scfgGenerator.init(fs, conf);
  }

  @Override
  public StructuredQuery parseQuery(String query){
    JsonObject queryJson = new JsonObject();
    JsonObject queryTJson = new JsonObject();
    JsonObject queryPJson = new JsonObject();

      List<String> tokensBOW = new ArrayList<String>(), tokensBOP = new ArrayList<String>();
      Map<String, HMapSFW> src2trg2weight = new HashMap<String, HMapSFW>();
      Map<String,String> target2source = new HashMap<String,String>();
      HMapSFW phrase2weight = new HMapSFW();

      String[] kbestTranslations = query.trim().split("\\|\\|\\|\\|");
      String origQuery = kbestTranslations[0].split(";")[2].trim();

      // if no weighting, delegate to appropriate generator class
      if (mtWeight == 0 && scfgWeight == 0 && bitextWeight == 1) {
        return clGenerator.parseQuery(";"+origQuery);
      }

      String[] stemmedSourceTokens = queryLangTokenizerWithStemming.processContent(origQuery);
      Map<String,String> stemmed2Stemmed = Utils.getStemMapping(origQuery, queryLangTokenizer, queryLangTokenizerWithStemming, docLangTokenizer);

      // if k is 1, we assume standard space-delimited query format
      // otherwise, query format consists of k "weight ||| translation" pairs, each separated by |||
      if (kBest == 1){
        String[] line = kbestTranslations[0].trim().split(";");
        String[] rules = line[0].trim().split("\\|\\|\\|");

        for (String rule : rules) {
          rule = rule.trim();
          String[] ruleArr = rule.split("::");
          String phrasePair = ruleArr[2];
          String trgPhrase = phrasePair.split("\\|")[1];
          float ruleProb = Float.parseFloat(ruleArr[0]);

          // include entire RHS in the alternative phrase-based representation (even if it's not a multiword expr)

          // heuristic: remove stop words from the RHS except for the ones between content words
          // (e.g., tremblement de terre, means earthquake)
          trgPhrase = docLangTokenizer.removeBorderStopWords(trgPhrase);
          tokensBOP.add(trgPhrase);

          // add target tokens
          String[] tokenPairs= ruleArr[1].split("\\|\\|");
          for (String tokenPair : tokenPairs) {
            // token consists of prob|source|target
            String[] arr = tokenPair.split("\\|");
            String source = arr[0];
            String target = arr[1];
            // Optional stem normalization
            // -1 is a special marker for Pass-through rules, meaning that source is OOV
            // in this case, we want to restem the source using the doc lang tokenizer
            // e.g. Emmy --> emmi with English stemmer. If we translate emmi with pass-through rule, it becomes emmi in French
            // however, Emmy --> emmy with French stemmer, so this will cause us to miss all relevant documents
            if (ruleProb == -1) {        
              target = stemmed2Stemmed.get(target);
            }

            if (target != null && !queryLangTokenizerWithStemming.isStopWord(source) && !source.equals("NULL") && !docLangTokenizer.isStopWord(target)) {
              tokensBOW.add(target);
            }
          }
        }

        String[] bopArr = new String[tokensBOP.size()];
        bopArr = tokensBOP.toArray(bopArr);
        JsonArray bop = Utils.createJsonArray(bopArr);
        JsonObject bopJson = new JsonObject();
        bopJson.add("#combine", bop);

        String[] bowArr = new String[tokensBOW.size()];
        bowArr = tokensBOW.toArray(bowArr);
        JsonArray bow = Utils.createJsonArray(bowArr);
        JsonObject bowJson = new JsonObject();
        bowJson.add("#combine", bow);

        JsonArray weightedQuery = new JsonArray();
        weightedQuery.add(new JsonPrimitive(tokenWeight));
        weightedQuery.add(bowJson);
        weightedQuery.add(new JsonPrimitive(phraseWeight));
        weightedQuery.add(bopJson);
        queryJson.add("#weight", weightedQuery);
      }else {     // k > 1

        // apply discount on logprobs to avoid floating point errors
        float discount = 0;
        String[] line = kbestTranslations[0].trim().split(";");
        discount = -Float.parseFloat(line[1]);

        float[] transProbs = new float[kbestTranslations.length];
        float sumOfProbs = 0;
        for (int k = 0; k < kbestTranslations.length; k++){
          line = kbestTranslations[k].trim().split(";");
          transProbs[k] = (float) Math.pow(Math.E, alpha * (Float.parseFloat(line[1]) + discount));
          sumOfProbs += transProbs[k];
        }

        // parse each of the k top translations
        float cumPhraseProbs = 0;
        HMapKI<String> tokenCount = new HMapKI<String>();
        for (int k = 0; k < kbestTranslations.length; k++){
          // init
          target2source.clear();

          // parse input from cdec
          line = kbestTranslations[k].trim().split(";");

          // normalize prob. of k-th translation
          float transProb = transProbs[k]/sumOfProbs;

          //          String text = line[2].replaceAll("\\s+", "");
          String[] rules = line[0].trim().split("\\|\\|\\|");

          for (String rule : rules) {
            String[] ruleArr = rule.split("::");
            String[] tokenPairs = ruleArr[1].split("\\|\\|");
            String[] phraseArr = ruleArr[2].split("\\|");
            //            String sourcePhrase = phraseArr[0];
            String targetPhrase = phraseArr[1].trim();
            float ruleProb = Float.parseFloat(ruleArr[0]);
            for (String tokenPair : tokenPairs) {
              // token consists of prob|source|target
              String[] arr2 = tokenPair.split("\\|");
              String source = arr2[0];
              String target = arr2[1];

              // Optional stem normalization
              if (ruleProb == -1) {
                target = stemmed2Stemmed.get(target);
              }

              //              LOG.info("assign:{"+source+"}->["+target+"]="+transProb);
              if (target == null || queryLangTokenizerWithStemming.isStopWord(source) || source.equals("NULL") || docLangTokenizer.isStopWord(target)) {
                continue;
              }

              // if a source token is aligned to multiple target tokens,
              // treat each alignment as a separate possible translation w/ same probability
              tokenCount.increment(source);
              if (src2trg2weight.containsKey(source)) {
                if(src2trg2weight.get(source).containsKey(target)) {
                  src2trg2weight.get(source).increment(target, transProb);
                }else {
                  src2trg2weight.get(source).put(target, transProb);
                }
              }else {
                HMapSFW trg2weight = new HMapSFW();
                trg2weight.put(target, transProb);
                src2trg2weight.put(source, trg2weight);
              }
            }
            if (targetPhrase.split(" ").length > 1) {
              phrase2weight.increment(targetPhrase, transProb);
              cumPhraseProbs += transProb;
            }
          }
        }

        // add phrase translations into a #weight array structure
        if (phraseWeight > 0) {
          JsonArray pArr = Utils.createJsonArrayFromProbabilities(Utils.scaleProbMap(lexProbThreshold, 1/cumPhraseProbs, phrase2weight));
          queryPJson.add("#weight", pArr);
        }

        // add token translations into a #combine of #weight array structures
        JsonArray tokensArr = new JsonArray();
        if (tokenWeight > 0) {
          for (String srcToken : stemmedSourceTokens) {
            HMapSFW nbestDist = src2trg2weight.get(srcToken);

            // skip stop words among source query words
            if (queryLangTokenizerWithStemming.isStopWord(srcToken)){
              LOG.info("Skipped stopword "+srcToken);
              continue;
            }
            JsonObject tokenWeightedArr = new JsonObject();
            LOG.info("Processing "+srcToken);

            // skip stop words among source query words
            if (nbestDist == null){
              LOG.info("Unaligned in MT: "+srcToken);
            }else {
              // normalize probabilities for this token
              float normalization = 0;
              for (Entry<String> e : nbestDist.entrySet()) {
                float weight = e.getValue();
                normalization += weight;
              }
              for (Entry<String> e : nbestDist.entrySet()) {
                nbestDist.put(e.getKey(), e.getValue()/normalization);
              }             
            }

            // combine translations from N-best AND bilingual dictionary
            List<PairOfFloatMap> tokenRepresentationList = new ArrayList<PairOfFloatMap>();

            // Pr{bitext}
            HMapSFW bitextDist = clGenerator.getTranslations(srcToken, stemmed2Stemmed);
            if (bitextDist != null && !bitextDist.isEmpty() && bitextWeight > 0) {
              tokenRepresentationList.add(new PairOfFloatMap(bitextDist, bitextWeight));
            }

            // Pr{scfg}
            HMapSFW scfgDist = scfgGenerator.getTranslations(srcToken, stemmed2Stemmed);
            if (scfgDist != null && !scfgDist.isEmpty() && scfgWeight > 0) {
              tokenRepresentationList.add(new PairOfFloatMap(scfgDist, scfgWeight));
            }

            // Pr{n-best}
            if (mtWeight > 0 && nbestDist != null && !nbestDist.isEmpty()) {
              tokenRepresentationList.add(new PairOfFloatMap(nbestDist, mtWeight));
            }

            JsonArray combinedArr;
            float scale = 1;
            if (scaling) {
              scale = scale * tokenCount.get(srcToken)/((float) kbestTranslations.length);
            }
            if(tokenRepresentationList.size() == 0) {
              continue;       // if empty distr., do not represent this source token in query
            } else if(tokenRepresentationList.size() == 1) {
              combinedArr = Utils.createJsonArrayFromProbabilities(Utils.scaleProbMap(lexProbThreshold, scale, tokenRepresentationList.get(0).getMap()));
            } else {
              combinedArr = Utils.createJsonArrayFromProbabilities(Utils.combineProbMaps(lexProbThreshold, scale, tokenRepresentationList));
            }

            tokenWeightedArr.add("#weight", combinedArr);

            // optional: if this source token has occurred more than once per query, reflect this in the representation
            //  for (int i = 0; i < Math.ceil(tokenCount.get(srcToken)/(float)kBest); i++) {
            //    tokensArr.put(tokenWeightedArr);
            //  }
            tokensArr.add(tokenWeightedArr);
          }
          queryTJson.add("#combine", tokensArr);
        }

        // combine the token-based and phrase-based representations into a #combweight structure
        JsonArray queryJsonArr = new JsonArray();

        HMapSFW normalizedPhrase2Weight = null;
        if (phraseWeight > 0) {
          normalizedPhrase2Weight = Utils.scaleProbMap(lexProbThreshold, phraseWeight/cumPhraseProbs, phrase2weight);     
          for (String phrase : normalizedPhrase2Weight.keySet()) {
            queryJsonArr.add(new JsonPrimitive(normalizedPhrase2Weight.get(phrase)));
            queryJsonArr.add(new JsonPrimitive(phrase));
          }
        }
        if (tokenWeight > 0) {
          queryJsonArr.add(new JsonPrimitive(tokenWeight));
          queryJsonArr.add(queryTJson);
        }
        queryJson.add("#combweight", queryJsonArr);
      }

    return new StructuredQuery(queryJson, length);
  }
}
TOP

Related Classes of ivory.sqe.querygenerator.MtNQueryGenerator

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.