Package ivory.smrf.model.expander

Source Code of ivory.smrf.model.expander.MRFExpander$TfDoclengthStatistics

/*
* Ivory: A Hadoop toolkit for web-scale information retrieval
*
* Licensed under the Apache License, Version 2.0 (the "License"); you
* may not use this file except in compliance with the License. You may
* obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied. See the License for the specific language governing
* permissions and limitations under the License.
*/

package ivory.smrf.model.expander;

import ivory.core.RetrievalEnvironment;
import ivory.core.data.document.IntDocVector;
import ivory.core.data.document.IntDocVector.Reader;
import ivory.core.exception.ConfigurationException;
import ivory.core.exception.RetrievalException;
import ivory.core.util.XMLTools;
import ivory.smrf.model.MarkovRandomField;
import ivory.smrf.model.Parameter;
import ivory.smrf.model.VocabFrequencyPair;
import ivory.smrf.model.builder.MRFBuilder;
import ivory.smrf.model.importance.ConceptImportanceModel;
import ivory.smrf.retrieval.Accumulator;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.Map.Entry;

import org.w3c.dom.Node;
import org.w3c.dom.NodeList;

import com.google.common.base.Preconditions;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;

import edu.umd.cloud9.util.map.HMapIV;

/**
* @author Don Metzler
*/
public abstract class MRFExpander {
  protected RetrievalEnvironment env = null; // Ivory retrieval environment.
  protected int numFeedbackDocs;             // Number of feedback documents.
  protected int numFeedbackTerms;            // Number of feedback terms.
  protected Set<String> stopwords = null;    // Stopwords list.

  // The expansion MRF cliques should be scaled according to this weight.
  protected float expanderWeight;

  // Maximum number of candidates to consider for expansion; non-positive numbers result in all
  // candidates being considered.
  protected int maxCandidates = 0;

  /**
   * @param mrf
   * @param results
   */
  public abstract MarkovRandomField getExpandedMRF(MarkovRandomField mrf, Accumulator[] results)
      throws ConfigurationException;

  /**
   * @param words list of words to ignore when constructing expansion concepts
   */
  public void setStopwordList(Set<String> words) {
    this.stopwords = Preconditions.checkNotNull(words);
  }

  public void setMaxCandidates(int maxCandidates) {
    this.maxCandidates = maxCandidates;
  }

  /**
   * @param env
   * @param model
   * @throws ConfigurationException
   */
  public static MRFExpander getExpander(RetrievalEnvironment env, Node model)
      throws ConfigurationException {
    Preconditions.checkNotNull(env);
    Preconditions.checkNotNull(model);

    // Get model type.
    String expanderType = XMLTools.getAttributeValueOrThrowException(model, "type",
        "Expander type must be specified!");

    // Get normalized model type.
    String normExpanderType = expanderType.toLowerCase().trim();

    // Build the expander.
    MRFExpander expander = null;

    if ("unigramlatentconcept".equals(normExpanderType)) {
      int fbDocs = XMLTools.getAttributeValue(model, "fbDocs", 10);
      int fbTerms = XMLTools.getAttributeValue(model, "fbTerms", 10);
      float expanderWeight = XMLTools.getAttributeValue(model, "weight", 1.0f);

      List<Parameter> parameters = Lists.newArrayList();
      List<Node> scoreFunctionNodes = Lists.newArrayList();
      List<ConceptImportanceModel> importanceModels = Lists.newArrayList();

      // Get the expandermodel, which describes how to actually build the expanded MRF.
      NodeList children = model.getChildNodes();
      for (int i = 0; i < children.getLength(); i++) {
        Node child = children.item(i);
        if ("conceptscore".equals(child.getNodeName())) {
          String paramID = XMLTools.getAttributeValueOrThrowException(child, "id",
              "conceptscore node must specify an id attribute!");

          float weight = XMLTools.getAttributeValue(child, "weight", 1.0f);

          parameters.add(new Parameter(paramID, weight));
          scoreFunctionNodes.add(child);

          // Get concept importance source (if applicable).
          ConceptImportanceModel importanceModel = null;
          String importanceSource = XMLTools.getAttributeValue(child, "importance", null);
          if (importanceSource != null) {
            importanceModel = env.getImportanceModel(importanceSource);
            if (importanceModel == null) {
              throw new RetrievalException("Error: importancemodel " + importanceSource
                  + " not found!");
            }
          }
          importanceModels.add(importanceModel);
        }
      }

      // Make sure there's at least one expansion model specified.
      if (scoreFunctionNodes.size() == 0) {
        throw new ConfigurationException("No conceptscore specified!");
      }

      // Create the expander.
      expander = new UnigramLatentConceptExpander(env, fbDocs, fbTerms, expanderWeight, parameters,
          scoreFunctionNodes, importanceModels);

      // Maximum number of candidate expansion terms to consider per query.
      int maxCandidates = XMLTools.getAttributeValue(model, "maxCandidates", 0);
      if (maxCandidates > 0) {
        expander.setMaxCandidates(maxCandidates);
      }
    } else if ("latentconcept".equals(normExpanderType)) {
      int defaultFbDocs = XMLTools.getAttributeValue(model, "fbDocs", 10);
      int defaultFbTerms = XMLTools.getAttributeValue(model, "fbTerms", 10);

      List<Integer> gramList = new ArrayList<Integer>();
      List<MRFBuilder> builderList = new ArrayList<MRFBuilder>();
      List<Integer> fbDocsList = new ArrayList<Integer>();
      List<Integer> fbTermsList = new ArrayList<Integer>();

      // Get the expandermodel, which describes how to actually build the expanded MRF.
      NodeList children = model.getChildNodes();
      for (int i = 0; i < children.getLength(); i++) {
        Node child = children.item(i);
        if ("expansionmodel".equals(child.getNodeName())) {
          int gramSize = XMLTools.getAttributeValue(child, "gramSize", 1);
          int fbDocs = XMLTools.getAttributeValue(child, "fbDocs", defaultFbDocs);
          int fbTerms = XMLTools.getAttributeValue(child, "fbTerms", defaultFbTerms);

          // Set MRF builder parameters.
          gramList.add(gramSize);
          builderList.add(MRFBuilder.get(env, child));
          fbDocsList.add(fbDocs);
          fbTermsList.add(fbTerms);
        }
      }

      // Make sure there's at least one expansion model specified.
      if (builderList.size() == 0) {
        throw new ConfigurationException("No expansionmodel specified!");
      }

      // Create the expander.
      expander = new NGramLatentConceptExpander(env, gramList, builderList, fbDocsList,
          fbTermsList);

      // Maximum number of candidate expansion terms to consider per query.
      int maxCandidates = XMLTools.getAttributeValue(model, "maxCandidates", 0);
      if (maxCandidates > 0) {
        expander.setMaxCandidates(maxCandidates);
      }
    } else {
      throw new ConfigurationException("Unrecognized expander type -- " + expanderType);
    }

    return expander;
  }

