Package opennlp.ccg.parse.supertagger.ml

Source Code of opennlp.ccg.parse.supertagger.ml.STPriorModel

///////////////////////////////////////////////////////////////////////////////
// Copyright (C) 2010 Dennis N. Mehay
//
// This library is free software; you can redistribute it and/or
// modify it under the terms of the GNU Lesser General Public
// License as published by the Free Software Foundation; either
// version 2.1 of the License, or (at your option) any later version.
//
// This library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
// GNU Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public
// License along with this program; if not, write to the Free Software
// Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA  02111-1307, USA.
//////////////////////////////////////////////////////////////////////////////
package opennlp.ccg.parse.supertagger.ml;

import opennlp.ccg.parse.supertagger.util.ProbPairComparator;

import java.io.BufferedReader;
import java.io.BufferedWriter;
import java.io.File;
import java.io.FileNotFoundException;
import java.io.FileReader;
import java.io.FileWriter;
import java.io.IOException;
import java.io.InputStreamReader;
import java.io.OutputStreamWriter;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.logging.Level;
import java.util.logging.Logger;
import opennlp.ccg.lexicon.DefaultTokenizer;
import opennlp.ccg.lexicon.Word;
import opennlp.ccg.ngrams.ConditionalProbabilityTable;
import opennlp.ccg.parse.tagger.io.SRILMFactoredBundleCorpusIterator;
import opennlp.ccg.parse.tagger.ProbIndexPair;
import opennlp.ccg.parse.supertagger.util.STTaggerPOSDictionary;
import opennlp.ccg.util.Interner;
import opennlp.ccg.util.Pair;

/**
* (c) (2009) Dennis N. Mehay
* @author Dennis N. Mehay
*
* Model for predicting p(supertag | word, pos).  Uses an ARPA-formatted
* SRILM-trained "unigram" factored LM for this, where each "unigram" is
* a bundle of word:pos:supertag.
*/
public class STPriorModel extends ConditionalProbabilityTable {

    public static final String WORD = DefaultTokenizer.WORD_ATTR;
    public static final String POS_TAG = DefaultTokenizer.POS_ATTR;
    public static final String SUPERTAG = DefaultTokenizer.SUPERTAG_ATTR;
    private Interner<Pair<String, String>> pairs = new Interner<Pair<String, String>>();
    /**
     * Re-usable list for attr-val pairs of word-pos-supertag inputs to the prior model
     * (i.e., for predicting p(STag | word, POS).
     */
    public List<Pair<String, String>> attrVals = new ArrayList<Pair<String, String>>(5);
    /**
     * A comparator for sorting Pair<Double,String>'s where the Double is a probability
     * (effectively sorts by descending order of probability).
     */
    private ProbPairComparator ppcomp = new ProbPairComparator();
    /** All the priors. Reference them when getting beta-best, beta-worst, etc. */
    List<Pair<Double, String>> priors = new ArrayList<Pair<Double, String>>(1000);
    /** String[] of all possible supertag outcomes. */
    private String[] stagVocab = null;
    /** double[] containing the probability distro over all supertags. */
    private double[] stagDistro = null;
    /**
     * POS-keyed tagging dictionary (to provide restrictions on what the prior model may consider.
     * No restrictions if null.
     */
    private STTaggerPOSDictionary posDict = null;
    /**
     * Re-usable way of containing the probabilities and a pointer back into where they came from
     * in the probability distro over all supertags.
     */
    private ProbIndexPair[] stagPointers = null;

    /** Construct a prior model with the FLM config file and corresponding vocab file. */
    public STPriorModel(String flmFile, String vocabFile) throws IOException {
        // create with a null POS dictionary (i.e., no restrictions on taggings).
        this(flmFile, vocabFile, null);
    }

