Package cc.mallet.fst.semi_supervised

Source Code of cc.mallet.fst.semi_supervised.GELattice$LatticeNode

/* Copyright (C) 2010 Univ. of Massachusetts Amherst, Computer Science Dept.
   This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit).
   http://www.cs.umass.edu/~mccallum/mallet
   This software is provided under the terms of the Common Public License,
   version 1.0, as published by http://www.opensource.org.  For further
   information, see the file `LICENSE' included with this distribution. */

package cc.mallet.fst.semi_supervised;

import java.util.ArrayList;

import cc.mallet.fst.CRF;
import cc.mallet.fst.Transducer;
import cc.mallet.fst.semi_supervised.constraints.GEConstraint;
import cc.mallet.types.FeatureVector;
import cc.mallet.types.FeatureVectorSequence;
import cc.mallet.types.LogNumber;

/**
* Runs the dynamic programming algorithm of [Mann and McCallum 08] for
* computing the gradient of a Generalized Expectation constraint that
* considers a single label of a linear chain CRF.
*
* See:
* "Generalized Expectation Criteria for Semi-Supervised Learning of Conditional Random Fields"
* Gideon Mann and Andrew McCallum
* ACL 2008
*
* gdruck NOTE: This new version of GE Lattice that computes the gradient
* for all constraints simultaneously!
*
* @author Gregory Druck
* @author Gaurav Chandalia
* @author Gideon Mann
*/
public class GELattice {
  // input length + 1
  protected int latticeLength;

  // the model
  protected Transducer transducer;
  // number of states in the FST
  protected int numStates;

  // dynamic programming lattice
  protected LatticeNode[][] lattice;
  // cache of dot produce between violation and
  // constraint features
  protected LogNumber[][][] dotCache;
 
  /**
   * @param fvs Input FeatureVectorSequence
   * @param gammas Marginals over single states
   * @param xis Marginals over pairs of states
   * @param transducer Transducer
   * @param reverseTrans Source state indices for each destination state
   * @param reverseTransIndices Transition indices for each destination state
   * @param gradient Gradient to increment
   * @param constraints List of constraints
   * @param check Whether to run the debugging test to verify correctness (will be much slower if true)
   */
  public GELattice(
      FeatureVectorSequence fvs, double[][] gammas, double[][][] xis,
      Transducer transducer, int[][] reverseTrans, int[][] reverseTransIndices, CRF.Factors gradient,
      ArrayList<GEConstraint> constraints, boolean check) {
    assert(gradient != null);

    latticeLength = fvs.size() + 1;
    this.transducer = transducer;
    numStates = transducer.numStates();

    // lattice
    lattice = new LatticeNode[latticeLength][numStates];
    for (int ip = 0; ip < latticeLength; ++ip) {
      for (int a = 0; a < numStates; ++a) {
        lattice[ip][a] = new LatticeNode();
      }
    }
   
    dotCache = new LogNumber[latticeLength][numStates][numStates];
   
    // TODO maybe this should be cached?
    // Separate lists for constraints that look at one vs two states.
    ArrayList<GEConstraint> constraints1 = new ArrayList<GEConstraint>();
    ArrayList<GEConstraint> constraints2 = new ArrayList<GEConstraint>();
   
    for (GEConstraint constraint : constraints) {
      if (constraint.isOneStateConstraint()) {
        constraints1.add(constraint);
      }
      else {
        constraints2.add(constraint);
      }
    }
   
    CRF crf = (CRF)transducer;
   
    double dotEx = this.runForward(crf, constraints1, constraints2, gammas, xis, reverseTrans, fvs);
    this.runBackward(crf, gammas, xis, reverseTrans, reverseTransIndices, fvs, dotEx, gradient);
    //check(constraints,gammas,xis,fvs);
  }
 
