Package com.tamingtext.qa

Source Code of com.tamingtext.qa.PassageRankingComponent

/*
* Copyright 2008-2011 Grant Ingersoll, Thomas Morton and Drew Farris
*
*    Licensed under the Apache License, Version 2.0 (the "License");
*    you may not use this file except in compliance with the License.
*    You may obtain a copy of the License at
*
*        http://www.apache.org/licenses/LICENSE-2.0
*
*    Unless required by applicable law or agreed to in writing, software
*    distributed under the License is distributed on an "AS IS" BASIS,
*    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*    See the License for the specific language governing permissions and
*    limitations under the License.
* -------------------
* To purchase or learn more about Taming Text, by Grant Ingersoll, Thomas Morton and Drew Farris, visit
* http://www.manning.com/ingersoll
*/

package com.tamingtext.qa;


import com.tamingtext.texttamer.solr.NameFilter;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.Term;
import org.apache.lucene.index.TermEnum;
import org.apache.lucene.index.TermVectorMapper;
import org.apache.lucene.index.TermVectorOffsetInfo;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.spans.SpanNearQuery;
import org.apache.lucene.search.spans.SpanQuery;
import org.apache.lucene.search.spans.SpanTermQuery;
import org.apache.lucene.search.spans.Spans;
import org.apache.lucene.util.PriorityQueue;
import org.apache.solr.common.SolrException;
import org.apache.solr.common.params.CommonParams;
import org.apache.solr.common.params.SolrParams;
import org.apache.solr.common.util.NamedList;
import org.apache.solr.core.PluginInfo;
import org.apache.solr.core.SolrCore;
import org.apache.solr.handler.component.ResponseBuilder;
import org.apache.solr.handler.component.SearchComponent;
import org.apache.solr.schema.SchemaField;
import org.apache.solr.search.DocList;
import org.apache.solr.search.SolrIndexSearcher;
import org.apache.solr.util.plugin.PluginInfoInitialized;
import org.apache.solr.util.plugin.SolrCoreAware;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.IOException;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
import java.util.SortedSet;
import java.util.TreeSet;

/**
* Given a SpanQuery, get windows around the matches and rank those results
*/
public class PassageRankingComponent extends SearchComponent implements PluginInfoInitialized, SolrCoreAware, QAParams {
  private transient static Logger log = LoggerFactory.getLogger(PassageRankingComponent.class);

  static final String NE_PREFIX_LOWER = NameFilter.NE_PREFIX.toLowerCase();

  public static final int DEFAULT_PRIMARY_WINDOW_SIZE = 25;
  public static final int DEFAULT_ADJACENT_WINDOW_SIZE = 25;
  public static final int DEFAULT_SECONDARY_WINDOW_SIZE = 25;

  public static final float DEFAULT_ADJACENT_WEIGHT = 0.5f;
  public static final float DEFAULT_SECOND_ADJACENT_WEIGHT = 0.25f;
  public static final float DEFAULT_BIGRAM_WEIGHT = 1.0f;

  @Override
  public void init(PluginInfo pluginInfo) {

  }

  @Override
  public void inform(SolrCore solrCore) {

  }


  @Override
  public void prepare(ResponseBuilder rb) throws IOException {
    SolrParams params = rb.req.getParams();
    if (!params.getBool(COMPONENT_NAME, false)) {
      return;
    }


  }

