Package com.clearnlp.component.srl

Source Code of com.clearnlp.component.srl.AbstractSRLabeler

/**
* Copyright (c) 2009/09-2012/08, Regents of the University of Colorado
* All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
*    list of conditions and the following disclaimer.
* 2. Redistributions in binary form must reproduce the above copyright notice,
*    this list of conditions and the following disclaimer in the documentation
*    and/or other materials provided with the distribution.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
* ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
/**
* Copyright 2012/09-2013/04, 2013/11-Present, University of Massachusetts Amherst
* Copyright 2013/05-2013/10, IPSoft Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
*     http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.clearnlp.component.srl;

import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.util.ArrayList;
import java.util.List;
import java.util.Set;
import java.util.regex.Matcher;

import com.carrotsearch.hppc.IntOpenHashSet;
import com.clearnlp.classification.feature.FtrToken;
import com.clearnlp.classification.feature.JointFtrXml;
import com.clearnlp.classification.instance.StringInstance;
import com.clearnlp.classification.model.StringModel;
import com.clearnlp.classification.prediction.StringPrediction;
import com.clearnlp.classification.train.StringTrainSpace;
import com.clearnlp.classification.vector.StringFeatureVector;
import com.clearnlp.component.AbstractStatisticalComponent;
import com.clearnlp.component.evaluation.SRLEval;
import com.clearnlp.component.state.SRLState;
import com.clearnlp.constant.universal.UNConstant;
import com.clearnlp.constant.universal.UNPunct;
import com.clearnlp.dependency.DEPArc;
import com.clearnlp.dependency.DEPLibEn;
import com.clearnlp.dependency.DEPNode;
import com.clearnlp.dependency.DEPTree;
import com.clearnlp.dependency.srl.SRLLib;
import com.clearnlp.propbank.PBLib;
import com.clearnlp.propbank.frameset.AbstractFrames;
import com.clearnlp.propbank.frameset.PBRoleset;
import com.clearnlp.propbank.frameset.PBType;
import com.clearnlp.util.UTCollection;
import com.clearnlp.util.map.Prob1DMap;
import com.clearnlp.util.pair.ObjectDoublePair;
import com.clearnlp.util.pair.StringIntPair;

/**
* @since 1.0.0
* @author Jinho D. Choi ({@code jdchoi77@gmail.com})
*/
abstract public class AbstractSRLabeler extends AbstractStatisticalComponent<SRLState>
{
  protected final int LEXICA_PATH_UP   = 0;
  protected final int LEXICA_PATH_DOWN = 1;
  protected final int LEXICA_FRAMES   = 2;
 
  protected final int PATH_ALL     = 0;
  protected final int PATH_UP       = 1;
  protected final int PATH_DOWN     = 2;
  protected final int SUBCAT_ALL     = 0;
  protected final int SUBCAT_LEFT     = 1;
  protected final int SUBCAT_RIGHT   = 2;
 
  protected final String LB_NO_ARG = "N";
 
  protected Prob1DMap     m_down, m_up;  // only for collecting
  protected Set<String>   s_down, s_up;
  protected AbstractFrames m_frames;
 
//  ====================================== CONSTRUCTORS ======================================
 
  /** Constructs a semantic role labeler for collecting lexica. */
  public AbstractSRLabeler(JointFtrXml[] xmls, AbstractFrames frames)
  {
    super(xmls);
    m_down   = new Prob1DMap();
    m_up     = new Prob1DMap();
    m_frames = frames;
  }
 
  /** Constructs a semantic role labeler for training. */
  public AbstractSRLabeler(JointFtrXml[] xmls, StringTrainSpace[] spaces, Object[] lexica)
  {
    super(xmls, spaces, lexica);
  }
 
  /** Constructs a semantic role labeler for developing. */
  public AbstractSRLabeler(JointFtrXml[] xmls, StringModel[] models, Object[] lexica)
  {
    super(xmls, models, lexica, new SRLEval());
  }
 
  /** Constructs a semantic role labeler for decoding. */
  public AbstractSRLabeler(ObjectInputStream in)
  {
    super(in);
  }
 
  /** Constructs a semantic role labeler for bootstrapping. */
  public AbstractSRLabeler(JointFtrXml[] xmls, StringTrainSpace[] spaces, StringModel[] models, Object[] lexica)
  {
    super(xmls, spaces, models, lexica);
  }
 
