Package edu.stanford.nlp.ie.crf

Source Code of edu.stanford.nlp.ie.crf.CRFClassifierNonlinear

// CRFClassifier -- a probabilistic (CRF) sequence model, mainly used for NER.
// Copyright (c) 2002-2008 The Board of Trustees of
// The Leland Stanford Junior University. All Rights Reserved.
//
// This program is free software; you can redistribute it and/or
// modify it under the terms of the GNU General Public License
// as published by the Free Software Foundation; either version 2
// of the License, or (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with this program; if not, write to the Free Software
// Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA  02111-1307, USA.
//
// For more information, bug reports, fixes, contact:
//    Christopher Manning
//    Dept of Computer Science, Gates 1A
//    Stanford CA 94305-9010
//    USA
//    Support/Questions: java-nlp-user@lists.stanford.edu
//    Licensing: java-nlp-support@lists.stanford.edu

package edu.stanford.nlp.ie.crf;

import edu.stanford.nlp.io.RuntimeIOException;
import edu.stanford.nlp.optimization.*;
import edu.stanford.nlp.sequences.*;
import edu.stanford.nlp.util.*;

import java.io.*;
import java.util.*;
import java.util.zip.GZIPInputStream;

/**
* Subclass of {@link edu.stanford.nlp.ie.crf.CRFClassifier} for implementing the nonlinear architecture in [Wang and Manning IJCNLP-2013 Effect of Nonlinear ...].
*
* @author Mengqiu Wang
*/
public class CRFClassifierNonlinear<IN extends CoreMap> extends CRFClassifier<IN> {

  /** Parameter weights of the classifier. */
  double[][] linearWeights;
  double[][] inputLayerWeights4Edge;
  double[][] outputLayerWeights4Edge;
  double[][] inputLayerWeights;
  double[][] outputLayerWeights;

  protected CRFClassifierNonlinear() {
    super(new SeqClassifierFlags());
  }

  public CRFClassifierNonlinear(Properties props) {
    super(props);
  }

  public CRFClassifierNonlinear(SeqClassifierFlags flags) {
    super(flags);
  }

  @Override
  public Triple<int[][][], int[], double[][][]> documentToDataAndLabels(List<IN> document) {
    Triple<int[][][], int[], double[][][]> result = super.documentToDataAndLabels(document);
    int[][][] data = result.first();
    data = transformDocData(data);

    return new Triple<int[][][], int[], double[][][]>(data, result.second(), result.third());
  }

  private int[][][] transformDocData(int[][][] docData) {
    int[][][] transData = new int[docData.length][][];
    for (int i = 0; i < docData.length; i++) {
      transData[i] = new int[docData[i].length][];
      for (int j = 0; j < docData[i].length; j++) {
        int[] cliqueFeatures = docData[i][j];
        transData[i][j] = new int[cliqueFeatures.length];
        for (int n = 0; n < cliqueFeatures.length; n++) {
          int transFeatureIndex = -1;
          if (j == 0) {
            transFeatureIndex = nodeFeatureIndicesMap.indexOf(cliqueFeatures[n]);
            if (transFeatureIndex == -1)
              throw new RuntimeException("node cliqueFeatures[n]="+cliqueFeatures[n]+" not found, nodeFeatureIndicesMap.size="+nodeFeatureIndicesMap.size());
          } else {
            transFeatureIndex = edgeFeatureIndicesMap.indexOf(cliqueFeatures[n]);
            if (transFeatureIndex == -1)
              throw new RuntimeException("edge cliqueFeatures[n]="+cliqueFeatures[n]+" not found, edgeFeatureIndicesMap.size="+edgeFeatureIndicesMap.size());
          }
          transData[i][j][n] = transFeatureIndex;
        }
      }
    }
    return transData;
  }

  @Override
  protected CliquePotentialFunction getCliquePotentialFunctionForTest() {
    if (cliquePotentialFunction == null) {
      if (flags.secondOrderNonLinear)
        cliquePotentialFunction = new NonLinearSecondOrderCliquePotentialFunction(inputLayerWeights4Edge, outputLayerWeights4Edge, inputLayerWeights, outputLayerWeights, flags);
      else
        cliquePotentialFunction = new NonLinearCliquePotentialFunction(linearWeights, inputLayerWeights, outputLayerWeights, flags);
    }
    return cliquePotentialFunction;
  }

