Package edu.stanford.nlp.patterns.surface

Source Code of edu.stanford.nlp.patterns.surface.ScorePatternsRatioModifiedFreq

package edu.stanford.nlp.patterns.surface;

import java.io.IOException;
import java.lang.reflect.InvocationTargetException;
import java.util.*;
import java.util.Map.Entry;
import java.util.function.Function;

import edu.stanford.nlp.patterns.surface.ConstantsAndVariables.ScorePhraseMeasures;
import edu.stanford.nlp.patterns.surface.GetPatternsFromDataMultiClass.PatternScoring;
import edu.stanford.nlp.stats.ClassicCounter;
import edu.stanford.nlp.stats.Counter;
import edu.stanford.nlp.stats.Counters;
import edu.stanford.nlp.stats.TwoDimensionalCounter;
import edu.stanford.nlp.util.Execution;
import edu.stanford.nlp.util.Pair;

public class ScorePatternsRatioModifiedFreq<E> extends ScorePatterns<E> {

  public ScorePatternsRatioModifiedFreq(
      ConstantsAndVariables constVars,
      PatternScoring patternScoring,
      String label, Set<String> allCandidatePhrases,
      TwoDimensionalCounter<E, String> patternsandWords4Label,
      TwoDimensionalCounter<E, String> negPatternsandWords4Label,
      TwoDimensionalCounter<E, String> unLabeledPatternsandWords4Label,
      TwoDimensionalCounter<String, ScorePhraseMeasures> phInPatScores,
      ScorePhrases scorePhrases, Properties props) {
    super(constVars, patternScoring, label, allCandidatePhrases,  patternsandWords4Label,
        negPatternsandWords4Label, unLabeledPatternsandWords4Label,
        props);
    this.phInPatScores = phInPatScores;
    this.scorePhrases = scorePhrases;
  }

  // cached values
  private TwoDimensionalCounter<String, ScorePhraseMeasures> phInPatScores;

  private ScorePhrases scorePhrases;

  @Override
  public void setUp(Properties props) {
  }

  @Override
  Counter<E> score() throws IOException, ClassNotFoundException {

    Counter<String> externalWordWeightsNormalized = null;
    if (constVars.dictOddsWeights.containsKey(label))
      externalWordWeightsNormalized = GetPatternsFromDataMultiClass
          .normalizeSoftMaxMinMaxScores(constVars.dictOddsWeights.get(label),
              true, true, false);

    Counter<E> currentPatternWeights4Label = new ClassicCounter<E>();

    boolean useFreqPhraseExtractedByPat = false;
    if (patternScoring.equals(PatternScoring.SqrtAllRatio))
      useFreqPhraseExtractedByPat = true;
    Function<Pair<E, String>, Double> numeratorScore = x -> patternsandWords4Label.getCount(x.first(), x.second());

    Counter<E> numeratorPatWt = this.convert2OneDim(label,
        numeratorScore, allCandidatePhrases, patternsandWords4Label, constVars.sqrtPatScore, false, null,
        useFreqPhraseExtractedByPat);

    Counter<E> denominatorPatWt = null;

    Function<Pair<E, String>, Double> denoScore;
    if (patternScoring.equals(PatternScoring.PosNegUnlabOdds)) {
      denoScore = x -> negPatternsandWords4Label.getCount(x.first(), x.second()) + unLabeledPatternsandWords4Label.getCount(x.first(), x.second());

      denominatorPatWt = this.convert2OneDim(label,
          denoScore, allCandidatePhrases, patternsandWords4Label, constVars.sqrtPatScore, false,
          externalWordWeightsNormalized, useFreqPhraseExtractedByPat);
    } else if (patternScoring.equals(PatternScoring.RatioAll)) {
      denoScore = x -> negPatternsandWords4Label.getCount(x.first(), x.second()) + unLabeledPatternsandWords4Label.getCount(x.first(), x.second()) +
        patternsandWords4Label.getCount(x.first(), x.second());
      denominatorPatWt = this.convert2OneDim(label, denoScore,allCandidatePhrases, patternsandWords4Label,
          constVars.sqrtPatScore, false, externalWordWeightsNormalized,
          useFreqPhraseExtractedByPat);
    } else if (patternScoring.equals(PatternScoring.PosNegOdds)) {
      denoScore = x -> negPatternsandWords4Label.getCount(x.first(), x.second());
      denominatorPatWt = this.convert2OneDim(label, denoScore, allCandidatePhrases, patternsandWords4Label,
          constVars.sqrtPatScore, false, externalWordWeightsNormalized,
          useFreqPhraseExtractedByPat);
    } else if (patternScoring.equals(PatternScoring.PhEvalInPat)
        || patternScoring.equals(PatternScoring.PhEvalInPatLogP)
        || patternScoring.equals(PatternScoring.LOGREG)
        || patternScoring.equals(PatternScoring.LOGREGlogP)) {
      denoScore = x -> negPatternsandWords4Label.getCount(x.first(), x.second()) + unLabeledPatternsandWords4Label.getCount(x.first(), x.second());
      denominatorPatWt = this.convert2OneDim(label,
        denoScore, allCandidatePhrases, patternsandWords4Label, constVars.sqrtPatScore, true,
          externalWordWeightsNormalized, useFreqPhraseExtractedByPat);
    } else if (patternScoring.equals(PatternScoring.SqrtAllRatio)) {
      denoScore = x -> negPatternsandWords4Label.getCount(x.first(), x.second()) + unLabeledPatternsandWords4Label.getCount(x.first(), x.second());

      denominatorPatWt = this.convert2OneDim(label,
        denoScore, allCandidatePhrases, patternsandWords4Label, true, false,
          externalWordWeightsNormalized, useFreqPhraseExtractedByPat);
    } else
      throw new RuntimeException("Cannot understand patterns scoring");

    currentPatternWeights4Label = Counters.divisionNonNaN(numeratorPatWt,
        denominatorPatWt);

    //Multiplying by logP
    if (patternScoring.equals(PatternScoring.PhEvalInPatLogP) || patternScoring.equals(PatternScoring.LOGREGlogP)) {
      Counter<E> logpos_i = new ClassicCounter<E>();
      for (Entry<E, ClassicCounter<String>> en : patternsandWords4Label
          .entrySet()) {
        logpos_i.setCount(en.getKey(), Math.log(en.getValue().size()));
      }
      Counters.multiplyInPlace(currentPatternWeights4Label, logpos_i);
    }
    Counters.retainNonZeros(currentPatternWeights4Label);
    return currentPatternWeights4Label;
  }

