Package ivory.cascade.retrieval

Source Code of ivory.cascade.retrieval.CascadeBatchQueryRunner

/*
* Ivory: A Hadoop toolkit for web-scale information retrieval
*
* Licensed under the Apache License, Version 2.0 (the "License"); you
* may not use this file except in compliance with the License. You may
* obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied. See the License for the specific language governing
* permissions and limitations under the License.
*/

package ivory.cascade.retrieval;

import ivory.core.RetrievalEnvironment;
import ivory.core.eval.GradedQrels;
import ivory.core.eval.RankedListEvaluator;
import ivory.core.exception.ConfigurationException;
import ivory.core.util.ResultWriter;
import ivory.core.util.XMLTools;
import ivory.smrf.model.builder.MRFBuilder;
import ivory.smrf.model.expander.MRFExpander;
import ivory.smrf.retrieval.Accumulator;
import ivory.smrf.retrieval.BatchQueryRunner;
import ivory.smrf.retrieval.QueryRunner;

import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.HashMap;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.Map;
import java.util.Set;

import javax.xml.parsers.DocumentBuilderFactory;
import javax.xml.parsers.ParserConfigurationException;

import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.log4j.Logger;
import org.w3c.dom.Document;
import org.w3c.dom.Node;
import org.w3c.dom.NodeList;
import org.xml.sax.SAXException;


import com.google.common.collect.Maps;

import edu.umd.cloud9.collection.DocnoMapping;

/**
* @author Lidan Wang
*/

public class CascadeBatchQueryRunner extends BatchQueryRunner{
  private static final Logger LOG = Logger.getLogger(CascadeBatchQueryRunner.class);

  //For each model (as key), store the cascade costs for all queries (as value)
  private HashMap<String, float[]> cascadeCosts = Maps.newHashMap();
  private HashMap cascadeCosts_lastStage = new HashMap();

  private String [] internalOutputFiles;
  private String [] internalInputFiles;

  public LinkedList ndcgValues = new LinkedList();
  public LinkedList costKeys = new LinkedList();
  private String dataCollection = null;
  private int K_val;
  private int kVal;

  public CascadeBatchQueryRunner(String[] args, FileSystem fs) throws ConfigurationException {
    super (args, fs);
    parseParameters(args);

  }

  public HashMap readInternalInputFile(String internalInputFile){
    HashMap<String, LinkedList<float[]>> savedResults= new HashMap<String, LinkedList<float[]>>();
   
    if (internalInputFile!=null){
      BufferedReader in;
      try{
        in = new BufferedReader(new InputStreamReader(fs.open(new Path(internalInputFile))));
        String line;
         //Docnos and scores for a given query
        LinkedList<float[]> results = new LinkedList<float[]>();
        String qid = "";
        float [] docno_score = new float[2];

        while ((line = in.readLine()) != null && line.trim().length()> 0) {
          String [] tokens = line.split("\\s+");
         
          //# qid internal_docno score
          if (!(qid.equals(tokens[0]))){
            if (!(qid.equals(""))){
              savedResults.put(qid, results);
            }
            qid = tokens[0];
            results = new LinkedList<float[]>();
          }
          docno_score = new float[2];
          //docno_score[0] = Integer.parseInt(qid);
          docno_score[0] = (float)(Double.parseDouble(tokens[1]));
          docno_score[1] = (float)(Double.parseDouble(tokens[2]));
          results.add(docno_score);
        }

         //put last group of query and results in
        savedResults.put(qid, results);
      }
      catch (Exception e){
                                System.out.println("Problem reading "+internalInputFile);
                                System.exit(-1);
                        }

      if (savedResults.size()<=1){
        System.out.println("Should have results for more queries.");
        System.exit(-1);
      }
    }

    HashMap savedResults_return = new HashMap();
   
    Set set = savedResults.entrySet();
    Iterator itr = set.iterator();

    while (itr.hasNext()){
      Map.Entry me = (Map.Entry) itr.next();

      String key = (String)(me.getKey());
      LinkedList val = (LinkedList) (me.getValue());

      float [][] results = new float[val.size()][2];
      float [] r;
      for (int i=0; i<val.size(); i++){
        r = (float[]) (val.get(i));
        results[i][0] = r[0];
        results[i][1] = r[1];
      }

      savedResults_return.put(key, results);
    }   
    return savedResults_return;

  }