  @Override
  protected double[] trainWeights(int[][][][] data, int[][] labels, Evaluator[] evaluators, int pruneFeatureItr, double[][][][] featureVals) {
    if (flags.secondOrderNonLinear) {
      CRFNonLinearSecondOrderLogConditionalObjectiveFunction func = new CRFNonLinearSecondOrderLogConditionalObjectiveFunction(data, labels,
        windowSize, classIndex, labelIndices, map, flags, nodeFeatureIndicesMap.size(), edgeFeatureIndicesMap.size());
      cliquePotentialFunctionHelper = func;
      double[] allWeights = trainWeightsUsingNonLinearCRF(func, evaluators);
      Quadruple<double[][], double[][], double[][], double[][]> params = func.separateWeights(allWeights);
      this.inputLayerWeights4Edge = params.first();
      this.outputLayerWeights4Edge = params.second();
      this.inputLayerWeights = params.third();
      this.outputLayerWeights = params.fourth();

    } else {
      CRFNonLinearLogConditionalObjectiveFunction func = new CRFNonLinearLogConditionalObjectiveFunction(data, labels,
        windowSize, classIndex, labelIndices, map, flags, nodeFeatureIndicesMap.size(), edgeFeatureIndicesMap.size(), featureVals);
      if (flags.useAdaGradFOBOS) {
        func.gradientsOnly = true;
      }
      cliquePotentialFunctionHelper = func;
      double[] allWeights = trainWeightsUsingNonLinearCRF(func, evaluators);
      Triple<double[][], double[][], double[][]> params = func.separateWeights(allWeights);
      this.linearWeights = params.first();
      this.inputLayerWeights = params.second();
      this.outputLayerWeights = params.third();
    }

    return null;
  }

  private double[] trainWeightsUsingNonLinearCRF(AbstractCachingDiffFunction func, Evaluator[] evaluators) {
    Minimizer minimizer = getMinimizer(0, evaluators);

    double[] initialWeights;
    if (flags.initialWeights == null) {
      initialWeights = func.initial();
    } else {
      try {
        System.err.println("Reading initial weights from file " + flags.initialWeights);
        DataInputStream dis = new DataInputStream(new BufferedInputStream(new GZIPInputStream(new FileInputStream(
            flags.initialWeights))));
        initialWeights = ConvertByteArray.readDoubleArr(dis);
      } catch (IOException e) {
        throw new RuntimeException("Could not read from double initial weight file " + flags.initialWeights);
      }
    }
    System.err.println("numWeights: " + initialWeights.length);

    if (flags.testObjFunction) {
      StochasticDiffFunctionTester tester = new StochasticDiffFunctionTester(func);
      if (tester.testSumOfBatches(initialWeights, 1e-4)) {
        System.err.println("Testing complete... exiting");
        System.exit(1);
      } else {
        System.err.println("Testing failed....exiting");
        System.exit(1);
      }

    }
    //check gradient
    if (flags.checkGradient) {
      if (func.gradientCheck()) {
        System.err.println("gradient check passed");
      } else {
        throw new RuntimeException("gradient check failed");
      }
    }
    return minimizer.minimize(func, flags.tolerance, initialWeights);
  }

  @Override
  protected void serializeTextClassifier(PrintWriter pw) throws Exception {
    super.serializeTextClassifier(pw);

    pw.printf("nodeFeatureIndicesMap.size()=\t%d%n", nodeFeatureIndicesMap.size());
    for (int i = 0; i < nodeFeatureIndicesMap.size(); i++) {
      pw.printf("%d\t%d%n", i, nodeFeatureIndicesMap.get(i));
    }

    pw.printf("edgeFeatureIndicesMap.size()=\t%d%n", edgeFeatureIndicesMap.size());
    for (int i = 0; i < edgeFeatureIndicesMap.size(); i++) {
      pw.printf("%d\t%d%n", i, edgeFeatureIndicesMap.get(i));
    }

    if (flags.secondOrderNonLinear) {
      pw.printf("inputLayerWeights4Edge.length=\t%d%n", inputLayerWeights4Edge.length);
      for (double[] ws : inputLayerWeights4Edge) {
        ArrayList<Double> list = new ArrayList<Double>();
        for (double w : ws) {
          list.add(w);
        }
        pw.printf("%d\t%s%n", ws.length, StringUtils.join(list, " "));
      }
      pw.printf("outputLayerWeights4Edge.length=\t%d%n", outputLayerWeights4Edge.length);
      for (double[] ws : outputLayerWeights4Edge) {
        ArrayList<Double> list = new ArrayList<Double>();
        for (double w : ws) {
          list.add(w);
        }
        pw.printf("%d\t%s%n", ws.length, StringUtils.join(list, " "));
      }
    } else {
      pw.printf("linearWeights.length=\t%d%n", linearWeights.length);
      for (double[] ws : linearWeights) {
        ArrayList<Double> list = new ArrayList<Double>();
        for (double w : ws) {
          list.add(w);
        }
        pw.printf("%d\t%s%n", ws.length, StringUtils.join(list, " "));
      }
    }
    pw.printf("inputLayerWeights.length=\t%d%n", inputLayerWeights.length);
    for (double[] ws : inputLayerWeights) {
      ArrayList<Double> list = new ArrayList<Double>();
      for (double w : ws) {
        list.add(w);
      }
      pw.printf("%d\t%s%n", ws.length, StringUtils.join(list, " "));
    }
    pw.printf("outputLayerWeights.length=\t%d%n", outputLayerWeights.length);
    for (double[] ws : outputLayerWeights) {
      ArrayList<Double> list = new ArrayList<Double>();
      for (double w : ws) {
        list.add(w);
      }
      pw.printf("%d\t%s%n", ws.length, StringUtils.join(list, " "));
    }
  }