  @Override
  public void process(ResponseBuilder rb) throws IOException {
    SolrParams params = rb.req.getParams();
    if (!params.getBool(COMPONENT_NAME, false)) {
      return;
    }
    Query origQuery = rb.getQuery();
    //TODO: longer term, we don't have to be a span query, we could re-analyze the document
    if (origQuery != null) {
      if (origQuery instanceof SpanNearQuery == false) {
        throw new SolrException(SolrException.ErrorCode.SERVER_ERROR, "Illegal query type.  The incoming query must be a Lucene SpanNearQuery and it was a " + origQuery.getClass().getName());
      }
      SpanNearQuery sQuery = (SpanNearQuery) origQuery;
      SolrIndexSearcher searcher = rb.req.getSearcher();
      IndexReader reader = searcher.getIndexReader();
      Spans spans = sQuery.getSpans(reader);
      //Assumes the query is a SpanQuery
      //Build up the query term weight map and the bi-gram
      Map<String, Float> termWeights = new HashMap<String, Float>();
      Map<String, Float> bigramWeights = new HashMap<String, Float>();
      createWeights(params.get(CommonParams.Q), sQuery, termWeights, bigramWeights, reader);
      float adjWeight = params.getFloat(ADJACENT_WEIGHT, DEFAULT_ADJACENT_WEIGHT);
      float secondAdjWeight = params.getFloat(SECOND_ADJ_WEIGHT, DEFAULT_SECOND_ADJACENT_WEIGHT);
      float bigramWeight = params.getFloat(BIGRAM_WEIGHT, DEFAULT_BIGRAM_WEIGHT);
      //get the passages
      int primaryWindowSize = params.getInt(QAParams.PRIMARY_WINDOW_SIZE, DEFAULT_PRIMARY_WINDOW_SIZE);
      int adjacentWindowSize = params.getInt(QAParams.ADJACENT_WINDOW_SIZE, DEFAULT_ADJACENT_WINDOW_SIZE);
      int secondaryWindowSize = params.getInt(QAParams.SECONDARY_WINDOW_SIZE, DEFAULT_SECONDARY_WINDOW_SIZE);
      WindowBuildingTVM tvm = new WindowBuildingTVM(primaryWindowSize, adjacentWindowSize, secondaryWindowSize);
      PassagePriorityQueue rankedPassages = new PassagePriorityQueue();
      //intersect w/ doclist
      DocList docList = rb.getResults().docList;
      while (spans.next() == true) {
        //build up the window
        if (docList.exists(spans.doc())) {
          tvm.spanStart = spans.start();
          tvm.spanEnd = spans.end();
          reader.getTermFreqVector(spans.doc(), sQuery.getField(), tvm);
          //The entries map contains the window, do some ranking of it
          if (tvm.passage.terms.isEmpty() == false) {
            log.debug("Candidate: Doc: {} Start: {} End: {} ",
                    new Object[]{spans.doc(), spans.start(), spans.end()});
          }
          tvm.passage.lDocId = spans.doc();
          tvm.passage.field = sQuery.getField();
          //score this window
          try {
            addPassage(tvm.passage, rankedPassages, termWeights, bigramWeights, adjWeight, secondAdjWeight, bigramWeight);
          } catch (CloneNotSupportedException e) {
            throw new SolrException(SolrException.ErrorCode.SERVER_ERROR, "Internal error cloning Passage", e);
          }
          //clear out the entries for the next round
          tvm.passage.clear();
        }
      }
      NamedList qaResp = new NamedList();
      rb.rsp.add("qaResponse", qaResp);
      int rows = params.getInt(QA_ROWS, 5);

      SchemaField uniqField = rb.req.getSchema().getUniqueKeyField();
      if (rankedPassages.size() > 0) {
        int size = Math.min(rows, rankedPassages.size());
        Set<String> fields = new HashSet<String>();
        for (int i = size - 1; i >= 0; i--) {
          Passage passage = rankedPassages.pop();
          if (passage != null) {
            NamedList passNL = new NamedList();
            qaResp.add(("answer"), passNL);
            String idName;
            String idValue;
            if (uniqField != null) {
              idName = uniqField.getName();
              fields.add(idName);
              fields.add(passage.field);//prefetch this now, so that it is cached
              idValue = searcher.doc(passage.lDocId, fields).get(idName);
            } else {
              idName = "luceneDocId";
              idValue = String.valueOf(passage.lDocId);
            }
            passNL.add(idName, idValue);
            passNL.add("field", passage.field);
            //get the window
            String fldValue = searcher.doc(passage.lDocId, fields).get(passage.field);
            if (fldValue != null) {
              //get the window of words to display, we don't use the passage window, as that is based on the term vector
              int start = passage.terms.first().start;//use the offsets
              int end = passage.terms.last().end;
              if (start >= 0 && start < fldValue.length() &&
                      end >= 0 && end < fldValue.length()) {
                passNL.add("window", fldValue.substring(start, end + passage.terms.last().term.length()));
              } else {
                log.debug("Passage does not have correct offset information");
                passNL.add("window", fldValue);//we don't have offsets, or they are incorrect, return the whole field value
              }
            }
          } else {
            break;
          }
        }
      }
    }


  }