  @Override @SuppressWarnings("unchecked")
  protected void initLexia(Object[] lexica)
  {
    s_down   = (Set<String>)   lexica[LEXICA_PATH_DOWN];
    s_up     = (Set<String>)   lexica[LEXICA_PATH_UP];
    m_frames = (AbstractFrames)lexica[LEXICA_FRAMES];
  }
 
//  ====================================== ABSTRACT METHODS ======================================
 
  abstract protected String getHardLabel(SRLState state, String label);
  abstract protected PBType getPBType(DEPNode pred);
  abstract protected void postLabel(SRLState state);
  abstract protected DEPNode getPossibleDescendent(DEPNode pred, DEPNode arg);
  abstract protected boolean rerankFromArgument(StringPrediction prediction, DEPNode arg);
 
//  ====================================== LOAD/SAVE MODELS ======================================
 
  @Override
  public void load(ObjectInputStream in)
  {
    try
    {
      loadDefault(in);
      loadLexica(in);
      in.close();
    }
    catch (Exception e) {e.printStackTrace();}
  }

  @Override
  public void save(ObjectOutputStream out)
  {
    try
    {
      saveDefault(out);
      saveLexica(out);
      out.close();
    }
    catch (Exception e) {e.printStackTrace();}
  }
 
  private void loadLexica(ObjectInputStream in) throws Exception
  {
    LOG.info("Loading lexica.\n");
   
    Object[] lexica = {in.readObject(), in.readObject(), in.readObject()};
    initLexia(lexica);
  }
 
  private void saveLexica(ObjectOutputStream out) throws Exception
  {
    LOG.info("Saving lexica.\n");
   
    out.writeObject(s_down);
    out.writeObject(s_up);
    out.writeObject(m_frames);
  }
 
//  ====================================== GETTERS/SETTERS ======================================
 
  @Override
  public Object[] getLexica()
  {
    Object[] lexica = new Object[3];
   
    lexica[LEXICA_PATH_DOWN] = (isLexica()) ? m_down.toSet(f_xmls[0].getPathDownCutoff()) : s_down;
    lexica[LEXICA_PATH_UP]   = (isLexica()) ? m_up  .toSet(f_xmls[0].getPathUpCutoff())   : s_up;
    lexica[LEXICA_FRAMES]    = m_frames;
   
    return lexica;
  }
 
  public Set<String> getDownSet(int cutoff)
  {
    return m_down.toSet(cutoff);
  }
 
  public Set<String> getUpSet(int cutoff)
  {
    return m_up.toSet(cutoff);
  }
 
  @Override
  public Set<String> getLabels()
  {
    return getDefaultLabels();
  }

//  ================================ PROCESS ================================
 
  @Override
  public void process(DEPTree tree)
  {
    SRLState state = init(tree);
    processAux(state);
   
    if (isDevelop())
      e_eval.countAccuracy(state.getTree(), state.getGoldLabels());
  }
 
  /** Called by {@link AbstractSRLabeler#process(DEPTree)}. */
  protected SRLState init(DEPTree tree)
  {
    SRLState state = new SRLState(tree);
   
    if (!isDecode())
    {
      state.setGoldLabels(tree.getSHeads());
      tree.clearSHeads();
    }
    else
      tree.initSHeads();

    return state;
  }
 
  protected void processAux(SRLState state)
  {
    if (isLexica())  addLexica(state);
    else      label(state);
  }
 
  private void addLexica(SRLState state)
  {
    DEPNode pred, head;
   
    while ((pred = state.moveToNextPredicate()) != null)
    {
      for (DEPArc arc : pred.getGrandDependents())
        collectDown(pred, arc.getNode());
   
      head = pred.getHead();
      if (head != null) collectUp(pred, head.getHead());
    }
  }
 
  private void collectDown(DEPNode pred, DEPNode arg)
  {
    if (arg.isArgumentOf(pred))
    {
      for (String path : getDUPathList(pred, arg.getHead()))
        m_down.add(path);
    }
   
    for (DEPArc arc : arg.getDependents())
      collectDown(pred, arc.getNode());
  }
 
