Package srmdata

Source Code of srmdata.StructuredRelevanceModel$DescendingRelevanceComp

package srmdata;

import java.io.File;
import java.io.IOException;
import java.util.Arrays;
import java.util.Comparator;
import java.util.HashMap;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;

import org.apache.lucene.document.Document;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.Term;
import org.apache.lucene.index.TermDocs;
import org.apache.lucene.index.TermEnum;
import org.apache.lucene.store.FSDirectory;
import org.apache.lucene.store.RAMDirectory;

public class StructuredRelevanceModel {

  Set<Integer> allDocIds;
  LinkedHashSet<Integer> testDocIds;
  Set<Integer> trainDocIds;
 
  public StructuredRelevanceModel() {
    allDocIds = new LinkedHashSet<Integer>();
    testDocIds = new LinkedHashSet<Integer>();
    trainDocIds = new LinkedHashSet<Integer>();
  }

  static class Score {
    double score;
    double relevance;
    int docID;
  }

  void predictField(String fieldToPredict, Map<String,String> testTrainFiles) throws Exception {

    for (Map.Entry<String, String> filenames : testTrainFiles.entrySet()) {

      RAMDirectory trainRAMDirectory = new RAMDirectory(FSDirectory.open(new File(filenames.getKey())));
      RAMDirectory testRAMDirectory = new RAMDirectory(FSDirectory.open(new File(filenames.getValue())));
      IndexReader trainIR = IndexReader.open(trainRAMDirectory, true);
      IndexReader testIR  = IndexReader.open(testRAMDirectory, true);

      System.out.println("Train File Name: " + filenames.getKey() + " Test File Name: " + filenames.getValue());
      int num_fields = 3;

      int nTrainDocs = trainIR.numDocs();
      int nTestDocs = testIR.numDocs();

      double[][][] scores = new double[num_fields][][];
      long t1, t2;
      t1 = System.nanoTime();
        scores[0] = computePriors(testIR, trainIR, "title");
      t2 = System.nanoTime();
      System.out.println("Time Taken Priors (title): " + ((double)(t2-t1)) / 1E9);
      t1 = System.nanoTime();
        scores[1] = computePriors(testIR, trainIR, "desc");
      t2 = System.nanoTime();
      System.out.println("Time Taken Priors (desc): " + ((double)(t2-t1)) / 1E9);
      t1 = System.nanoTime();
      scores[2] = computePriors(testIR, trainIR, "content");
      t2 = System.nanoTime();
      System.out.println("Time Taken Priors (content): " + ((double)(t2-t1)) / 1E9);

      Score[][] combined_score = new Score[nTestDocs][nTrainDocs];
      for (int i = 0; i < nTestDocs; ++i) {
        for (int j = 0; j < nTrainDocs; ++j) {
          combined_score[i][j] = new Score();
        }
      }

      for (int i = 0; i < nTrainDocs; ++i) {
        for (int j = 0; j < nTestDocs; ++j) {
          combined_score[j][i].docID = i;
          combined_score[j][i].score = scores[0][i][j] * scores[1][i][j];
        }
      }

      t1 = System.nanoTime();
      DescendingScoreComp comp = new DescendingScoreComp();
      for (int i = 0; i < nTestDocs; ++i) {
        Arrays.sort(combined_score[i], comp);
      }

      int topN = 100;
      for (int i = 0; i < nTestDocs; ++i) {
        double total_score = 0.0;
        for (int j = 0; j < topN; ++j)
          total_score += combined_score[i][j].score;
        for (int j = 0; j < topN; ++j)
          combined_score[i][j].score /= total_score;
      }
      t2 = System.nanoTime();
      System.out.println("Time Taken Normalization and Sorting: " + ((double)(t2-t1)) / 1E9);

      int numAudience = 6;
      Map<String,Integer> audienceMap = new HashMap<String,Integer>();
      audienceMap.put("learner", 0);
      audienceMap.put("educator", 1);
      audienceMap.put("researcher", 2);
      audienceMap.put("general public", 3);
      audienceMap.put("professional/practitioner", 4);
      audienceMap.put("administrator", 5);

      int numCorrect = 0;
      for (int i = 0; i < nTestDocs; ++i) {
        double[] audience_freq = new double[numAudience];
        int total_freq = 0;
        for (int j = 0; j < topN; ++j) {
          int docID = combined_score[i][j].docID;
          combined_score[i][j].relevance = 0.0;
          Document doc = trainIR.document(docID);
          String fieldValue = doc.get(fieldToPredict).toLowerCase();
          Integer index = audienceMap.get(fieldValue);
          audience_freq[index]++;
          total_freq++;
        }

        for (int j = 0; j < topN; ++j) {
          int docID = combined_score[i][j].docID;
          Document doc = trainIR.document(docID);
          String fieldValue = doc.get(fieldToPredict).toLowerCase();
          Integer index = audienceMap.get(fieldValue);
          combined_score[i][j].relevance = (audience_freq[index] / total_freq) * combined_score[i][j].score;
        }
        Arrays.sort(combined_score[i], 0, topN-1, new DescendingRelevanceComp());

        Document testDoc = testIR.document(i);
        String actualValue = testDoc.get(fieldToPredict);

        System.out.print(actualValue + " : ");
        for (int j = 0; j < 10; ++j) {
          int docID = combined_score[i][j].docID;
          combined_score[i][j].relevance = 0.0;
          Document doc = trainIR.document(docID);
          String fieldValue = doc.get(fieldToPredict);
          if (j == 0 && actualValue.equals(fieldValue))
            numCorrect++;
          System.out.print(fieldValue + " ");
        }
        System.out.println();
      }

      System.out.println("Num Correct: " + numCorrect + " out of " + nTestDocs);
     
      trainIR.close();
      testIR.close();
    }
  }