  protected float scoreTerms(SortedSet<WindowTerm> terms, Map<String, Float> termWeights, Set<String> covered) {
    float score = 0f;
    for (WindowTerm wTerm : terms) {
      Float tw = (Float) termWeights.get(wTerm.term);
      if (tw != null && !covered.contains(wTerm.term)) {
        score += tw.floatValue();
        covered.add(wTerm.term);
      }
    }

    return (score);
  }

  protected float scoreBigrams(SortedSet<WindowTerm> bigrams, Map<String, Float> bigramWeights, Set<String> covered) {
    float result = 0;
    for (WindowTerm bigram : bigrams) {
      Float tw = (Float) bigramWeights.get(bigram.term);
      if (tw != null && !covered.contains(bigram.term)) {
        result += tw.floatValue();
        covered.add(bigram.term);
      }
    }
    return result;
  }

  /**
   * A fairly straightforward and simple scoring approach based on http://trec.nist.gov/pubs/trec8/papers/att-trec8.pdf.
   * <br/>
   * Score the {@link com.tamingtext.qa.PassageRankingComponent.Passage} as the sum of:
   * <ul>
   * <li>The sum of the IDF values for the primary window terms ({@link com.tamingtext.qa.PassageRankingComponent.Passage#terms}</li>
   * <li>The sum of the weights of the terms of the adjacent window ({@link com.tamingtext.qa.PassageRankingComponent.Passage#prevTerms} and {@link com.tamingtext.qa.PassageRankingComponent.Passage#followTerms}) * adjWeight</li>
   * <li>The sum of the weights terms of the second adjacent window ({@link com.tamingtext.qa.PassageRankingComponent.Passage#secPrevTerms} and {@link com.tamingtext.qa.PassageRankingComponent.Passage#secFollowTerms}) * secondAdjWeight</li>
   * <li>The sum of the weights of any bigram matches for the primary window * biWeight</li>
   * </ul>
   * In laymen's terms, this is a decay function that gives higher scores to matching terms that are closer to the anchor
   * term  (where the query matched, in the middle of the window) than those that are further away.
   *
   * @param p               The {@link com.tamingtext.qa.PassageRankingComponent.Passage} to score
   * @param termWeights     The weights of the terms, key is the term, value is the inverse doc frequency (or other weight)
   * @param bigramWeights   The weights of the bigrams, key is the bigram, value is the weight
   * @param adjWeight       The weight to be applied to the adjacent window score
   * @param secondAdjWeight The weight to be applied to the secondary adjacent window score
   * @param biWeight        The weight to be applied to the bigram window score
   * @return The score of passage
   */
  //<start id="qa.scorePassage"/>
  protected float scorePassage(Passage p, Map<String, Float> termWeights,
                               Map<String, Float> bigramWeights,
                               float adjWeight, float secondAdjWeight,
                               float biWeight) {
    Set<String> covered = new HashSet<String>();
    float termScore = scoreTerms(p.terms, termWeights, covered);//<co id="prc.main"/>
    float adjScore = scoreTerms(p.prevTerms, termWeights, covered) +
            scoreTerms(p.followTerms, termWeights, covered);//<co id="prc.adj"/>
    float secondScore = scoreTerms(p.secPrevTerms, termWeights, covered)
            + scoreTerms(p.secFollowTerms, termWeights, covered);//<co id="prc.sec"/>
    //Give a bonus for bigram matches in the main window, could also
    float bigramScore = scoreBigrams(p.bigrams, bigramWeights, covered);//<co id="prc.bigrams"/>
    float score = termScore + (adjWeight * adjScore) +
            (secondAdjWeight * secondScore)
            + (biWeight * bigramScore);//<co id="prc.score"/>
    return (score);
  }
  /*
  <calloutlist>
      <callout arearefs="prc.main"><para>Score the terms in the main window</para></callout>
      <callout arearefs="prc.adj"><para>Score the terms in the window immediately to the left and right of the main window</para></callout>
      <callout arearefs="prc.sec"><para>Score the terms in the windows adjacent to the previous and following windows</para></callout>
      <callout arearefs="prc.bigrams"><para>Score any bigrams in the passage</para></callout>
      <callout arearefs="prc.score"><para>The final score for the passage is a combination of all the scores, each weighted separately.  A bonus is given for any bigram matches.</para></callout>
     
  </calloutlist>
  */
  //<end id="qa.scorePassage"/>