  private void collectUp(DEPNode pred, DEPNode head)
  {
    if (head == nullreturn;
   
    for (DEPArc arc : head.getDependents())
    {
      if (arc.getNode().isArgumentOf(pred))
      {
        for (String path : getDUPathList(head, pred))
          m_up.add(path);
       
        break;
      }
    } 
   
    collectUp(pred, head.getHead());
  }
 
  private String getDUPath(DEPNode top, DEPNode bottom)
  {
    return getPathAux(top, bottom, JointFtrXml.F_DEPREL, SRLLib.DELIM_PATH_DOWN, true);
  }
 
  private List<String> getDUPathList(DEPNode top, DEPNode bottom)
  {
    List<String> paths = new ArrayList<String>();
   
    while (bottom != top)
    {
      paths.add(getDUPath(top, bottom));
      bottom = bottom.getHead();
    }
   
    return paths;
  }
 
  private void label(SRLState state)
  {
    DEPNode pred;
   
    while ((pred = state.moveToNextPredicate()) != null)
    {
      setRoleset(pred, state);

      do
      {
        labelAux(state);
      }
      while (state.moveToNextLowestCommonAncestor());// && (pred.isDependentOf(d_lca) || s_up.contains(getDUPath(d_lca, pred))));
    }
   
    postLabel(state);
  }
 
  private void setRoleset(DEPNode pred, SRLState state)
  {
    if (m_frames != null)
    {
      PBType type = getPBType(pred);
     
      if (type != null)
        state.setRoleset(m_frames.getRoleset(type, pred.lemma, pred.getFeat(DEPLibEn.FEAT_PB)));
    }
  }
 
  /** Called by {@link AbstractSRLabeler#label(DEPTree)}. */
  private void labelAux(SRLState state)
  {
    DEPNode head = state.getLowestCommonAncestor();
   
    if (!state.isSkip(head))
    {
      state.setArgument(head);
      addArgument(getLabel(state), state)
    }
   
    labelDown(head.getDependents(), state);
  }
 
  /** Called by {@link AbstractSRLabeler#labelAux(DEPNode, IntOpenHashSet)}. */
  private void labelDown(List<DEPArc> arcs, SRLState state)
  {
    DEPNode pred = state.getCurrentPredicate();
    DEPNode arg;
   
    for (DEPArc arc : arcs)
    {
      arg = arc.getNode();
     
      if (!state.isSkip(arg))
      {
        state.setArgument(arg);
        addArgument(getLabel(state), state);
       
        if (state.isLowestCommonAncestor(pred))
        {
          if (s_down.contains(getDUPath(pred, arg)))
            labelDown(arg.getDependents(), state);
          else if ((arg = getPossibleDescendent(pred, arg)) != null)
            labelDown(arg.getDependents(), state);
        }
       
//        if (state.isLowestCommonAncestor(pred) && s_down.contains(getDUPath(pred, arg)))
//          labelDown(arg.getDependents(), state);
      }
    }
  }
 
  private StringPrediction getLabel(SRLState state)
  {
    StringFeatureVector vector = getFeatureVector(f_xmls[0], state);
    int idx = state.getDirection();
    StringPrediction p = null;
   
    if (isTrain())
    {
      p = new StringPrediction(getGoldLabel(state), 1d);
      s_spaces[idx].addInstance(new StringInstance(p.label, vector));
    }
    else if (isDevelopOrDecode())
    {
      p = getAutoLabel(idx, vector, state);
    }
    else if (isBootstrap())
    {
      p = getAutoLabel(idx, vector, state);
      s_spaces[idx].addInstance(new StringInstance(getGoldLabel(state), vector));
    }

    return p;
  }
 
  /** Called by {@link AbstractSRLabeler#getGoldLabel(byte)}. */
  private String getGoldLabel(SRLState state)
  {
    for (StringIntPair head : state.getGoldLabel())
    {
      if (head.i == state.getCurrPredicateID())
        return head.s;
    }
   
    return LB_NO_ARG;
  }

  /** Called by {@link AbstractSRLabeler#getLabel(byte)}. */
  private StringPrediction getAutoLabel(int idx, StringFeatureVector vector, SRLState state)
  {
    StringPrediction p = getBestPrediction(s_models[idx], vector, state);
   
    if (isDecode() && !p.label.equals(LB_NO_ARG))
    {
      String label = getHardLabel(state, p.label);
      if (label != nullp.label = label;
    }
   
    return p;
  }

  private void addArgument(StringPrediction p, SRLState state)
  {
    DEPNode arg = state.getCurrentArgument();
    ObjectDoublePair<DEPNode> prev;
    DEPNode node;

    state.addArgumentToSkipList();
   
    if (!p.label.equals(LB_NO_ARG))
    {
      if (PBLib.isNumberedArgument(p.label))
      {
        state.addNumberedArgument(p.label);
       
        if (PBLib.isCoreNumberedArgument(p.label))
        {
          if ((prev = state.getCoreNumberedArgument(p.label)) != null)
          {
            node = (DEPNode)prev.o;
            node.removeSHeadsByLabel(p.label);
          }
         
          state.putCoreNumberedArgument(p.label, new ObjectDoublePair<DEPNode>(arg, p.score));
        }
      }
     
      String fTag = UNConstant.EMPTY;
      PBRoleset roleset = state.getRoleset();
     
      if (!p.label.contains(UNPunct.HYPHEN) && roleset != null)
      {
        String n = PBLib.getNumber(p.label);
        fTag = roleset.getFunctionTag(n);
      }
     
      arg.addSHead(state.getCurrentPredicate(), p.label, fTag);
    }
  }
 
//  ================================ FEATURE EXTRACTION ================================

  @Override
  protected String getField(FtrToken token, SRLState state)
  {
    DEPNode node = state.getNode(token);
    if (node == nullreturn null;
    Matcher m;
   
    if (token.isField(JointFtrXml.F_FORM))
    {
      return node.form;
    }
    else if (token.isField(JointFtrXml.F_LEMMA))
    {
      return node.lemma;
    }
    else if (token.isField(JointFtrXml.F_POS))
    {
      return node.pos;
    }
    else if (token.isField(JointFtrXml.F_DEPREL))
    {
      return node.getLabel();
    }
    else if (token.isField(JointFtrXml.F_DISTANCE))
    {
      return getDistance(node, state);
    }
    else if ((m = JointFtrXml.P_ARGN.matcher(token.field)).find())
    {
      return state.getNumberedArgument(Integer.parseInt(m.group(1)));
    }
    else if ((m = JointFtrXml.P_PATH.matcher(token.field)).find())
    {
      String type = m.group(1);
      int    dir  = Integer.parseInt(m.group(2));
     
      return getPath(type, dir, state);
    }
    else if ((m = JointFtrXml.P_SUBCAT.matcher(token.field)).find())
    {
      String type = m.group(1);
      int    dir  = Integer.parseInt(m.group(2));
     
      return getSubcat(node, type, dir);
    }
    else if ((m = JointFtrXml.P_FEAT.matcher(token.field)).find())
    {
      return node.getFeat(m.group(1));
    }
    else if ((m = JointFtrXml.P_BOOLEAN.matcher(token.field)).find())
    {
      DEPNode pred = state.getCurrentPredicate();
      int    field = Integer.parseInt(m.group(1));
     
      switch (field)
      {
      case 0: return (node.isDependentOf(pred)) ? token.field : null;
      case 1: return (pred.isDependentOf(node)) ? token.field : null;
      case 2: return (pred.isDependentOf(state.getLowestCommonAncestor())) ? token.field : null;
      case 3: return (state.isLowestCommonAncestor(pred)) ? token.field : null;
      case 4: return (state.isLowestCommonAncestor(node)) ? token.field : null;
      }
    }
   
    return null;
  }
 
  @Override
  protected String[] getFields(FtrToken token, SRLState state)
  {
    DEPNode node = state.getNode(token);
    if (node == nullreturn null;
   
    if (token.isField(JointFtrXml.F_DEPREL_SET))
    {
      return getDeprelSet(node.getDependents());
    }
    else if (token.isField(JointFtrXml.F_GRAND_DEPREL_SET))
    {
      return getDeprelSet(node.getGrandDependents());
    }
   
    return null;
  }
 
  private String getDistance(DEPNode node, SRLState state)
  {
    int dist = Math.abs(state.getCurrPredicateID() - node.id);
   
    if      (dist <=  5return "0";
    else if (dist <= 10return "1";
    else if (dist <= 15return "2";
    else          return "3";
  }
 
  private String getPath(String type, int dir, SRLState state)
  {
    DEPNode pred = state.getCurrentPredicate();
    DEPNode arg  = state.getCurrentArgument();
    DEPNode lca  = state.getLowestCommonAncestor();
   
    if (dir == PATH_UP)
    {
      if (lca != pred)
        return getPathAux(lca, pred, type, SRLLib.DELIM_PATH_UP, true);
    }
    else if (dir == PATH_DOWN)
    {
      if (lca != arg)
        return getPathAux(lca, arg, type, SRLLib.DELIM_PATH_DOWN, true);
    }
    else
    {
      if (pred == lca)
        return getPathAux(pred, arg, type, SRLLib.DELIM_PATH_DOWN, true);
      else if (pred.isDescendentOf(arg))
        return getPathAux(arg, pred, type, SRLLib.DELIM_PATH_UP, true);
      else
      {
        String path = getPathAux(lca, pred, type, SRLLib.DELIM_PATH_UP, true);
        path += getPathAux(lca, arg, type, SRLLib.DELIM_PATH_DOWN, false);
       
        return path;
      }     
    }
   
    return null;
  }
 
  private String getPathAux(DEPNode top, DEPNode bottom, String type, String delim, boolean includeTop)
  {
    StringBuilder build = new StringBuilder();
    DEPNode head = bottom;
    int dist = 0;
   
    do
    {
      if (type.equals(JointFtrXml.F_POS))
      {
        build.append(delim);
        build.append(head.pos);
      }
      else if (type.equals(JointFtrXml.F_DEPREL))
      {
        build.append(delim);
        build.append(head.getLabel());
      }
      else if (type.equals(JointFtrXml.F_DISTANCE))
      {
        dist++;
      }
   
      head = head.getHead();
    }
    while (head != top && head != null);
   
    if (type.equals(JointFtrXml.F_POS))
    {
      if (includeTop)
      {
        build.append(delim);
        build.append(top.pos)
      }
    }
    else if (type.equals(JointFtrXml.F_DISTANCE))
    {
      build.append(delim);
      build.append(dist);
    }
   
    return build.length() == 0 ? null : build.toString();
  }
 
  private String getSubcat(DEPNode node, String type, int dir)
  {
    List<DEPArc>  deps  = node.getDependents();
    StringBuilder build = new StringBuilder();
    int i, size = deps.size();
    DEPNode dep;
   
    if (dir == SUBCAT_LEFT)
    {
      for (i=0; i<size; i++)
      {
        dep = deps.get(i).getNode();
        if (dep.id > node.idbreak;
        getSubcatAux(build, dep, type);
      }
    }
    else if (dir == SUBCAT_RIGHT)
    {
      for (i=size-1; i>=0; i--)
      {
        dep = deps.get(i).getNode();
        if (dep.id < node.idbreak;
        getSubcatAux(build, dep, type);
      }
    }
    else
    {
      for (i=0; i<size; i++)
      {
        dep = deps.get(i).getNode();
        getSubcatAux(build, dep, type);
      }
    }
   
    return build.length() == 0 ? null : build.substring(SRLLib.DELIM_SUBCAT.length());
  }
 
  private void getSubcatAux(StringBuilder build, DEPNode node, String type)
  {
    build.append(SRLLib.DELIM_SUBCAT);
   
    if (type.equals(JointFtrXml.F_POS))
      build.append(node.pos);
    else if (type.equals(JointFtrXml.F_DEPREL))
      build.append(node.getLabel());
  }
 
//  ================================ RERANK ================================
 
  private StringPrediction getBestPrediction(StringModel model, StringFeatureVector vector, SRLState state)
  {
    List<StringPrediction> ps = model.predictAll(vector);
    rerankPredictions(ps, state);

    return ps.get(0);
  }
 
  protected void rerankPredictions(List<StringPrediction> ps, SRLState state)
  {
    DEPNode arg = state.getCurrentArgument();
    boolean change = false;
   
    for (StringPrediction p : ps)
    {
      if (rerankFrameMismatch(p, state) || rerankRedundantNumberedArgument(p, state) || rerankFromArgument(p, arg))
      {
        p.score = -1;
        change = true;
      }
    }
   
    if (changeUTCollection.sortReverseOrder(ps);
  }
 
  protected boolean rerankFrameMismatch(StringPrediction prediction, SRLState state)
  {
    PBRoleset roleset = state.getRoleset();
   
    if (roleset != null && !roleset.isValidArgument(prediction.label))
      return true;
   
    return false;
  }
 
  protected boolean rerankRedundantNumberedArgument(StringPrediction prediction, SRLState state)
  {
    ObjectDoublePair<DEPNode> prev = state.getCoreNumberedArgument(prediction.label);
   
    if (prev != null && prev.d >= prediction.score)
      return true;
   
    return false;
  }
}
TOP

Related Classes of com.clearnlp.component.srl.AbstractSRLabeler

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.