  static class DescendingScoreComp implements Comparator<Score> {
    @Override
    public int compare(Score o1, Score o2) {
      Double diff = o2.score-o1.score;
      if (diff < 0)
        return -1;
      if (diff > 0)
        return 1;
      return 0;
    }
  }
 
  static class DescendingRelevanceComp implements Comparator<Score> {
    @Override
    public int compare(Score o1, Score o2) {
      Double diff = o2.relevance-o1.relevance;
      if (diff < 0)
        return -1;
      if (diff > 0)
        return 1;
      return 0;
    }
  }

  static boolean containsNumber(String str) {
    for (int i = 0; i < str.length(); ++i) {
      if (str.charAt(i) >= '0' && str.charAt(i) <= '9')
        return true;
    }
    return false;
  }

  double[][] computePriors(IndexReader testIR, IndexReader trainIR, String fieldName) throws Exception {

    // assume there are no holes in document ids for train/test indices
    int nTrainDocs = trainIR.numDocs();
    int nTestDocs = testIR.numDocs();

    // find number of terms in all training documents for the given field
    int[] doc_lengths = new int[nTrainDocs];

    double[][] modelScores;
    modelScores = new double[nTrainDocs][nTestDocs];
    for (int i = 0; i < modelScores.length; ++i)
      for (int j = 0; j < modelScores[i].length; ++j)
        modelScores[i][j] = 1.0;

    int collectionSize = findCollectionSize(trainIR, fieldName, doc_lengths);

    double[] mle = new double[trainIR.numDocs()];
    double score[] = new double[2];
    TermEnum terms = trainIR.terms();
    while (terms.next()) {

      Term t = terms.term();
      if (!t.field().equals(fieldName) || containsNumber(t.text()))
        continue;

      compute_mlestimate(trainIR, fieldName, t, doc_lengths, collectionSize, mle);
      if (mle == null)
        continue;

      int[] termDocsArr = new int[nTestDocs];
      for (int i = 0; i < termDocsArr.length; ++i)
        termDocsArr[i] = 1;
      TermDocs termDocs = testIR.termDocs(t);
      while (termDocs.next()) {
        termDocsArr[termDocs.doc()] = 0;
      }
      termDocs.close();

//      long t1 = System.nanoTime();
      for (int md = 0; md < nTrainDocs; ++md) {
        score[0] = mle[md];
        score[1] = 1.0 - score[0];
        for (int q = 0; q < nTestDocs; ++q) {
          modelScores[md][q] *= score[termDocsArr[q]];
        }
      }
//      long t2 = System.nanoTime();
//      System.out.println("Time Taken: " + (t2-t1)/1E6);
    }

    terms.close();
    return modelScores;
  }