  @Override
  protected void loadTextClassifier(BufferedReader br) throws Exception {
    super.loadTextClassifier(br);

    String line = br.readLine();
    String[] toks = line.split("\\t");
    if (!toks[0].equals("nodeFeatureIndicesMap.size()=")) {
      throw new RuntimeException("format error in nodeFeatureIndicesMap");
    }
    int nodeFeatureIndicesMapSize = Integer.parseInt(toks[1]);
    nodeFeatureIndicesMap = new HashIndex<Integer>();
    int count = 0;
    while (count < nodeFeatureIndicesMapSize) {
      line = br.readLine();
      toks = line.split("\\t");
      int idx = Integer.parseInt(toks[0]);
      if (count != idx) {
        throw new RuntimeException("format error");
      }
      nodeFeatureIndicesMap.add(Integer.parseInt(toks[1]));
      count++;
    }

    line = br.readLine();
    toks = line.split("\\t");
    if (!toks[0].equals("edgeFeatureIndicesMap.size()=")) {
      throw new RuntimeException("format error");
    }
    int edgeFeatureIndicesMapSize = Integer.parseInt(toks[1]);
    edgeFeatureIndicesMap = new HashIndex<Integer>();
    count = 0;
    while (count < edgeFeatureIndicesMapSize) {
      line = br.readLine();
      toks = line.split("\\t");
      int idx = Integer.parseInt(toks[0]);
      if (count != idx) {
        throw new RuntimeException("format error");
      }
      edgeFeatureIndicesMap.add(Integer.parseInt(toks[1]));
      count++;
    }

    int  weightsLength = -1;
    if (flags.secondOrderNonLinear) {
      line = br.readLine();
      toks = line.split("\\t");
      if (!toks[0].equals("inputLayerWeights4Edge.length=")) {
        throw new RuntimeException("format error");
      }
      weightsLength = Integer.parseInt(toks[1]);
      inputLayerWeights4Edge = new double[weightsLength][];
      count = 0;
      while (count < weightsLength) {
        line = br.readLine();

        toks = line.split("\\t");
        int weights2Length = Integer.parseInt(toks[0]);
        inputLayerWeights4Edge[count] = new double[weights2Length];
        String[] weightsValue = toks[1].split(" ");
        if (weights2Length != weightsValue.length) {
          throw new RuntimeException("weights format error");
        }

        for (int i2 = 0; i2 < weights2Length; i2++) {
          inputLayerWeights4Edge[count][i2] = Double.parseDouble(weightsValue[i2]);
        }
        count++;
      }
      line = br.readLine();

      toks = line.split("\\t");
      if (!toks[0].equals("outputLayerWeights4Edge.length=")) {
        throw new RuntimeException("format error");
      }
      weightsLength = Integer.parseInt(toks[1]);
      outputLayerWeights4Edge = new double[weightsLength][];
      count = 0;
      while (count < weightsLength) {
        line = br.readLine();

        toks = line.split("\\t");
        int weights2Length = Integer.parseInt(toks[0]);
        outputLayerWeights4Edge[count] = new double[weights2Length];
        String[] weightsValue = toks[1].split(" ");
        if (weights2Length != weightsValue.length) {
          throw new RuntimeException("weights format error");
        }

        for (int i2 = 0; i2 < weights2Length; i2++) {
          outputLayerWeights4Edge[count][i2] = Double.parseDouble(weightsValue[i2]);
        }
        count++;
      }
    } else {
      line = br.readLine();

      toks = line.split("\\t");
      if (!toks[0].equals("linearWeights.length=")) {
        throw new RuntimeException("format error");
      }
      weightsLength = Integer.parseInt(toks[1]);
      linearWeights = new double[weightsLength][];
      count = 0;
      while (count < weightsLength) {
        line = br.readLine();

        toks = line.split("\\t");
        int weights2Length = Integer.parseInt(toks[0]);
        linearWeights[count] = new double[weights2Length];
        String[] weightsValue = toks[1].split(" ");
        if (weights2Length != weightsValue.length) {
          throw new RuntimeException("weights format error");
        }

        for (int i2 = 0; i2 < weights2Length; i2++) {
          linearWeights[count][i2] = Double.parseDouble(weightsValue[i2]);
        }
        count++;
      }
    }

    line = br.readLine();

    toks = line.split("\\t");
    if (!toks[0].equals("inputLayerWeights.length=")) {
      throw new RuntimeException("format error");
    }
    weightsLength = Integer.parseInt(toks[1]);
    inputLayerWeights = new double[weightsLength][];
    count = 0;
    while (count < weightsLength) {
      line = br.readLine();

      toks = line.split("\\t");
      int weights2Length = Integer.parseInt(toks[0]);
      inputLayerWeights[count] = new double[weights2Length];
      String[] weightsValue = toks[1].split(" ");
      if (weights2Length != weightsValue.length) {
        throw new RuntimeException("weights format error");
      }

      for (int i2 = 0; i2 < weights2Length; i2++) {
        inputLayerWeights[count][i2] = Double.parseDouble(weightsValue[i2]);
      }
      count++;
    }
    line = br.readLine();

    toks = line.split("\\t");
    if (!toks[0].equals("outputLayerWeights.length=")) {
      throw new RuntimeException("format error");
    }
    weightsLength = Integer.parseInt(toks[1]);
    outputLayerWeights = new double[weightsLength][];
    count = 0;
    while (count < weightsLength) {
      line = br.readLine();

      toks = line.split("\\t");
      int weights2Length = Integer.parseInt(toks[0]);
      outputLayerWeights[count] = new double[weights2Length];
      String[] weightsValue = toks[1].split(" ");
      if (weights2Length != weightsValue.length) {
        throw new RuntimeException("weights format error");
      }

      for (int i2 = 0; i2 < weights2Length; i2++) {
        outputLayerWeights[count][i2] = Double.parseDouble(weightsValue[i2]);
      }
      count++;
    }
  }