    /** Construct a prior model with the FLM config file and corresponding vocab file. */
    public STPriorModel(String flmFile, String vocabFile, STTaggerPOSDictionary posDict) throws IOException {
        super(flmFile);
        this.posDict = posDict;

        String st = null;
        BufferedReader br = new BufferedReader(new FileReader(new File(vocabFile)));
        st = br.readLine().trim();

        // get next supertag from the vocab.
        while ((st != null) && !st.trim().startsWith(SUPERTAG + "-")) {
            st = br.readLine();
        }
        if (st != null) {
            st = st.trim().split("-")[1];
        }

        Collection<String> allSupertags = new HashSet<String>();

        // find out how many outcomes we have.
        int cnt = 0;
        while (st != null) {
            cnt++;
            allSupertags.add(st);
            while ((st != null) && !st.trim().startsWith(SUPERTAG + "-")) {
                st = br.readLine();
            }
            if (st != null) {
                st = st.trim().split("-")[1];
            }
        }

        // initialize the arrays to this size.
        stagVocab = new String[cnt];
        stagPointers = new ProbIndexPair[cnt];
        stagDistro = new double[cnt];

        cnt = 0;
        // fill the vocab array with all possible supertags.
        for (String stag : allSupertags) {
            stagVocab[cnt++] = stag.intern();
        }
    }

    /** Set the POS-keyed tagging dictionary. */
    public void setPOSDict(STTaggerPOSDictionary posDict) {
        this.posDict = posDict;
    }

    /** Get the prior probability of this supertag/POS/word combo. */
    public double getPriorOf(String supertag, String word, String pos) {
        attrVals.clear();
        Pair<String, String> surfaceForm = pairs.intern(new Pair<String, String>(WORD, DefaultTokenizer.escape(word).intern()));
        attrVals.add(surfaceForm);
        Pair<String, String> partOfSpeech = pairs.intern(new Pair<String, String>(POS_TAG, DefaultTokenizer.escape(pos).intern()));
        attrVals.add(partOfSpeech);
        attrVals.add(pairs.intern(new Pair<String, String>(SUPERTAG, DefaultTokenizer.escape(supertag).intern())));
        return score(attrVals);
    }

    /** Get the beta-best tags for this word, under the prior model. */
    public List<Pair<String, Double>> getBetaBestPriors(Word w, double beta) {
        List<Pair<String, Double>> allPriors = getAllPriors(w);
        List<Pair<String, Double>> betaBestPriors = new ArrayList<Pair<String, Double>>(100);
        double best = allPriors.get(0).b;
        for (Pair<String, Double> prior : allPriors) {
            if (best * beta <= prior.b) {
                betaBestPriors.add(prior);
            } else {
                break;
            }
        }
        return betaBestPriors;
    }

    /** Compute all priors, subject to the POS dict constraints. */
    public void computePriors(Word w) {
        if (posDict != null) {
            priors = getPOSRestrictedPriors(w);
        }
    }

    /** Get the POS-dict restricted prior distribution (sorted descending by prob.) */
    protected List<Pair<Double, String>> getPOSRestrictedPriors(Word w) {
        Collection<String> tagsAllowed = posDict.getEntry(w.getPOS());
        if (tagsAllowed == null || tagsAllowed.size() == 0) {
            return priors;
        } else {
            List<Pair<Double, String>> sortedTags = new ArrayList<Pair<Double, String>>(tagsAllowed.size());
            for (String tag : tagsAllowed) {
                sortedTags.add(new Pair<Double, String>(getPriorOf(tag, w.getForm(), w.getPOS()), tag));
            }
            Collections.sort(sortedTags, ppcomp);
            return sortedTags;
        }
    }

    /**
     * Get the beta-best tags (using the prior model) only from among the POS-dictionary-allowed possibilities.
     * beta-best (def'n): {t | p(t) >= beta * p(best-tag) }
     */
    public List<Pair<String, Double>> getRestrictedBetaBestPriors(Word w, double beta) {
        if (posDict == null) {
            return getBetaBestPriors(w, beta);
        } else {
            List<Pair<String,Double>> rez = new ArrayList<Pair<String,Double>>(50);
            double best = priors.get(0).a;
            for(Pair<Double,String> tg : priors) {
                if(tg.a >= (beta * best)) {
                    rez.add(new Pair<String,Double>(tg.b,tg.a));
                } else {
                    break;
                }
            }
            return rez;
        }
    }
   
