Package net.bpiwowar.mg4j.extensions.adhoc

Source Code of net.bpiwowar.mg4j.extensions.adhoc.MG4JScorer$PseudoRF$TopRanked

package net.bpiwowar.mg4j.extensions.adhoc;

import bpiwowar.argparser.Argument;
import bpiwowar.argparser.handlers.ClassChooser;
import bpiwowar.argparser.handlers.ClassChooser.Choice;
import bpiwowar.argparser.handlers.XStreamHandler;
import it.unimi.di.big.mg4j.document.Document;
import it.unimi.di.big.mg4j.document.DocumentCollection;
import it.unimi.di.big.mg4j.index.Index;
import it.unimi.di.big.mg4j.index.NullTermProcessor;
import it.unimi.di.big.mg4j.index.TermProcessor;
import it.unimi.di.big.mg4j.query.Query;
import it.unimi.di.big.mg4j.query.QueryEngine;
import it.unimi.di.big.mg4j.query.SelectedInterval;
import it.unimi.di.big.mg4j.query.parser.SimpleParser;
import it.unimi.di.big.mg4j.search.DocumentIteratorBuilderVisitor;
import it.unimi.di.big.mg4j.search.score.DocumentScoreInfo;
import it.unimi.di.big.mg4j.search.score.Scorer;
import it.unimi.dsi.fastutil.Hash;
import it.unimi.dsi.fastutil.ints.IntBigList;
import it.unimi.dsi.fastutil.longs.Long2LongMap;
import it.unimi.dsi.fastutil.longs.Long2LongRBTreeMap;
import it.unimi.dsi.fastutil.longs.LongRBTreeSet;
import it.unimi.dsi.fastutil.objects.*;
import it.unimi.dsi.io.WordReader;
import it.unimi.dsi.lang.MutableString;
import net.bpiwowar.mg4j.extensions.conf.IndexConfiguration;
import net.bpiwowar.mg4j.extensions.query.Topic;
import net.bpiwowar.mg4j.extensions.utils.Output;
import net.bpiwowar.mg4j.extensions.utils.TermUtil;
import net.bpiwowar.mg4j.extensions.utils.timer.TaskTimer;
import org.apache.commons.lang.mutable.MutableInt;
import org.apache.log4j.Logger;

import java.io.IOException;
import java.io.Reader;
import java.util.*;

import static java.lang.Math.log;
import static java.lang.Math.min;

public class MG4JScorer implements RetrievalModel {
  final static Logger logger = Logger.getLogger(MG4JScorer.class);

  @Argument(name = "term-processor", help = "The file that contains the term processor description", handler = XStreamHandler.class)
    TermProcessor processor = NullTermProcessor.getInstance();

  @Argument(name = "rf", prefix = "pseudo-rf", help = "How to handle pseudo-relevance feedback", handler = ClassChooser.class)
  PseudoRF pseudoRF;

  @Argument(name = "query-types", help = "The parts of the query that we use", required = true)
  private Set<String> queryTypes = new TreeSet<>();

    {
        queryTypes.add("title");
    }

  Scorer scorer;

  /**
   * The MG4J topic. Made public to be reusable later. Invoke process()
   * before using this!
   */
  public String mg4jTopic;
 
  transient private IndexConfiguration index;

    private QueryEngine queryEngine;

//  private uk.ac.gla.dcs.renaissance.mg4j.QueryEngine queryEngine;

  @Override
  public void init(DocumentCollection collection, IndexConfiguration index) throws Exception {
    this.index = index;

    final Object2ReferenceLinkedOpenHashMap<String, Index> indexMap = new Object2ReferenceLinkedOpenHashMap<>(
        Hash.DEFAULT_INITIAL_SIZE, .5f);
    indexMap.put(index.field, index.index);
    final Reference2DoubleOpenHashMap<Index> index2Weight = new Reference2DoubleOpenHashMap<>();
    index2Weight.put(index.index, 1.);

    final SimpleParser queryParser = new SimpleParser(indexMap.keySet(),
        indexMap.firstKey(), null);

    final Reference2ReferenceMap<Index, Object> index2Parser = new Reference2ReferenceOpenHashMap<>();

        queryEngine = new QueryEngine(queryParser, new DocumentIteratorBuilderVisitor(indexMap,
                index2Parser, indexMap.get(indexMap.firstKey()),
                Query.MAX_STEMMING), indexMap);

    queryEngine.score(new Scorer[]{scorer}, new double[]{1});

    queryEngine.multiplex = true;
    queryEngine.setWeights(index2Weight);

  }


    @Override
  public void close() {
  }

  @Override
    public void process(Topic topic, ObjectArrayList<DocumentScoreInfo<Reference2ObjectMap<Index, SelectedInterval[]>>> results, int capacity, TaskTimer timer) throws Exception {

        // --- Get the MG4J topic

    TreeMap<String, MutableInt> terms = new TreeMap<>();

    for (String queryType : queryTypes) {
            net.bpiwowar.mg4j.extensions.query.Query query = topic.getTopicPart(queryType);
      // Take all the words from the topic and construct the query for
      // MG4J
      // TODO Annalina: if (query instanceof MG4JQuery) {} else...
      TermUtil.getPositiveTerms(query, terms, processor, index);
    }

    this.mg4jTopic = Output.toString(" | ", terms.entrySet(),
                weightedWordFormatter);

    logger.info(String.format("Topic %s", mg4jTopic));
    queryEngine.process(mg4jTopic, 0, pseudoRF == null ? capacity
                : pseudoRF.k, results);

    // --- Handling relevance feedback

    if (pseudoRF != null) {
      for (CharSequence term : pseudoRF.process(results))
        update(terms, term.toString());

      mg4jTopic = Output.toString(" | ", terms.entrySet(),
                    weightedWordFormatter);
      logger.info(String.format("After blind RF, topic is %s", mg4jTopic));

      results.clear();
      queryEngine.process(mg4jTopic, 0, capacity, results);
    }
  }