  /**
   * Run forward pass of dynamic programming algorithm
   *
   * @param crf CRF
   * @param constraints1 Constraints that consider one state.
   * @param constraints2 Constraints that consider two states.
   * @param gammas Marginals over single states
   * @param xis Marginals over pairs of states
   * @param reverseTrans Source state indices for each destination state
   * @param fvs Input FeatureVectorSequence
   * @return
   */
  private double runForward(CRF crf, ArrayList<GEConstraint> constraints1, ArrayList<GEConstraint> constraints2, double[][] gammas,
      double[][][] xis, int[][] reverseTrans, FeatureVectorSequence fvs) {
    double dotEx = 0;
 
    LogNumber[] oneStateValueCache = new LogNumber[numStates];
    LogNumber nuAlpha = new LogNumber(Transducer.IMPOSSIBLE_WEIGHT,true);
    LogNumber temp = new LogNumber(Transducer.IMPOSSIBLE_WEIGHT,true);
   
    for (int ip = 0; ip < latticeLength-1; ++ip) {
      FeatureVector fv = fvs.get(ip);
      // speed things up by giving the constraints an
      // opportunity to cache, for example, which
      // constrained input features appear in this
      // FeatureVector
      for (GEConstraint constraint : constraints1) {
        constraint.preProcess(fv);
      }
      for (GEConstraint constraint : constraints2) {
        constraint.preProcess(fv);
      }
     
      boolean[] oneStateValComputed = new boolean[numStates];
      for (int prev = 0; prev < numStates; prev++) {
        nuAlpha.set(Transducer.IMPOSSIBLE_WEIGHT,true);
        if (ip != 0) {
          int[] prevPrevs = reverseTrans[prev];
          // calculate only once: \sum_y_{i-1} w_a(y_{i-1},y_i)
          for (int ppi = 0; ppi < prevPrevs.length; ppi++) {
            nuAlpha.plusEquals(lattice[ip-1][prevPrevs[ppi]].alpha[prev]);
          }
        }

        assert (!Double.isNaN(nuAlpha.logVal));

        CRF.State prevState = (CRF.State)crf.getState(prev);
        LatticeNode node = lattice[ip][prev];
        double[] xi = xis[ip][prev];
        double gamma = gammas[ip][prev];

        for (int ci = 0; ci < prevState.numDestinations(); ci++) {
          int curr = prevState.getDestinationState(ci).getIndex();
          double dot = 0;
          for (GEConstraint constraint : constraints2) {
            dot += constraint.getCompositeConstraintFeatureValue(fv, ip, prev, curr);
          }

          // avoid recomputing one-state constraint features #labels times
          if (!oneStateValComputed[curr]) {
            double osVal = 0;
            for (GEConstraint constraint : constraints1) {
              osVal += constraint.getCompositeConstraintFeatureValue(fv, ip, prev, curr);
            }
            if (osVal < 0) {
              dotEx += Math.exp(gammas[ip+1][curr]) * osVal;
              oneStateValueCache[curr] = new LogNumber(Math.log(-osVal),false);
            }
            else if (osVal > 0) {
              dotEx += Math.exp(gammas[ip+1][curr]) * osVal;
              oneStateValueCache[curr] = new LogNumber(Math.log(osVal),true);
            }
            else {
              oneStateValueCache[curr] = null;
            }
            oneStateValComputed[curr] = true;
          }
         
          // combine the one and two state constraint feature values
          if (dot == 0 && oneStateValueCache[curr] == null) {
            dotCache[ip][prev][curr] = null;
          }
          else if (dot == 0 && oneStateValueCache[curr] != null) {
            dotCache[ip][prev][curr] = oneStateValueCache[curr];
          }
          else {
            dotEx += Math.exp(xi[curr]) * dot;
            if (dot < 0) {
              dotCache[ip][prev][curr] = new LogNumber(Math.log(-dot),false);
            }
            else {
              dotCache[ip][prev][curr] = new LogNumber(Math.log(dot),true);
            }
            if (oneStateValueCache[curr] != null) {
              dotCache[ip][prev][curr].plusEquals(oneStateValueCache[curr]);
            }
          }
         
          // update the dynamic programming table
          if (dotCache[ip][prev][curr] != null) {
            temp.set(xi[curr],true);
            temp.timesEquals(dotCache[ip][prev][curr]);
            node.alpha[curr].plusEquals(temp);
          }
          if (gamma == Transducer.IMPOSSIBLE_WEIGHT) {
            node.alpha[curr] = new LogNumber(Transducer.IMPOSSIBLE_WEIGHT,true);
          } else {
            temp.set(xi[curr] - gamma,true);
            temp.timesEquals(nuAlpha);
            node.alpha[curr].plusEquals(temp);
          }
          assert (!Double.isNaN(node.alpha[curr].logVal)) : "xi: " + xi[curr] + ", gamma: "
              + gamma + ", constraint feature: " + dotCache[ip][prev][curr]
              + ", nuApha: " + nuAlpha + " dot: " + dot;
        }
      }
    }
    return dotEx;
  }