  static double[] compute_mlestimate(IndexReader ir, String fieldName,
      Term t, int[] doc_length, int collectionSize, double[] mlEstimates) throws Exception {

    List<Double> avgs = compute_avgs(ir, fieldName, t, doc_length);
    Double pavg = avgs.get(0);
    Double meanfreq = avgs.get(1);
    Double collectionFreq = avgs.get(2);

    if (collectionFreq == 0.0) {
      return null;
    }

    for (int i = 0; i < mlEstimates.length; ++i) {
      mlEstimates[i] = 0.0;
    }

    pavg = Math.log10(pavg);
    double term1 = meanfreq / (1.0 + meanfreq);
    double term2 = 1.0 / (1.0 + meanfreq);
    TermDocs termDocs = ir.termDocs(t);
    while (termDocs.next()) {
      int d = termDocs.doc();
      int tf = termDocs.freq();
      if (tf == 0.0) {
        mlEstimates[d] = 0.0;
        continue;
      }
      double R = term2 * Math.pow(term1,tf);
      double pml = Math.log10( ((double)tf)/doc_length[d] );
      double val = (1.0-R)*pml + R*pavg;
      mlEstimates[d] = val;
    }
    termDocs.close();

    double defaultVal = Math.log10((double)collectionFreq/collectionSize);
    for (int md = 0; md < ir.maxDoc(); ++md) {
      if (mlEstimates[md] == 0.0)
        mlEstimates[md] = defaultVal;
    }

    return mlEstimates;
  }

  static List<Double> compute_avgs(IndexReader ir, String fieldName,
      Term t, int[] doc_length) throws Exception {

    double collectionFreq = 0;
    double pavg = 0.0;
    double meanfreq = 0.0;

    int count = 0;
    TermDocs termDocs = ir.termDocs(t);
    while (termDocs.next()) {
      int d = termDocs.doc();
      int tf = termDocs.freq();
      double pml = ((double)tf) / doc_length[d];
      pavg = pavg + pml;
      meanfreq = meanfreq + tf;
      collectionFreq = collectionFreq + tf;
      count++;
    }
    termDocs.close();

    if (count == 0) {
      pavg = 0.0;
      meanfreq = 0.0;
    }
    else {
      pavg = pavg / count;
      meanfreq = meanfreq / count;
    }
    return Arrays.asList(pavg, meanfreq, collectionFreq);
  }

  /**
   * Find total number of tokens in the collection
   * @param ir
    * @param fieldName
   * @param doc_lengths
   * @throws IOException
   */
  static Integer findCollectionSize(IndexReader ir, String fieldName, int[] doc_lengths) throws IOException {
    int collectionSize = 0;
    TermEnum terms = ir.terms();
    while (terms.next()) {
      Term t = terms.term();
      if (!t.field().equals(fieldName) || containsNumber(t.text()))
        continue;
      TermDocs termDocs = ir.termDocs(t);
      while (termDocs.next()) {
        int tf = termDocs.freq();
        collectionSize += tf;
        int docID = termDocs.doc();
        doc_lengths[docID] += tf;
      }
      termDocs.close();
    }
    terms.close();
    return collectionSize;
  }
}
TOP

Related Classes of srmdata.StructuredRelevanceModel$DescendingRelevanceComp

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.