Package cc.mallet.fst.semi_supervised.pr.constraints

Source Code of cc.mallet.fst.semi_supervised.pr.constraints.OneLabelL2PRConstraints$OneLabelPRConstraint

/* Copyright (C) 2011 Univ. of Massachusetts Amherst, Computer Science Dept.
   This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit).
   http://www.cs.umass.edu/~mccallum/mallet
   This software is provided under the terms of the Common Public License,
   version 1.0, as published by http://www.opensource.org.  For further
   information, see the file `LICENSE' included with this distribution. */

package cc.mallet.fst.semi_supervised.pr.constraints;

import gnu.trove.TIntArrayList;
import gnu.trove.TIntIntHashMap;
import gnu.trove.TIntObjectHashMap;

import java.util.BitSet;
import cc.mallet.fst.semi_supervised.StateLabelMap;
import cc.mallet.types.FeatureVector;
import cc.mallet.types.FeatureVectorSequence;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;

/**
* A set of constraints on distributions over single
* labels conditioned on the presence of input features.
*
* This is to be used with PR, and penalizes
* L_2^2 difference from target expectations.
*
* Multiple constraints are grouped together here
* to make things more efficient.
*
* @author Gregory Druck
*/

public class OneLabelL2PRConstraints implements PRConstraint {

  // maps between input feature indices and constraints
  protected TIntObjectHashMap<OneLabelPRConstraint> constraints;
  // maps between input feature indices and constraint indices
  protected TIntIntHashMap constraintIndices;
  protected StateLabelMap map;
  protected boolean normalized;
 
  // cache of set of constrained features that fire at last FeatureVector
  // provided in preprocess call
  protected TIntArrayList cache;

  public OneLabelL2PRConstraints(boolean normalized) {
    this.constraints = new TIntObjectHashMap<OneLabelPRConstraint>();
    this.constraintIndices = new TIntIntHashMap();
    this.cache = new TIntArrayList();
    this.normalized = normalized;
  }
 
  protected OneLabelL2PRConstraints(TIntObjectHashMap<OneLabelPRConstraint> constraints,
      TIntIntHashMap constraintIndices, StateLabelMap map, boolean normalized) {
    this.constraints = new TIntObjectHashMap<OneLabelPRConstraint>();
    for (int key : constraints.keys()) {
      this.constraints.put(key, constraints.get(key).copy());
    }
   
    //this.constraints = constraints;
    this.constraintIndices = constraintIndices;
    this.map = map;
    this.cache = new TIntArrayList();
    this.normalized = normalized;
  }
 
  public PRConstraint copy() {
    return new OneLabelL2PRConstraints(this.constraints, this.constraintIndices, this.map, this.normalized);
  }

  public void addConstraint(int fi, double[] target, double weight) {
    constraints.put(fi,new OneLabelPRConstraint(target,weight));
    constraintIndices.put(fi, constraintIndices.size());
  }
 
  public int numDimensions() {
    assert(map != null);
    return map.getNumLabels() * constraints.size();
  }
 
  public boolean isOneStateConstraint() {
    return true;
  }
 
  public void setStateLabelMap(StateLabelMap map) {
    this.map = map;
  }
 
  public void preProcess(FeatureVector fv) {
    cache.resetQuick();
    int fi;
    // cache constrained input features
    for (int loc = 0; loc < fv.numLocations(); loc++) {
      fi = fv.indexAtLocation(loc);
      if (constraints.containsKey(fi)) {
        cache.add(fi);
      }
    }
  }
 
  // find examples that contain constrained input features
  public BitSet preProcess(InstanceList data) {
    // count
    int ii = 0;
    int fi;
    FeatureVector fv;
    BitSet bitSet = new BitSet(data.size());
    for (Instance instance : data) {
      FeatureVectorSequence fvs = (FeatureVectorSequence)instance.getData();
      for (int ip = 0; ip < fvs.size(); ip++) {
        fv = fvs.get(ip);
        for (int loc = 0; loc < fv.numLocations(); loc++) {
          fi = fv.indexAtLocation(loc);
          if (constraints.containsKey(fi)) {
            constraints.get(fi).count += 1;
            bitSet.set(ii);
          }
        }
      }
      ii++;
    }
    return bitSet;
  }   
 