  /**
   * Run backward pass of dynamic programming algorithm
   *
   * @param crf CRF
   * @param gammas Marginals over single states
   * @param xis Marginals over pairs of states
   * @param reverseTrans Source state indices for each destination state
   * @param reverseTransIndices Transition indices for each destination state
   * @param fvs Input FeatureVectorSequence
   * @param dotEx Expectation of constraint features dot violation terms
   * @param gradient Gradient to increment
   * @return
   */
  private void runBackward(CRF crf, double[][] gammas, double[][][] xis, int[][] reverseTrans, int[][] reverseTransIndices,
      FeatureVectorSequence fvs, double dotEx, CRF.Factors gradient) {
   
    LogNumber nuBeta = new LogNumber(Transducer.IMPOSSIBLE_WEIGHT,true);
    LogNumber dot = new LogNumber(Transducer.IMPOSSIBLE_WEIGHT,true);
    LogNumber temp = new LogNumber(Transducer.IMPOSSIBLE_WEIGHT,true);
    LogNumber temp2 = new LogNumber(Transducer.IMPOSSIBLE_WEIGHT,true);
    LogNumber nextDot;
   
    for (int ip = latticeLength-2; ip >= 0; --ip) {
      for (int curr = 0; curr < numStates; ++curr) {

        nuBeta.set(Transducer.IMPOSSIBLE_WEIGHT,true);
        dot.set(Transducer.IMPOSSIBLE_WEIGHT,true);
        // calculate only once: \sum_y_{i+1} w_b(y_i,y+i)
       
       
        CRF.State currState = (CRF.State)crf.getState(curr);
        for (int ni = 0; ni < currState.numDestinations(); ni++){
          int next= currState.getDestinationState(ni).getIndex();
          nuBeta.plusEquals(lattice[ip+1][curr].beta[next]);
          assert(!Double.isNaN(nuBeta.logVal));

          nextDot = dotCache[ip+1][curr][next];
          if (nextDot != null) {
            double xi = xis[ip+1][curr][next];
            temp.set(xi,true);
            temp.timesEquals(nextDot);
            dot.plusEquals(temp);
          }
        }

        double gamma = gammas[ip+1][curr];

        int[] prevStates = reverseTrans[curr];
        for (int pi = 0; pi < prevStates.length; pi++) {
          int prev = prevStates[pi];
         
          CRF.State crfState = (CRF.State)crf.getState(prev);

          LatticeNode node = lattice[ip][prev];
          double xi = xis[ip][prev][curr];

          if (gamma == Transducer.IMPOSSIBLE_WEIGHT) {
            node.beta[curr] = new LogNumber(Transducer.IMPOSSIBLE_WEIGHT,true);
          } else {
            // constraint feature values cached in Forward pass
            temp.set(dot.logVal,dot.sign);
            temp.plusEquals(nuBeta);
            temp2.set(xi-gamma,true);
            temp.timesEquals(temp2);
            node.beta[curr].plusEquals(temp);
          }
          assert(!Double.isNaN(node.beta[curr].logVal))
          : "xi: " + xi + ", gamma: " + gamma + ", xi: " + xi +
          ", log(indicatorFeat): " + dotCache[ip][curr];

          // compute and update gradient!
          double transProb = Math.exp(xi);
          double covFirstTerm = node.alpha[curr].exp() + node.beta[curr].exp();
          double contribution = (covFirstTerm - (transProb * dotEx));

          int nwi = crfState.getWeightNames(reverseTransIndices[curr][pi]).length;
          int weightsIndex;
          for (int wi = 0; wi < nwi; wi++) {
            weightsIndex = ((CRF)transducer).getWeightsIndex(crfState.getWeightNames(reverseTransIndices[curr][pi])[wi]);
            gradient.weights[weightsIndex].plusEqualsSparse (fvs.get(ip), contribution);
            gradient.defaultWeights[weightsIndex] += contribution;
          }
        }
      }
    }
  }
 
 
  /**
   * Verifies the correctness of the lattice computations.
   */
  public void check(ArrayList<GEConstraint> constraints, double[][] gammas, double[][][] xis, FeatureVectorSequence fvs) {
    // sum of marginal probabilities
    double ex1 = 0.0;
    for (int ip = 0; ip < latticeLength-1; ++ip) {
      for (int si1 = 0; si1 < numStates; si1++) {
        for (int si2 = 0; si2 < numStates; si2++) {
          double dot = 0;
          for (GEConstraint constraint : constraints) {
            dot += constraint.getCompositeConstraintFeatureValue(fvs.get(ip), ip, si1, si2);
          }
          double prob = Math.exp(xis[ip][si1][si2]);
          ex1 += prob * dot;
        }
      }
    }

    double ex2 = 0.0;
    for (int ip = 0; ip < latticeLength-1; ++ip) {
      double ex3 = 0.0;
      for (int s1 = 0; s1 < numStates; ++s1) {
        LatticeNode node = lattice[ip][s1];
        for (int s2 = 0; s2 < numStates; ++s2) {
          ex3 += node.alpha[s2].exp() + node.beta[s2].exp();
        }
      }
      // should be equal to marginal prob.
      assert(ex1 - ex3 < 1e-6) :ex1 + " " + ex3;
      ex2 += ex3;
    }
    ex2 = ex2 / (latticeLength - 1);
    // should be equal to marginal prob.
    assert(ex1 - ex2 < 1e-6) : ex1 + " " + ex2;
  }
 
  public LogNumber getAlpha(int ip, int s1, int s2) {
    return lattice[ip][s1].alpha[s2];
  }
 
  public LogNumber getBeta(int ip, int s1, int s2) {
    return lattice[ip][s1].beta[s2];
  }
 
  /**
   * Contains forward-backward vectors correspoding to an input position and a
   * state index.
   */
  protected class LatticeNode {
    // ip -> input position, a vector of doubles since for each node we need to
    // keep track of the alpha, beta values of state@(ip+1)
    protected LogNumber[] alpha;
    protected LogNumber[] beta;

    public LatticeNode() {
      alpha = new LogNumber[numStates];
      beta = new LogNumber[numStates];
      for (int si = 0; si < numStates; ++si) {
        alpha[si] = new LogNumber(Transducer.IMPOSSIBLE_WEIGHT,true);
        beta[si] new LogNumber(Transducer.IMPOSSIBLE_WEIGHT,true);
      }
    }
  }
}
TOP

Related Classes of cc.mallet.fst.semi_supervised.GELattice$LatticeNode

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.