  /**
   * Potentially add the passage to the PriorityQueue.
   *
   * @param p               The passage to add
   * @param pq              The {@link org.apache.lucene.util.PriorityQueue} to add the passage to if it ranks high enough
   * @param termWeights     The weights of the terms
   * @param bigramWeights   The weights of the bigrams
   * @param adjWeight       The weight to be applied to the score of the adjacent window
   * @param secondAdjWeight The weight to be applied to the score of the second adjacent window
   * @param biWeight        The weight to be applied to the score of the bigrams
   * @throws CloneNotSupportedException if not cloneable
   */
  private void addPassage(Passage p, PassagePriorityQueue pq, Map<String, Float> termWeights,
                          Map<String, Float> bigramWeights,
                          float adjWeight, float secondAdjWeight, float biWeight) throws CloneNotSupportedException {
    p.score = scorePassage(p, termWeights, bigramWeights, adjWeight, secondAdjWeight, biWeight);
    Passage lowest = pq.top();
    if (lowest == null || pq.lessThan(p, lowest) == false || pq.size() < pq.capacity()) {
      //by doing this, we can re-use the Passage object
      Passage cloned = (Passage) p.clone();
      //TODO: Do we care about the overflow?
      pq.insertWithOverflow(cloned);
    }

  }

  protected void createWeights(String origQuery, SpanNearQuery parsedQuery,
                               Map<String, Float> termWeights,
                               Map<String, Float> bigramWeights, IndexReader reader) throws IOException {

    SpanQuery[] clauses = parsedQuery.getClauses();
    //we need to recurse through the clauses until we get to SpanTermQuery
    Term lastTerm = null;
    Float lastWeight = null;
    for (int i = 0; i < clauses.length; i++) {
      SpanQuery clause = clauses[i];
      if (clause instanceof SpanTermQuery) {
        Term term = ((SpanTermQuery) clause).getTerm();
        Float weight = calculateWeight(term, reader);
        termWeights.put(term.text(), weight);
        if (lastTerm != null) {//calculate the bi-grams
          //use the smaller of the two weights
          if (lastWeight.floatValue() < weight.floatValue()) {
            bigramWeights.put(lastTerm + "," + term.text(), new Float(lastWeight.floatValue() * 0.25));
          } else {
            bigramWeights.put(lastTerm + "," + term.text(), new Float(weight.floatValue() * 0.25));
          }
        }
        //last
        lastTerm = term;
        lastWeight = weight;
      } else {
        //TODO: handle the other types
        throw new SolrException(SolrException.ErrorCode.SERVER_ERROR, "Unhandled query type: " + clause.getClass().getName());
      }
    }


  }