  /**
   * A formatter that adds a MG4J weight to each word
   */
  public static Output.Formatter<Map.Entry<String, MutableInt>> weightedWordFormatter = new Output.Formatter<Map.Entry<String, MutableInt>>() {
    @Override
    public String format(Map.Entry<String, MutableInt> t) {
      return String.format("%s{%d}", t.getKey(), t.getValue().intValue());
    }
  };


  // --- Get sets of terms

  static public void update(Map<String, MutableInt> terms, String word) {
    MutableInt v = terms.get(word);
    if (v == null)
      v = terms.put(word, new MutableInt(1));
    else
      v.add(1);

  }



  // ---- Relevance feedback parameters ----
  // ---------------------------------------
  static public abstract class PseudoRF {
    @Argument(name = "top-k", help = "Top results to consider (default: top 10)")
    int k = 10;

    /**
     * @param results
     *            The set of results
     * @throws java.io.IOException
     */
    abstract Collection<CharSequence> process(
        ObjectArrayList<DocumentScoreInfo<Reference2ObjectMap<Index, SelectedInterval[]>>> results)
        throws IOException;

    abstract void init() throws IOException;

    @Choice(name = "trec-8")
    static public class TopRanked extends PseudoRF {

      @Argument(name = "threshold", help = "Threshold")
      double c = 0;

      // log of vocabularySize log(V)
      private double logV;

      private long vocabularySize;

      private IntBigList frequencies;

      private double logN;

      /**
       * The ith entry stores log(Comb(k; i+1))
       */
      double[] combinaisons;

      private IndexConfiguration index;

      private DocumentCollection collection;

      private int fieldIndex;

      private TermProcessor processor;

      @Override
      void init() throws IOException {
        vocabularySize = index.index.numberOfTerms;
        this.frequencies = index.getFrequencies();
        logN = log(index.index.numberOfDocuments);
        logV = log(vocabularySize);

        // TODO: Annalina - create a transformer that stem & stop using
        // processor
        // queryEngine.transformer(transformer);

        // Compute log(Comb(k; i+1)) for i = 0 to k-1
        combinaisons = new double[k];
        // log(Comb(k; 1)) is log(k)
        combinaisons[0] = log(k);
        for (int i = 1; i < k; i++) {
          // C(k; i+1) = C(k; i) * (k-i) / (i+1)
          // [don't forget that combinaisons[i] is C(k; i+1)
          combinaisons[i] = combinaisons[i - 1] + log(k - i)
              - log(i + 1);
        }

        fieldIndex = collection.factory().fieldIndex(index.field);
        if (fieldIndex < 0)
          throw new RuntimeException(String.format(
              "Could not find field %s in index", index.field));

      }

      @Override
      public Collection<CharSequence> process(
          ObjectArrayList<DocumentScoreInfo<Reference2ObjectMap<Index, SelectedInterval[]>>> results)
          throws IOException {
        // Creates a new collection of terms
        Collection<CharSequence> newTerms = new HashSet<>();

        // Get a subspace representation of the K first documents
        if (k > results.size()) {
          logger
              .warn(String.format(
                  "The number of returned documents (%d) is inferior to the number of documents (%d) for blind RF",
                  results.size(), k));
        }

        // --- Compute the frequencies of terms
        Long2LongRBTreeMap relFreq = computeTermFrequencies(results);

        // --- Select the candidates
        for (Long2LongMap.Entry entry : relFreq.long2LongEntrySet()) {
          long rt2 = entry.getLongValue();
          int nt = frequencies.getInt(entry.getLongKey());
          final double termScore = rt2 * (logN - log(nt))
              - combinaisons[(int)(rt2 - 1)] - logV;
          if (termScore > c) {
            // Add term to set
            final CharSequence term = index.getTerm(entry
                .getLongKey());
            newTerms.add(term);
            logger.info(String.format("Adding term %s", term));
          } else {
            if (logger.isDebugEnabled())
              logger.debug(String.format("Not adding term %s (%g <= %g)", index
                  .getTerm(entry.getLongKey()), termScore, c));
          }

        }

        return newTerms;
      }

      /**
       *
             * @param results
             * @return
       * @throws java.io.IOException
       */
      private Long2LongRBTreeMap computeTermFrequencies(
                    ObjectArrayList<DocumentScoreInfo<Reference2ObjectMap<Index, SelectedInterval[]>>> results)
          throws IOException {
        Long2LongRBTreeMap relFreq = new Long2LongRBTreeMap();
        relFreq.defaultReturnValue(0);

        for (int i = 0; i < min(k, results.size()); i++) {
          Document document = collection
              .document(results.get(i).document);
          Reader reader = (Reader) document.content(fieldIndex);
          WordReader wordReader = document.wordReader(fieldIndex);
          wordReader.setReader(reader);

          MutableString word = new MutableString();
          MutableString nonWord = new MutableString();
          final LongRBTreeSet set = new LongRBTreeSet();

          while (wordReader.next(word, nonWord)) {
            if (processor.processTerm(word)) {
              long termId = index.getTermId(word);
              if (termId >= 0)
                if (set.add(termId))
                  relFreq
                      .put(termId,
                          relFreq.get(termId) + 1);
            }
          }

          document.close();

        }
        return relFreq;
      }
    }

  }

}
TOP

Related Classes of net.bpiwowar.mg4j.extensions.adhoc.MG4JScorer$PseudoRF$TopRanked

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.