    /**
     *  Get the beta-WORST tags (using the prior model) only from among the POS-dictionary-allowed possibilities.
     *  beta-best (def'n): {t | p(t) >= beta * p(best-tag) }
     *  beta-worst (def'n): {t | p(t) * beta <= p(worst-tag)}
     */
    public List<Pair<String, Double>> getRestrictedBetaWorstPriors(Word w, double beta) {
        if (posDict == null) {
            throw new UnsupportedOperationException("Cannot get beta-worst without a pos-keyed tagging dict.\nNot yet implemented.");
        } else {
            List<Pair<String,Double>> rez = new ArrayList<Pair<String,Double>>(50);
            List<Pair<Double,String>> cpy = new ArrayList<Pair<Double,String>>(priors);
            Collections.reverse(cpy);           
            double worst = cpy.get(0).a;
            for(Pair<Double,String> tg : cpy) {
                if((tg.a * beta) <= worst) {
                    rez.add(new Pair<String,Double>(tg.b,tg.a));
                } else {
                    break;
                }
            }
            return rez;
        }
    }

    public List<Pair<String, Double>> getAllPriors(Word w) {
        return getNBestPriors(w, stagVocab.length);
    }

    /** Get the n-best supertags on the prior model, given this word (with POS). */
    public List<Pair<String, Double>> getNBestPriors(Word w, int n) {
        attrVals.clear();
        Pair<String, String> surfaceForm = pairs.intern(new Pair<String, String>(WORD, DefaultTokenizer.escape(w.getForm()).intern()));
        attrVals.add(surfaceForm);
        Pair<String, String> pos = pairs.intern(new Pair<String, String>(POS_TAG, DefaultTokenizer.escape(w.getPOS()).intern()));
        attrVals.add(pos);

        int cnt = 0;
        for (String st : stagVocab) {
            // remove the last stag factor, if there.
            if (attrVals.size() == 3) {
                attrVals.remove(attrVals.size() - 1);
            }

            attrVals.add(pairs.intern(new Pair<String, String>(SUPERTAG, st)));
            // add the probability of this tag under the prior model to the distro array.
            double sc = score(attrVals);
            stagDistro[cnt] = sc;
            // add this probability with a pointer back to where it came from in the vocab.
            // (so that we can sort by probability, but then retrieve the supertag string).
            stagPointers[cnt] = new ProbIndexPair(sc, cnt);
            cnt++;

        }
        // sort descending by probability (achieved by the comparator implementation of ProbIndexPair).

        Arrays.sort(stagPointers);

        List<Pair<String, Double>> result = new ArrayList<Pair<String, Double>>(n);
        for (int i = 0; i <
                n; i++) {
            result.add(new Pair<String, Double>(stagVocab[stagPointers[i].b], stagPointers[i].a));
        }

        return result;
    }