  public void runQueries() {

    //for each model, store cascade costs for all queries
    cascadeCosts = new HashMap();
    cascadeCosts_lastStage = new HashMap();

    int modelCnt = 0;

    for (String modelID : models.keySet()) {

      String internalInputFile = internalInputFiles[modelCnt];

      //Initialize mDocSet for each query if there is internalInputFile
      HashMap savedResults_prevStage = readInternalInputFile(internalInputFile);

      Node modelNode = models.get(modelID);
      Node expanderNode = expanders.get(modelID);

         //K value for cascade
                        K_val = XMLTools.getAttributeValue(modelNode, "K", 0);
      kVal = XMLTools.getAttributeValue(modelNode, "topK", 0);

      if (kVal == 0){
        System.out.println("Should not be 0!");
        System.exit(-1);
      }
      RetrievalEnvironment.topK = kVal;

      // Initialize retrieval environment variables.
      CascadeQueryRunner runner = null;
      MRFBuilder builder = null;
      MRFExpander expander = null;

      try {
        // Get the MRF builder.
        builder = MRFBuilder.get(env, modelNode.cloneNode(true));

        // Get the MRF expander.
        expander = null;
        if (expanderNode != null) {
          expander = MRFExpander.getExpander(env, expanderNode.cloneNode(true));
        }
        if (stopwords != null && stopwords.size() != 0) {
          expander.setStopwordList(stopwords);
        }

        int numHits = XMLTools.getAttributeValue(modelNode, "hits", 1000);
        if (K_val!=0){
          numHits = K_val;
        }       
        LOG.info("number of hits: " + numHits);

        // Multi-threaded query evaluation still a bit unstable; setting
        // thread=1 for now.
        runner = new CascadeThreadedQueryRunner(builder, expander, 1, numHits, savedResults_prevStage, K_val);

        queryRunners.put(modelID, (QueryRunner)runner);

      } catch (Exception e) {
        e.printStackTrace();
      }

      for (String queryID : queries.keySet()) {
        String rawQueryText = queries.get(queryID);
        String[] queryTokens = env.tokenize(rawQueryText);

        LOG.info(String.format("query id: %s, query: \"%s\"", queryID, rawQueryText));

        // Execute the query.
        runner.runQuery(queryID, queryTokens);
      }

      // Where should we output these results?
      Node model = models.get(modelID);
      String fileName = XMLTools.getAttributeValue(model, "output", null);
      boolean compress = XMLTools.getAttributeValue(model, "compress", false);


      String internalOutputFile = internalOutputFiles[modelCnt];

      try {
        ResultWriter resultWriter = new ResultWriter(fileName, compress, fs);

        //print out representation that uses internal docno for next cascade stage if doing boosting training
        if (internalOutputFile!=null){

          ResultWriter resultWriter2 = new ResultWriter(internalOutputFile, compress, fs);
          printResults(modelID, runner, resultWriter2, true);
          resultWriter2.flush();
        }

        printResults(modelID, runner, resultWriter, false);
        resultWriter.flush();
      } catch (IOException e) {
        throw new RuntimeException("Error: Unable to write results!");
      }


      cascadeCosts.put(modelID, runner.getCascadeCostAllQueries());

      cascadeCosts_lastStage.put(modelID, runner.getCascadeCostAllQueries_lastStage());

      modelCnt++;
    }

    //Compute evaluation metric
    float totalNDCG = 0, totalCost = 0;
    for (int i=0; i<costKeys.size(); i++){
      String [] tokens = ((String) costKeys.get(i)).split("\\s+");
      float cost = getCascadeCost(tokens[0], tokens[1]);
      float ndcg = Float.parseFloat((String) (ndcgValues.get(i)));

      totalNDCG+=ndcg;
      totalCost+=cost;
    }

    if (costKeys.size()!=ndcgValues.size()){
      System.out.println("They should be equal "+costKeys.size()+" "+ndcgValues.size());
      System.exit(-1);
    }
    System.out.println("Evaluation results... NDCG Sum "+totalNDCG+" TotalCost "+totalCost+" # queries with results "+costKeys.size()+" dataCollection "+dataCollection+" kVal "+kVal);
  }

  //The cascade cost of the qid under model
  public float getCascadeCost(String model, String qid){
    float [] allQueryCosts = (float[]) (cascadeCosts.get(model));
    return allQueryCosts[Integer.parseInt(qid)];
    //return Float.parseFloat((String)(cascadeCosts.get(model+" "+qid)));
  }

  public float getCascadeCost_lastStage(String model, String qid){
    float [] allQueryCosts_lastStage = (float[]) (cascadeCosts_lastStage.get(model));
    return allQueryCosts_lastStage[Integer.parseInt(qid)];
  }