  protected float calculateWeight(Term term, IndexReader reader) throws IOException {
    //if a term is not in the index, then it's weight is 0
    TermEnum termEnum = reader.terms(term);
    if (termEnum != null && termEnum.term() != null && termEnum.term().equals(term)) {
      return 1.0f / termEnum.docFreq();
    } else {
      log.warn("Couldn't find doc freq for term {}", term);
      return 0;
    }

  }

  class Passage implements Cloneable {
    int lDocId;
    String field;

    float score;
    SortedSet<WindowTerm> terms = new TreeSet<WindowTerm>();
    SortedSet<WindowTerm> prevTerms = new TreeSet<WindowTerm>();
    SortedSet<WindowTerm> followTerms = new TreeSet<WindowTerm>();
    SortedSet<WindowTerm> secPrevTerms = new TreeSet<WindowTerm>();
    SortedSet<WindowTerm> secFollowTerms = new TreeSet<WindowTerm>();
    SortedSet<WindowTerm> bigrams = new TreeSet<WindowTerm>();

    Passage() {
    }

    @Override
    protected Object clone() throws CloneNotSupportedException {
      Passage result = (Passage) super.clone();
      result.terms = new TreeSet<WindowTerm>();
      for (WindowTerm term : terms) {
        result.terms.add((WindowTerm) term.clone());
      }
      result.prevTerms = new TreeSet<WindowTerm>();
      for (WindowTerm term : prevTerms) {
        result.prevTerms.add((WindowTerm) term.clone());
      }
      result.followTerms = new TreeSet<WindowTerm>();
      for (WindowTerm term : followTerms) {
        result.followTerms.add((WindowTerm) term.clone());
      }
      result.secPrevTerms = new TreeSet<WindowTerm>();
      for (WindowTerm term : secPrevTerms) {
        result.secPrevTerms.add((WindowTerm) term.clone());
      }
      result.secFollowTerms = new TreeSet<WindowTerm>();
      for (WindowTerm term : secFollowTerms) {
        result.secFollowTerms.add((WindowTerm) term.clone());
      }
      result.bigrams = new TreeSet<WindowTerm>();
      for (WindowTerm term : bigrams) {
        result.bigrams.add((WindowTerm) term.clone());
      }

      return result;
    }


    public void clear() {
      terms.clear();
      prevTerms.clear();
      followTerms.clear();
      secPrevTerms.clear();
      secPrevTerms.clear();
      bigrams.clear();
    }


  }

  class PassagePriorityQueue extends PriorityQueue<Passage> {

    PassagePriorityQueue() {
      initialize(10);
    }

    PassagePriorityQueue(int maxSize) {
      initialize(maxSize);
    }

    public int capacity() {
      return getHeapArray().length;
    }

    @Override
    public boolean lessThan(Passage passageA, Passage passageB) {
      if (passageA.score == passageB.score)
        return passageA.lDocId > passageB.lDocId;
      else
        return passageA.score < passageB.score;
    }
  }


  //Not thread-safe, but should be lightweight to build

  /**
   * The PassageRankingTVM is a Lucene TermVectorMapper that builds a five different windows around a matching term.
   * This Window can then be used to rank the passages
   */
  class WindowBuildingTVM extends TermVectorMapper {
    //spanStart and spanEnd are the start and positions of where the match occurred in the document
    //from these values, we can calculate the windows
    int spanStart, spanEnd;
    Passage passage;
    private int primaryWS, adjWS, secWS;


    public WindowBuildingTVM(int primaryWindowSize, int adjacentWindowSize, int secondaryWindowSize) {
      this.primaryWS = primaryWindowSize;
      this.adjWS = adjacentWindowSize;
      this.secWS = secondaryWindowSize;
      passage = new Passage();//reuse the passage, since it will be cloned if it makes it onto the priority queue
    }