    public static void main(String[] args) throws IOException {
        String usage = "\nSTPriorModel -vocab <vocabfile> (-c <corpus>) (-o <output>) (-u <catFreqCutoff> ) (-v [ or '-verbose'])\n";
        if (args.length > 0 && args[0].equals("-h")) {
            System.out.println(usage);
            System.exit(0);
        }

        SRILMFactoredBundleCorpusIterator in = null;
        BufferedWriter out = null;
        BufferedWriter voc = null;

        try {
            String inputCorp = "<stdin>", output = "<stdout>", vocabFile = "vocab.voc";
            int catCutoff = 10;

            for (int i = 0; i <
                    args.length; i++) {
                if (args[i].equals("-c")) { inputCorp = args[++i]; continue; }
                if (args[i].equals("-o")) { output = args[++i];    continue; }
                if (args[i].equals("-vocab")) {vocabFile = args[++i]; continue; }
                if (args[i].equals("-u")) { catCutoff = Integer.parseInt(args[++i]); continue; }
                System.out.println("Unrecognized option: " + args[i]);
            }
           
            try {
                in = new SRILMFactoredBundleCorpusIterator(
                        (inputCorp.equals("<stdin>")) ? new BufferedReader(new InputStreamReader(System.in)) : new BufferedReader(new FileReader(new File(inputCorp))));
            } catch (FileNotFoundException ex) {
                System.err.print("Input corpus " + inputCorp + " not found.  Exiting...");
                Logger.getLogger(STPriorModel.class.getName()).log(Level.SEVERE, null, ex);
                System.exit(
                        -1);
            }

            try {
                out = (output.equals("<stdout>")) ? new BufferedWriter(new OutputStreamWriter(System.out)) : new BufferedWriter(new FileWriter(new File(output)));
            } catch (IOException ex) {
                System.err.print("Output file " + output + " not found.  Exiting...");
                Logger.getLogger(STPriorModel.class.getName()).log(Level.SEVERE, null, ex);
                System.exit(
                        -1);
            }

            try {
                voc = new BufferedWriter(new FileWriter(new File(vocabFile)));
            } catch (IOException ex) {
                Logger.getLogger(STPriorModel.class.getName()).log(Level.SEVERE, null, ex);
            }

            Map<String, Integer> vocab = new HashMap<String, Integer>();
            for (List<Word> inLine : in) {
                for (Word w : inLine) {
                    String st = SUPERTAG + "-" + DefaultTokenizer.escape(w.getSupertag()),
                            pos = POS_TAG + "-" + DefaultTokenizer.escape(w.getPOS()),
                            wform = WORD + "-" + DefaultTokenizer.escape(w.getForm());

                    vocab.put(st, (vocab.get(st) == null) ? 1 : vocab.get(st) + 1);
                    vocab.put(pos, (vocab.get(pos) == null) ? 1 : vocab.get(pos) + 1);
                    vocab.put(wform, (vocab.get(wform) == null) ? 1 : vocab.get(wform) + 1);
                }

            }

            // reopen file
            try {
                in = new SRILMFactoredBundleCorpusIterator(
                        (inputCorp.equals("<stdin>")) ? new BufferedReader(new InputStreamReader(System.in)) : new BufferedReader(new FileReader(new File(inputCorp))));
            } catch (FileNotFoundException ex) {
                System.err.print("Input corpus " + inputCorp + " not found.  Exiting...");
                Logger.getLogger(STPriorModel.class.getName()).log(Level.SEVERE, null, ex);
                System.exit(
                        -1);
            }
            for (List<Word> inLine : in) {
                for (Word w : inLine) {
                    String st = SUPERTAG + "-" + DefaultTokenizer.escape(w.getSupertag()),
                            pos = POS_TAG + "-" + DefaultTokenizer.escape(w.getPOS()),
                            wform = WORD + "-" + DefaultTokenizer.escape(w.getForm());
                    if (vocab.get(st) > catCutoff) {
                        out.write(wform + ":" + pos + ":" + st + " ");
                    }
                }

                out.write(System.getProperty("line.separator"));
            }

            out.flush();

            for (String str : vocab.keySet()) {
                if (vocab.get(str) > catCutoff) {
                    voc.write(str + System.getProperty("line.separator"));
                }
            }

            voc.flush();
        } finally {
            try {
                out.close();
                in.close();
                voc.close();
            } catch (IOException ex) {
                Logger.getLogger(STPriorModel.class.getName()).log(Level.SEVERE, null, ex);
            }
        }
    }
}
TOP

Related Classes of opennlp.ccg.parse.supertagger.ml.STPriorModel

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.