  public double getScore(FeatureVector input, int inputPosition,
      int srcIndex, int destIndex, double[] parameters) {
    double dot = 0;
    int li2 = map.getLabelIndex(destIndex);
    for (int i = 0; i < cache.size(); i++) {
      int j = constraintIndices.get(cache.getQuick(i));
      // TODO binary features
      if (normalized) {
        dot += parameters[j + constraints.size() * li2] / constraints.get(cache.getQuick(i)).count;
      }
      else {
        dot += parameters[j + constraints.size() * li2];
      }
    }
    return dot;
  }

  public void incrementExpectations(FeatureVector input, int inputPosition,
      int srcIndex, int destIndex, double prob) {
    int li2 = map.getLabelIndex(destIndex);
    for (int i = 0; i < cache.size(); i++) {
      constraints.get(cache.getQuick(i)).expectation[li2] += prob;
    }
  }
 
  public void getExpectations(double[] expectations) {
    assert(expectations.length == numDimensions());
    for (int fi : constraintIndices.keys()) {
      int ci = constraintIndices.get(fi);
      OneLabelPRConstraint constraint = constraints.get(fi);
      for (int li = 0; li < constraint.expectation.length; li++) {
        expectations[ci + li * constraints.size()] = constraint.expectation[li];
      }
    }
  }
 
  public void addExpectations(double[] expectations) {
    assert(expectations.length == numDimensions());
    for (int fi : constraintIndices.keys()) {
      int ci = constraintIndices.get(fi);
      OneLabelPRConstraint constraint = constraints.get(fi);
      for (int li = 0; li < constraint.expectation.length; li++) {
        constraint.expectation[li] += expectations[ci + li * constraints.size()];
      }
    }
  }

  public void zeroExpectations() {
    for (int fi : constraints.keys()) {
      constraints.get(fi).expectation = new double[map.getNumLabels()];
    }
  }

  public double getAuxiliaryValueContribution(double[] parameters) {
    double value = 0;
    for (int fi : constraints.keys()) {
      int ci = constraintIndices.get(fi);
      for (int li = 0; li < map.getNumLabels(); li++) {
        double param = parameters[ci + li * constraints.size()];
        value += constraints.get(fi).target[li] * param - (param * param) / (2 * constraints.get(fi).weight);
      }
    }
    return value;
  }

  // TODO
  public double getCompleteValueContribution(double[] parameters) {
    double value = 0;
    for (int fi : constraints.keys()) {
      OneLabelPRConstraint constraint = constraints.get(fi);
      for (int li = 0; li < map.getNumLabels(); li++) {
        if (normalized) {
          value +=  constraint.weight * Math.pow(constraint.target[li] - constraint.expectation[li]/constraint.count,2) / 2;
        }
        else {
          value +=  constraint.weight * Math.pow(constraint.target[li] - constraint.expectation[li],2) / 2;
        }
      }
    }
    return value;
  }

  public void getGradient(double[] parameters, double[] gradient) {
    for (int fi : constraints.keys()) {
      int ci = constraintIndices.get(fi);
      OneLabelPRConstraint constraint = constraints.get(fi);
      for (int li = 0; li < map.getNumLabels(); li++) {
        if (normalized) {
          gradient[ci + li * constraints.size()] =
            constraint.target[li] - constraint.expectation[li] / constraint.count -
            parameters[ci + li * constraints.size()] / constraint.weight;
        }
        else {
          gradient[ci + li * constraints.size()] =
            constraint.target[li] - constraint.expectation[li] -
            parameters[ci + li * constraints.size()] / constraint.weight;
        }
      }
    }
  }
 
  protected class OneLabelPRConstraint {
   
    protected double[] target;
    protected double[] expectation;
    protected double count;
    protected double weight;
   
    public OneLabelPRConstraint(double[] target, double weight) {
      this.target = target;
      this.weight = weight;
      this.expectation = null;
      this.count = 0;
    }
   
    public OneLabelPRConstraint copy() {
      OneLabelPRConstraint copy = new OneLabelPRConstraint(target,weight);
      copy.count = count;
      copy.expectation = new double[target.length];
      return copy;
    }
  }
}
TOP

Related Classes of cc.mallet.fst.semi_supervised.pr.constraints.OneLabelL2PRConstraints$OneLabelPRConstraint

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.