    public void map(String term, int frequency, TermVectorOffsetInfo[] offsets, int[] positions) {
      if (positions.length > 0 && term.startsWith(NameFilter.NE_PREFIX) == false && term.startsWith(NE_PREFIX_LOWER) == false) {//filter out the types, as we don't need them here
        //construct the windows, which means we need a bunch of bracketing variables to know what window we are in

        //start and end of the primary window
        int primStart = spanStart - primaryWS;
        int primEnd = spanEnd + primaryWS;
        // stores the start and end of the adjacent previous and following
        int adjLBStart = primStart - adjWS;
        int adjLBEnd = primStart - 1;//don't overlap
        int adjUBStart = primEnd + 1;//don't o
        int adjUBEnd = primEnd + adjWS;
        //stores the start and end of the secondary previous and the secondary following
        int secLBStart = adjLBStart - secWS;
        int secLBEnd = adjLBStart - 1; //don't overlap the adjacent window
        int secUBStart = adjUBEnd + 1;
        int secUBEnd = adjUBEnd + secWS;
        WindowTerm lastWT = null;
        for (int i = 0; i < positions.length; i++) {//unfortunately, we still have to loop over the positions
          //we'll make this inclusive of the boundaries, do an upfront check here so we can skip over anything that is outside of all windows
          if (positions[i] >= secLBStart && positions[i] <= secUBEnd) {
            //fill in the windows
            WindowTerm wt;
            //offsets aren't required, but they are nice to have
            if (offsets != null){
              wt = new WindowTerm(term, positions[i], offsets[i].getStartOffset(), offsets[i].getEndOffset());
            } else {
              wt = new WindowTerm(term, positions[i]);
            }
            if (positions[i] >= primStart && positions[i] <= primEnd) {//are we in the primary window
              passage.terms.add(wt);
              //we are only going to keep bigrams for the primary window.  You could do it for the other windows, too
              if (lastWT != null) {
                WindowTerm bigramWT = new WindowTerm(lastWT.term + "," + term, lastWT.position);//we don't care about offsets for bigrams
                passage.bigrams.add(bigramWT);
              }
              lastWT = wt;
            } else if (positions[i] >= secLBStart && positions[i] <= secLBEnd) {//are we in the secondary previous window?
              passage.secPrevTerms.add(wt);
            } else if (positions[i] >= secUBStart && positions[i] <= secUBEnd) {//are we in the secondary following window?
              passage.secFollowTerms.add(wt);
            } else if (positions[i] >= adjLBStart && positions[i] <= adjLBEnd) {//are we in the adjacent previous window?
              passage.prevTerms.add(wt);
            } else if (positions[i] >= adjUBStart && positions[i] <= adjUBEnd) {//are we in the adjacent following window?
              passage.followTerms.add(wt);
            }
          }
        }
      }
    }



    public void setExpectations(String field, int numTerms, boolean storeOffsets, boolean storePositions) {
      // do nothing for this example
      //See also the PositionBasedTermVectorMapper.
    }

  }

  class WindowTerm implements Cloneable, Comparable<WindowTerm> {
    String term;
    int position;
    int start, end = -1;

    WindowTerm(String term, int position, int startOffset, int endOffset) {
      this.term = term;
      this.position = position;
      this.start = startOffset;
      this.end = endOffset;
    }

    public WindowTerm(String s, int position) {
      this.term = s;
      this.position = position;
    }

    @Override
    protected Object clone() throws CloneNotSupportedException {
      return super.clone();
    }

    @Override
    public int compareTo(WindowTerm other) {
      int result = position - other.position;
      if (result == 0) {
        result = term.compareTo(other.term);
      }
      return result;
    }

    @Override
    public String toString() {
      return "WindowEntry{" +
              "term='" + term + '\'' +
              ", position=" + position +
              '}';
    }
  }

  @Override
  public String getDescription() {
    return "Question Answering PassageRanking";
  }

  @Override
  public String getVersion() {
    return "$Revision:$";
  }

  @Override
  public String getSourceId() {
    return "$Id:$";
  }

  @Override
  public String getSource() {
    return "$URL:$";
  }
}
TOP

Related Classes of com.tamingtext.qa.PassageRankingComponent

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.