  private void printResults(String modelID, CascadeQueryRunner runner, ResultWriter resultWriter, boolean internalDocno)
      throws IOException {

    float ndcgSum = 0;
    String qrelsPath = null;

    //Set up qrelsPath.
    if (dataCollection.indexOf("wt10g")!=-1){
      if (fs.exists(new Path("/user/lidan/qrels/qrels.wt10g"))){
        qrelsPath = "/user/lidan/qrels/qrels.wt10g";
      }
      else if (fs.exists(new Path("/umd-lin/lidan/qrels/qrels.wt10g"))){
        qrelsPath = "/umd-lin/lidan/qrels/qrels.wt10g";
      }
      else if (fs.exists(new Path("/fs/clip-trec/trunk_new/docs/data/wt10g/qrels.wt10g"))){
        qrelsPath = "/fs/clip-trec/trunk_new/docs/data/wt10g/qrels.wt10g";
      }      else if (fs.exists(new Path("data/wt10g/qrels.wt10g.all"))){
        qrelsPath = "data/wt10g/qrels.wt10g.all";
      }
    }
    else if (dataCollection.indexOf("gov2")!=-1){
      if (fs.exists(new Path("/user/lidan/qrels/qrels.gov2.all"))){
        qrelsPath = "/user/lidan/qrels/qrels.gov2.all";
      }     
      else if (fs.exists(new Path("/umd-lin/lidan/qrels/qrels.gov2.all"))){
        qrelsPath = "/umd-lin/lidan/qrels/qrels.gov2.all";
      }
      else if (fs.exists(new Path("/fs/clip-trec/trunk_new/docs/data/gov2/qrels.gov2.all"))){
        qrelsPath = "/fs/clip-trec/trunk_new/docs/data/gov2/qrels.gov2.all";
      } else if (fs.exists(new Path("data/gov2/qrels.gov2.all"))){
        qrelsPath = "data/gov2/qrels.gov2.all";
      }
    }
    else if (dataCollection.indexOf("clue")!=-1){
      if (fs.exists(new Path("/user/lidan/qrels/qrels.web09catB.txt"))){
        qrelsPath = "/user/lidan/qrels/qrels.web09catB.txt";
      }
      else if (fs.exists(new Path("/umd-lin/lidan/qrels/qrels.web09catB.txt"))){
        qrelsPath = "/umd-lin/lidan/qrels/qrels.web09catB.txt";
      }
      else if (fs.exists(new Path("/fs/clip-trec/trunk_new/docs/data/clue/qrels.web09catB.txt"))){
        qrelsPath = "/fs/clip-trec/trunk_new/docs/data/clue/qrels.web09catB.txt";
      } else if (fs.exists(new Path("data/clue/qrels.web09catB.txt"))){
        qrelsPath = "data/clue/qrels.web09catB.txt";
      }

    }

    if (qrelsPath == null){
      System.out.println("Should have set qrelsPath!");
      System.exit(-1);
    }

    GradedQrels qrels = new GradedQrels(qrelsPath);
                DocnoMapping mapping = getDocnoMapping();

    if (K_val==0){
      //System.out.println("K value should be set already.");
      //System.exit(-1);
    }
    for (String queryID : queries.keySet()) {
      // Get the ranked list for this query.
      Accumulator[] list = runner.getResults(queryID);
      if (list == null) {
        LOG.info("null results for: " + queryID);
        continue;
      }

      float ndcg = (float) RankedListEvaluator.computeNDCG(kVal, list, mapping, qrels.getReldocsForQid(queryID, true));
      ndcgSum += ndcg;

      if (!internalDocno){
        if (qrels.getReldocsForQid(queryID, true).size()>0){ //if have qrels for this query
          //System.out.println("Lidan: NDCG for query "+queryID+" is "+ndcg);
          ndcgValues.add(ndcg+"");
          costKeys.add(modelID+" "+queryID); //save keys for cost. Evaluation metric is computed by the end of runQueries()
        }
      }
      //Lidan: print out -- qid internal_docno score, for next cascade stage
      if (internalDocno){
        for (int i = 0; i < list.length; i++) {
          //System.out.println("Lidan: print internal results "+queryID + " "+list[i].docno + " " + list[i].score);
          resultWriter.println(queryID + " "+list[i].docno + " " + list[i].score);
        }
      }

      else if (docnoMapping == null) {
        // Print results with internal docnos if unable to translate to
        // external docids.
        for (int i = 0; i < list.length; i++) {
          resultWriter.println(queryID + " Q0 " + list[i].docno + " " + (i + 1) + " "
              + list[i].score + " " + modelID);
        }
      } else {
        // Translate internal docnos to external docids.
        for (int i = 0; i < list.length; i++) {
          resultWriter.println(queryID + " Q0 " + docnoMapping.getDocid(list[i].docno)
              + " " + (i + 1) + " " + list[i].score + " " + modelID);

        }
      }
    }

  }