  @Override
  public void serializeClassifier(ObjectOutputStream oos) {
    try {
      super.serializeClassifier(oos);
      oos.writeObject(nodeFeatureIndicesMap);
      oos.writeObject(edgeFeatureIndicesMap);
      if (flags.secondOrderNonLinear) {
        oos.writeObject(inputLayerWeights4Edge);
        oos.writeObject(outputLayerWeights4Edge);
      } else {
        oos.writeObject(linearWeights);
      }
      oos.writeObject(inputLayerWeights);
      oos.writeObject(outputLayerWeights);
    } catch (IOException e) {
      throw new RuntimeIOException(e);
    }
  }

  @Override
  @SuppressWarnings( { "unchecked" })
  // can't have right types in deserialization
  public void loadClassifier(ObjectInputStream ois, Properties props) throws ClassCastException, IOException,
      ClassNotFoundException {

    super.loadClassifier(ois, props);

    nodeFeatureIndicesMap = (Index<Integer>) ois.readObject();
    edgeFeatureIndicesMap = (Index<Integer>) ois.readObject();
    if (flags.secondOrderNonLinear) {
      inputLayerWeights4Edge = (double[][]) ois.readObject();
      outputLayerWeights4Edge = (double[][]) ois.readObject();
    } else {
      linearWeights = (double[][]) ois.readObject();
    }
    inputLayerWeights = (double[][]) ois.readObject();
    outputLayerWeights = (double[][]) ois.readObject();
  }

} // end class CRFClassifierNonlinear
TOP

Related Classes of edu.stanford.nlp.ie.crf.CRFClassifierNonlinear

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.