  @SuppressWarnings("unchecked")
  protected TfDoclengthStatistics getTfDoclengthStatistics(IntDocVector[] docVecs)
      throws IOException {
    Preconditions.checkNotNull(docVecs);

    Map<String, Integer> vocab = Maps.newHashMap();
    Map<String, Short>[] tfs = new HashMap[docVecs.length];
    int[] doclens = new int[docVecs.length];

    for (int i = 0; i < docVecs.length; i++) {
      IntDocVector doc = docVecs[i];

      Map<String, Short> docTfs = new HashMap<String, Short>();
      int doclen = 0;

      Reader dvReader = doc.getReader();
      while (dvReader.hasMoreTerms()) {
        int termid = dvReader.nextTerm();
        String stem = env.getTermFromId(termid);
        short tf = dvReader.getTf();

        doclen += tf;

        if (stem != null && (stopwords == null || !stopwords.contains(stem))) {
          Integer df = vocab.get(stem);
          if (df != null) {
            vocab.put(stem, df + 1);
          } else {
            vocab.put(stem, 1);
          }
        }

        docTfs.put(stem, tf);
      }

      tfs[i] = docTfs;
      doclens[i] = doclen;
    }

    // Sort the vocab hashmap according to tf.
    VocabFrequencyPair[] entries = new VocabFrequencyPair[vocab.size()];
    int entryNum = 0;
    for (Entry<String, Integer> entry : vocab.entrySet()) {
      entries[entryNum++] = new VocabFrequencyPair(entry.getKey(), entry.getValue());
    }
    Arrays.sort(entries);

    return new TfDoclengthStatistics(entries, tfs, doclens);
  }

  /**
   * @param docVecs
   * @param gramSize
   * @throws IOException
   */
  protected VocabFrequencyPair[] getVocabulary(IntDocVector[] docVecs, int gramSize)
      throws IOException {
    Map<String, Integer> vocab = new HashMap<String, Integer>();

    for (IntDocVector doc : docVecs) {
      HMapIV<String> termMap = new HMapIV<String>();
      int maxPos = Integer.MIN_VALUE;

      Reader dvReader = doc.getReader();
      while (dvReader.hasMoreTerms()) {
        int termid = dvReader.nextTerm();
        String stem = env.getTermFromId(termid);
        int[] pos = dvReader.getPositions();
        for (int i = 0; i < pos.length; i++) {
          termMap.put(pos[i], stem);
          if (pos[i] > maxPos) {
            maxPos = pos[i];
          }
        }
      }

      // Grab all grams of size gramSize that do not contain any out of vocabulary terms.
      for (int pos = 0; pos <= maxPos + 1 - gramSize; pos++) {
        String concept = new String();
        boolean toAdd = true;
        for (int offset = 0; offset < gramSize; offset++) {
          String stem = termMap.get(pos + offset);

          if (stem == null || (stopwords != null && stopwords.contains(stem))) {
            toAdd = false;
            break;
          }

          if (offset == gramSize - 1) {
            concept += stem;
          } else {
            concept += stem + " ";
          }
        }

        if (toAdd) {
          Integer tf = vocab.get(concept);
          if (tf != null) {
            vocab.put(concept, tf + 1);
          } else {
            vocab.put(concept, 1);
          }
        }
      }
    }

    // Sort the vocab hashmap according to tf.
    VocabFrequencyPair[] entries = new VocabFrequencyPair[vocab.size()];
    int entryNum = 0;
    for (Entry<String, Integer> entry : vocab.entrySet()) {
      entries[entryNum++] = new VocabFrequencyPair(entry.getKey(), entry.getValue());
    }
    Arrays.sort(entries);

    return entries;
  }

  protected class TfDoclengthStatistics {
    private VocabFrequencyPair[] vocab = null;
    private Map<String, Short>[] tfs = null;
    private int[] doclengths = null;

    public TfDoclengthStatistics(VocabFrequencyPair[] entries, Map<String, Short>[] tfs,
        int[] doclengths) {
      this.vocab = Preconditions.checkNotNull(entries);
      this.tfs = Preconditions.checkNotNull(tfs);
      this.doclengths = Preconditions.checkNotNull(doclengths);
    }

    public VocabFrequencyPair[] getVocab() {
      return vocab;
    }

    public Map<String, Short>[] getTfs() {
      return tfs;
    }

    public int[] getDoclens() {
      return doclengths;
    }
  }
}
TOP

Related Classes of ivory.smrf.model.expander.MRFExpander$TfDoclengthStatistics

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.