  Counter<E> convert2OneDim(String label,
      Function<Pair<E, String>, Double> scoringFunction, Set<String> allCandidatePhrases, TwoDimensionalCounter<E, String> positivePatternsAndWords,
      boolean sqrtPatScore, boolean scorePhrasesInPatSelection,
      Counter<String> dictOddsWordWeights, boolean useFreqPhraseExtractedByPat) throws IOException, ClassNotFoundException {

    if (Data.googleNGram.size() == 0 && Data.googleNGramsFile != null) {
      Data.loadGoogleNGrams();
    }

    Counter<E> patterns = new ClassicCounter<E>();

    Counter<String> googleNgramNormScores = new ClassicCounter<String>();
    Counter<String> domainNgramNormScores = new ClassicCounter<String>();

    Counter<String> externalFeatWtsNormalized = new ClassicCounter<String>();
    Counter<String> editDistanceFromOtherSemanticBinaryScores = new ClassicCounter<String>();
    Counter<String> editDistanceFromAlreadyExtractedBinaryScores = new ClassicCounter<String>();
    double externalWtsDefault = 0.5;
    Counter<String> classifierScores = null;

    if ((patternScoring.equals(PatternScoring.PhEvalInPat) || patternScoring
        .equals(PatternScoring.PhEvalInPatLogP)) && scorePhrasesInPatSelection) {

      for (String g : allCandidatePhrases) {
        if (constVars.usePatternEvalEditDistOther) {

          editDistanceFromOtherSemanticBinaryScores.setCount(g,
              constVars.getEditDistanceScoresOtherClassThreshold(g));
        }
        if (constVars.usePatternEvalEditDistSame) {
          editDistanceFromAlreadyExtractedBinaryScores.setCount(g,
              1 - constVars.getEditDistanceScoresThisClassThreshold(label, g));
        }

        if (constVars.usePatternEvalGoogleNgram) {
          if (Data.googleNGram.containsKey(g)) {
            assert (Data.rawFreq.containsKey(g));
            googleNgramNormScores
                .setCount(
                    g,
                    ((1 + Data.rawFreq.getCount(g)
                        * Math.sqrt(Data.ratioGoogleNgramFreqWithDataFreq)) / Data.googleNGram
                        .getCount(g)));
          }
        }
        if (constVars.usePatternEvalDomainNgram) {
          // calculate domain-ngram wts
          if (Data.domainNGramRawFreq.containsKey(g)) {
            assert (Data.rawFreq.containsKey(g));
            domainNgramNormScores.setCount(g,
                scorePhrases.phraseScorer.getDomainNgramScore(g));
          }
        }

        if (constVars.usePatternEvalWordClass) {
          Integer num = constVars.getWordClassClusters().get(g);
          if (num != null
              && constVars.distSimWeights.get(label).containsKey(num)) {
            externalFeatWtsNormalized.setCount(g,
                constVars.distSimWeights.get(label).getCount(num));
          } else
            externalFeatWtsNormalized.setCount(g, externalWtsDefault);
        }
      }
      if (constVars.usePatternEvalGoogleNgram)
        googleNgramNormScores = GetPatternsFromDataMultiClass
            .normalizeSoftMaxMinMaxScores(googleNgramNormScores, true, true,
                false);
      if (constVars.usePatternEvalDomainNgram)
        domainNgramNormScores = GetPatternsFromDataMultiClass
            .normalizeSoftMaxMinMaxScores(domainNgramNormScores, true, true,
                false);
      if (constVars.usePatternEvalWordClass)
        externalFeatWtsNormalized = GetPatternsFromDataMultiClass
            .normalizeSoftMaxMinMaxScores(externalFeatWtsNormalized, true,
                true, false);
    }

    else if ((patternScoring.equals(PatternScoring.LOGREG) || patternScoring.equals(PatternScoring.LOGREGlogP))
        && scorePhrasesInPatSelection) {
      Properties props2 = new Properties();
      props2.putAll(props);
      props2.setProperty("phraseScorerClass", "edu.stanford.nlp.patterns.surface.ScorePhrasesLearnFeatWt");
      ScorePhrases scoreclassifier = new ScorePhrases(props2, constVars);
      System.out.println("file is " + props.getProperty("domainNGramsFile"));
      Execution.fillOptions(Data.class, props2);
      classifierScores = scoreclassifier.phraseScorer.scorePhrases(label, allCandidatePhrases,  true);
      // scorePhrases(Data.sents, label, true,
      // constVars.perSelectRand, constVars.perSelectNeg, null, null,
      // dictOddsWordWeights);
      // throw new RuntimeException("Not implemented currently");
    }

    Counter<String> cachedScoresForThisIter = new ClassicCounter<String>();

    for (Map.Entry<E, ClassicCounter<String>> en: positivePatternsAndWords.entrySet()) {

        for(Entry<String, Double> en2: en.getValue().entrySet()) {
          String word = en2.getKey();
          Counter<ScorePhraseMeasures> scoreslist = new ClassicCounter<ScorePhraseMeasures>();
          double score = 1;
          if ((patternScoring.equals(PatternScoring.PhEvalInPat) || patternScoring
            .equals(PatternScoring.PhEvalInPatLogP))
            && scorePhrasesInPatSelection) {
            if (cachedScoresForThisIter.containsKey(word)) {
              score = cachedScoresForThisIter.getCount(word);
            } else {
              if (constVars.getOtherSemanticClassesWords().contains(word)
                || constVars.getCommonEngWords().contains(word))
                score = 1;
              else {

                if (constVars.usePatternEvalSemanticOdds) {
                  double semanticClassOdds = 1;
                  if (dictOddsWordWeights.containsKey(word))
                    semanticClassOdds = 1 - dictOddsWordWeights.getCount(word);
                  scoreslist.setCount(ScorePhraseMeasures.SEMANTICODDS,
                    semanticClassOdds);
                }

                if (constVars.usePatternEvalGoogleNgram) {
                  double gscore = 0;
                  if (googleNgramNormScores.containsKey(word)) {
                    gscore = 1 - googleNgramNormScores.getCount(word);
                  }
                  scoreslist.setCount(ScorePhraseMeasures.GOOGLENGRAM, gscore);
                }

                if (constVars.usePatternEvalDomainNgram) {
                  double domainscore;
                  if (domainNgramNormScores.containsKey(word)) {
                    domainscore = 1 - domainNgramNormScores.getCount(word);
                  } else
                    domainscore = 1 - scorePhrases.phraseScorer
                      .getPhraseWeightFromWords(domainNgramNormScores, word,
                        scorePhrases.phraseScorer.OOVDomainNgramScore);
                  scoreslist.setCount(ScorePhraseMeasures.DOMAINNGRAM,
                    domainscore);
                }
                if (constVars.usePatternEvalWordClass) {
                  double externalFeatureWt = externalWtsDefault;
                  if (externalFeatWtsNormalized.containsKey(word))
                    externalFeatureWt = 1 - externalFeatWtsNormalized.getCount(word);
                  scoreslist.setCount(ScorePhraseMeasures.DISTSIM,
                    externalFeatureWt);
                }

                if (constVars.usePatternEvalEditDistOther) {
                  assert editDistanceFromOtherSemanticBinaryScores.containsKey(word) : "How come no edit distance info?";
                  scoreslist.setCount(ScorePhraseMeasures.EDITDISTOTHER,
                    editDistanceFromOtherSemanticBinaryScores.getCount(word));
                }
                if (constVars.usePatternEvalEditDistSame) {
                  scoreslist.setCount(ScorePhraseMeasures.EDITDISTSAME,
                    editDistanceFromAlreadyExtractedBinaryScores.getCount(word));
                }

                // taking average
                score = Counters.mean(scoreslist);

                phInPatScores.setCounter(word, scoreslist);
              }

              cachedScoresForThisIter.setCount(word, score);
            }
          } else if ((patternScoring.equals(PatternScoring.LOGREG) || patternScoring.equals(PatternScoring.LOGREGlogP))
            && scorePhrasesInPatSelection) {
            score = 1 - classifierScores.getCount(word);
            // score = 1 - scorePhrases.scoreUsingClassifer(classifier,
            // e.getKey(), label, true, null, null, dictOddsWordWeights);
            // throw new RuntimeException("not implemented yet");
          }

          if (useFreqPhraseExtractedByPat)
            score = score * scoringFunction.apply(new Pair(en.getKey(), word));
          if (constVars.sqrtPatScore)
            patterns.incrementCount(en.getKey(), Math.sqrt(score));
          else
            patterns.incrementCount(en.getKey(), score);
        }
    }



    return patterns;
  }

}
TOP

Related Classes of edu.stanford.nlp.patterns.surface.ScorePatternsRatioModifiedFreq

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.