        private void parseParameters(String[] args) throws ConfigurationException {
                for (int i = 0; i < args.length; i++) {
                        String element = args[i];
                        Document d = null;
               
                        try {
     
                          d = DocumentBuilderFactory.newInstance().newDocumentBuilder().parse(
                                                        fs.open(new Path(element)));
                        } catch (SAXException e) {
                                throw new ConfigurationException(e.getMessage());
                        } catch (IOException e) {
                                throw new ConfigurationException(e.getMessage());
                        } catch (ParserConfigurationException e) {
                                throw new ConfigurationException(e.getMessage());
                        }
               
                        parseModels(d);                       
                        parseIndexLocation(d);
                }
                               
                // Make sure we have some queries to run.
                if (queries.isEmpty()) {
                        throw new ConfigurationException("Must specify at least one query!");
                }
                       
                // Make sure there are models that need evaluated.
                if (models.isEmpty()) {
                        throw new ConfigurationException("Must specify at least one model!");
                }
                                               
                // Make sure we have an index to run against.
                if (indexPath == null) {
                        throw new ConfigurationException("Must specify an index!");
                }
        }      
 
  private void parseModels(Document d) throws ConfigurationException {

    NodeList models = d.getElementsByTagName("model");

    if (models.getLength() > 0){
      internalInputFiles = new String[models.getLength()];
      internalOutputFiles = new String[models.getLength()];
    }
    for (int i = 0; i < models.getLength(); i++) {
      // Get model XML node.
      Node node = models.item(i);


      // Get model id
      String modelID = XMLTools.getAttributeValue(node, "id", null);

      //Lidan: need to save to String[] internalInputFiles, indexed by modelID, in case there are multiple different internalInputFile, one for each model.
     
      String internalInputFile = XMLTools.getAttributeValue(node, "internalInputFile", null);

      if (internalInputFile!=null && internalInputFile.trim().length() == 0){
        internalInputFile = null;
      }

      internalInputFiles[i] = internalInputFile;

      String internalOutputFile = XMLTools.getAttributeValue(node, "internalOutputFile", null);

      if (internalOutputFile!=null && internalOutputFile.trim().length() == 0){
        internalOutputFile = null;
      }
      internalOutputFiles[i] = internalOutputFile;


      if (modelID == null) {
        throw new ConfigurationException("Must specify a model id for every model!");
      }

      // Parse parent nodes.
      NodeList children = node.getChildNodes();
      for (int j = 0; j < children.getLength(); j++) {
        Node child = children.item(j);
        if ("expander".equals(child.getNodeName())) {
          if (expanders.containsKey(modelID)) {
            throw new ConfigurationException("Only one expander allowed per model!");
          }
          expanders.put(modelID, child);
        }
      }

      // Add model to internal map.
      /*
      if (mModels.get(modelID) != null) {
        throw new ConfigurationException(
            "Duplicate model ids not allowed! Already parsed model with id=" + modelID);
      }
      mModels.put(modelID, node);
      */
    }
  }


  private void parseIndexLocation(Document d) throws ConfigurationException {
    NodeList index = d.getElementsByTagName("index");

    if (index.getLength() > 0) {
      /*
      if (mIndexPath != null) {
        throw new ConfigurationException(
            "Must specify only one index! There is no support for multiple indexes!");
      }
      mIndexPath = index.item(0).getTextContent();
      */

      if (indexPath!=null){
        //System.out.println("The name of the index is "+mIndexPath);
        if (indexPath.toLowerCase().indexOf("wt10g")!=-1){
          dataCollection = "wt10g";
          RetrievalEnvironment.dataCollection = "wt10g";
        }
        else if (indexPath.toLowerCase().indexOf("gov2")!=-1){
          dataCollection = "gov2";
          RetrievalEnvironment.dataCollection = "gov2";
        }
        else if (indexPath.toLowerCase().indexOf("clue")!=-1){
          dataCollection = "clue";
          RetrievalEnvironment.dataCollection = "clue";
        }
        else{
          System.out.println("Invalid data collection "+indexPath);
          System.exit(-1);
        }
      }
    }
  }

}
TOP

Related Classes of ivory.cascade.retrieval.CascadeBatchQueryRunner

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.