Package org.maltparserx.ml.lib

Source Code of org.maltparserx.ml.lib.Lib

package org.maltparserx.ml.lib;

import java.io.BufferedOutputStream;
import java.io.BufferedReader;
import java.io.BufferedWriter;
import java.io.File;
import java.io.FileNotFoundException;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.OutputStream;

import java.io.OutputStreamWriter;
import java.util.ArrayList;

import org.apache.log4j.Logger;

import java.util.LinkedHashMap;
import java.util.Set;
import java.util.jar.JarEntry;
import java.util.regex.Pattern;
import java.util.regex.PatternSyntaxException;


import org.maltparserx.core.exception.MaltChainedException;
import org.maltparserx.core.feature.FeatureVector;
import org.maltparserx.core.feature.function.FeatureFunction;
import org.maltparserx.core.feature.value.FeatureValue;
import org.maltparserx.core.feature.value.MultipleFeatureValue;
import org.maltparserx.core.feature.value.SingleFeatureValue;
import org.maltparserx.core.syntaxgraph.DependencyStructure;
import org.maltparserx.ml.LearningMethod;
import org.maltparserx.parser.DependencyParserConfig;
import org.maltparserx.parser.guide.instance.InstanceModel;
import org.maltparserx.parser.history.action.SingleDecision;

public abstract class Lib implements LearningMethod {
  protected Verbostity verbosity;
  public enum Verbostity {
    SILENT, ERROR, ALL
  }
  protected InstanceModel owner;
  protected int learnerMode;
  protected String name;
  protected int numberOfInstances;
  protected boolean saveInstanceFiles;
  protected boolean excludeNullValues;
  protected BufferedWriter instanceOutput = null;
  protected FeatureMap featureMap;
  protected String paramString;
  protected String pathExternalTrain;
  protected LinkedHashMap<String, String> libOptions;
  protected String allowedLibOptionFlags;
  protected Logger configLogger;
  protected final Pattern tabPattern = Pattern.compile("\t");
  protected final Pattern pipePattern = Pattern.compile("\\|")
  private final StringBuilder sb = new StringBuilder();
  protected MaltLibModel model = null;
  /**
   * Constructs a Lib learner.
   *
   * @param owner the guide model owner
   * @param learnerMode the mode of the learner BATCH or CLASSIFY
   */
  public Lib(InstanceModel owner, Integer learnerMode, String learningMethodName) throws MaltChainedException {
    setOwner(owner);
    setLearnerMode(learnerMode.intValue());
    setNumberOfInstances(0);
    setLearningMethodName(learningMethodName);
    verbosity = Verbostity.SILENT;
    configLogger = owner.getGuide().getConfiguration().getConfigLogger();
    initLibOptions();
    initAllowedLibOptionFlags();
    parseParameters(getConfiguration().getOptionValue("lib", "options").toString());
    initSpecialParameters();
   
    if (learnerMode == BATCH) {
      featureMap = new FeatureMap();
      instanceOutput = new BufferedWriter(getInstanceOutputStreamWriter(".ins"));
    } else if (learnerMode == CLASSIFY) {
      featureMap = loadFeatureMap(getInputStreamFromConfigFileEntry(".map"));
    }
  }
 
 
  public void addInstance(SingleDecision decision, FeatureVector featureVector) throws MaltChainedException {
    if (featureVector == null) {
      throw new LibException("The feature vector cannot be found");
    } else if (decision == null) {
      throw new LibException("The decision cannot be found");
   
   
    try {
      sb.append(decision.getDecisionCode()+"\t");
      final int n = featureVector.size();
      for (int i = 0; i < n; i++) {
        FeatureValue featureValue = featureVector.getFeatureValue(i);
        if (featureValue == null || (excludeNullValues == true && featureValue.isNullValue())) {
          sb.append("-1");
        } else {
          if (!featureValue.isMultiple()) {
            SingleFeatureValue singleFeatureValue = (SingleFeatureValue)featureValue;
            if (singleFeatureValue.getValue() == 1) {
              sb.append(singleFeatureValue.getIndexCode());
            } else if (singleFeatureValue.getValue() == 0) {
              sb.append("-1");
            } else {
              sb.append(singleFeatureValue.getIndexCode());
              sb.append(":");
              sb.append(singleFeatureValue.getValue());
            }
          } else { //if (featureValue instanceof MultipleFeatureValue) {
            Set<Integer> values = ((MultipleFeatureValue)featureValue).getCodes();
            int j=0;
            for (Integer value : values) {
              sb.append(value.toString());
              if (j != values.size()-1) {
                sb.append("|");
              }
              j++;
            }
          }
//          else {
//            throw new LibException("Don't recognize the type of feature value: "+featureValue.getClass());
//          }
        }
        sb.append('\t');
      }
      sb.append('\n');
      instanceOutput.write(sb.toString());
      instanceOutput.flush();
      increaseNumberOfInstances();
      sb.setLength(0);
    } catch (IOException e) {
      throw new LibException("The learner cannot write to the instance file. ", e);
    }
  }

  public void finalizeSentence(DependencyStructure dependencyGraph) throws MaltChainedException { }

  public void moveAllInstances(LearningMethod method,
      FeatureFunction divideFeature,
      ArrayList<Integer> divideFeatureIndexVector)
      throws MaltChainedException {
    if (method == null) {
      throw new LibException("The learning method cannot be found. ");
    } else if (divideFeature == null) {
      throw new LibException("The divide feature cannot be found. ");
    }
   
    try {
      final BufferedReader in = new BufferedReader(getInstanceInputStreamReader(".ins"));
      final BufferedWriter out = method.getInstanceWriter();
      final StringBuilder sb = new StringBuilder(6);
      int l = in.read();
      char c;
      int j = 0;
 
      while(true) {
        if (l == -1) {
          sb.setLength(0);
          break;
        }
        c = (char)l;
        l = in.read();
        if (c == '\t') {
          if (divideFeatureIndexVector.contains(j-1)) {
            out.write(Integer.toString(((SingleFeatureValue)divideFeature.getFeatureValue()).getIndexCode()));
            out.write('\t');
          }
          out.write(sb.toString());
          j++;
          out.write('\t');
          sb.setLength(0);
        } else if (c == '\n') {
          out.write(sb.toString());
          if (divideFeatureIndexVector.contains(j-1)) {
            out.write('\t');
            out.write(Integer.toString(((SingleFeatureValue)divideFeature.getFeatureValue()).getIndexCode()));
          }
          out.write('\n');
          sb.setLength(0);
          method.increaseNumberOfInstances();
          this.decreaseNumberOfInstances();
          j = 0;
        } else {
          sb.append(c);
        }
      } 
      in.close();
      getFile(".ins").delete();
      out.flush();
    } catch (SecurityException e) {
      throw new LibException("The learner cannot remove the instance file. ", e);
    } catch (NullPointerException  e) {
      throw new LibException("The instance file cannot be found. ", e);
    } catch (FileNotFoundException e) {
      throw new LibException("The instance file cannot be found. ", e);
    } catch (IOException e) {
      throw new LibException("The learner read from the instance file. ", e);
    }
  }

  public void noMoreInstances() throws MaltChainedException {
    closeInstanceWriter();
  }

  public boolean predict(FeatureVector featureVector, SingleDecision decision) throws MaltChainedException {
//    if (featureVector == null) {
//      throw new LibException("The learner cannot predict the next class, because the feature vector cannot be found. ");
//    }
    final FeatureList featureList = new FeatureList();
    final int size = featureVector.size();
    for (int i = 1; i <= size; i++) {
      final FeatureValue featureValue = featureVector.getFeatureValue(i-1)
      if (featureValue != null && !(excludeNullValues == true && featureValue.isNullValue())) {
        if (!featureValue.isMultiple()) {
          SingleFeatureValue singleFeatureValue = (SingleFeatureValue)featureValue;
          final int index = featureMap.getIndex(i, singleFeatureValue.getIndexCode());
          if (index != -1 && singleFeatureValue.getValue() != 0) {
            featureList.add(index,singleFeatureValue.getValue());
          }
        }
        else { //if (featureValue instanceof MultipleFeatureValue) {
          for (Integer value : ((MultipleFeatureValue)featureValue).getCodes()) {
            final int v = featureMap.getIndex(i, value);
            if (v != -1) {
              featureList.add(v,1);
            }
          }
        }
      }
    }
    try {
      decision.getKBestList().addList(model.predict(featureList.toArray()));
    } catch (OutOfMemoryError e) {
      throw new LibException("Out of memory. Please increase the Java heap size (-Xmx<size>). ", e);
    }
    return true;
  }
   
//  protected abstract int[] prediction(FeatureList featureList) throws MaltChainedException;
 
  public void train(FeatureVector featureVector) throws MaltChainedException {
    if (featureVector == null) {
      throw new LibException("The feature vector cannot be found. ");
    } else if (owner == null) {
      throw new LibException("The parent guide model cannot be found. ");
    }
    long startTime = System.currentTimeMillis();
   
//    if (configLogger.isInfoEnabled()) {
//      configLogger.info("\nStart training\n");
//    }
    if (pathExternalTrain != null) {
      trainExternal(featureVector);
    } else {
      trainInternal(featureVector);
    }
//    long elapsed = System.currentTimeMillis() - startTime;
//    if (configLogger.isInfoEnabled()) {
//      configLogger.info("Time 1: " +new Formatter().format("%02d:%02d:%02d", elapsed/3600000, elapsed%3600000/60000, elapsed%60000/1000)+" ("+elapsed+" ms)\n");
//    }
    try {
//      if (configLogger.isInfoEnabled()) {
//        configLogger.info("\nSaving feature map "+getFile(".map").getName()+"\n");
//      }
      saveFeatureMap(new BufferedOutputStream(new FileOutputStream(getFile(".map").getAbsolutePath())), featureMap);
    } catch (FileNotFoundException e) {
      throw new LibException("The learner cannot save the feature map file '"+getFile(".map").getAbsolutePath()+"'. ", e);
    }
//    elapsed = System.currentTimeMillis() - startTime;
//    if (configLogger.isInfoEnabled()) {
//      configLogger.info("Time 2: " +new Formatter().format("%02d:%02d:%02d", elapsed/3600000, elapsed%3600000/60000, elapsed%60000/1000)+" ("+elapsed+" ms)\n");
//    }
  }
  protected abstract void trainExternal(FeatureVector featureVector) throws MaltChainedException;
  protected abstract void trainInternal(FeatureVector featureVector) throws MaltChainedException;
 
  public void terminate() throws MaltChainedException {
    closeInstanceWriter();
    owner = null;
    model = null;
  }

  public BufferedWriter getInstanceWriter() {
    return instanceOutput;
  }
 
  protected void closeInstanceWriter() throws MaltChainedException {
    try {
      if (instanceOutput != null) {
        instanceOutput.flush();
        instanceOutput.close();
        instanceOutput = null;
      }
    } catch (IOException e) {
      throw new LibException("The learner cannot close the instance file. ", e);
    }
  }
 
 
  /**
   * Returns the parameter string used for configure the learner
   *
   * @return the parameter string used for configure the learner
   */
  public String getParamString() {
    return paramString;
  }
 
  public InstanceModel getOwner() {
    return owner;
  }

  protected void setOwner(InstanceModel owner) {
    this.owner = owner;
  }
 
  public int getLearnerMode() {
    return learnerMode;
  }

  public void setLearnerMode(int learnerMode) throws MaltChainedException {
    this.learnerMode = learnerMode;
  }
 
  public String getLearningMethodName() {
    return name;
  }
 
  /**
   * Returns the current configuration
   *
   * @return the current configuration
   * @throws MaltChainedException
   */
  public DependencyParserConfig getConfiguration() throws MaltChainedException {
    return owner.getGuide().getConfiguration();
  }
 
  public int getNumberOfInstances() throws MaltChainedException {
    if(numberOfInstances!=0)
      return numberOfInstances;
    else{
      BufferedReader reader = new BufferedReader( getInstanceInputStreamReader(".ins"));
      try {
        while(reader.readLine()!=null){
          numberOfInstances++;
          owner.increaseFrequency();
        }
        reader.close();
      } catch (IOException e) {
        throw new MaltChainedException("No instances found in file",e);
      }
      return numberOfInstances;
    }
  }

  public void increaseNumberOfInstances() {
    numberOfInstances++;
    owner.increaseFrequency();
  }
 
  public void decreaseNumberOfInstances() {
    numberOfInstances--;
    owner.decreaseFrequency();
  }
 
  protected void setNumberOfInstances(int numberOfInstances) {
    this.numberOfInstances = 0;
  }

  protected void setLearningMethodName(String name) {
    this.name = name;
  }
 
  public String getPathExternalTrain() {
    return pathExternalTrain;
  }


  public void setPathExternalTrain(String pathExternalTrain) {
    this.pathExternalTrain = pathExternalTrain;
  }

  protected OutputStreamWriter getInstanceOutputStreamWriter(String suffix) throws MaltChainedException {
    return getConfiguration().getConfigurationDir().getAppendOutputStreamWriter(owner.getModelName()+getLearningMethodName()+suffix);
  }
 
  protected InputStreamReader getInstanceInputStreamReader(String suffix) throws MaltChainedException {
    return getConfiguration().getConfigurationDir().getInputStreamReader(owner.getModelName()+getLearningMethodName()+suffix);
  }
 
  protected InputStreamReader getInstanceInputStreamReaderFromConfigFile(String suffix) throws MaltChainedException {
    return getConfiguration().getConfigurationDir().getInputStreamReaderFromConfigFile(owner.getModelName()+getLearningMethodName()+suffix);
  }
 
  protected InputStream getInputStreamFromConfigFileEntry(String suffix) throws MaltChainedException {
    return getConfiguration().getConfigurationDir().getInputStreamFromConfigFileEntry(owner.getModelName()+getLearningMethodName()+suffix);
  }
 
 
  protected File getFile(String suffix) throws MaltChainedException {
    return getConfiguration().getConfigurationDir().getFile(owner.getModelName()+getLearningMethodName()+suffix);
  }
 
  protected JarEntry getConfigFileEntry(String suffix) throws MaltChainedException {
    return getConfiguration().getConfigurationDir().getConfigFileEntry(owner.getModelName()+getLearningMethodName()+suffix);
  }
 
  protected void initSpecialParameters() throws MaltChainedException {
    if (getConfiguration().getOptionValue("singlemalt", "null_value") != null && getConfiguration().getOptionValue("singlemalt", "null_value").toString().equalsIgnoreCase("none")) {
      excludeNullValues = true;
    } else {
      excludeNullValues = false;
    }
    saveInstanceFiles = ((Boolean)getConfiguration().getOptionValue("lib", "save_instance_files")).booleanValue();
    if (!getConfiguration().getOptionValue("lib", "external").toString().equals("")) {
      String path = getConfiguration().getOptionValue("lib", "external").toString();
      try {
        if (!new File(path).exists()) {
          throw new LibException("The path to the external  trainer 'svm-train' is wrong.");
        }
        if (new File(path).isDirectory()) {
          throw new LibException("The option --lib-external points to a directory, the path should point at the 'train' file or the 'train.exe' file in the libsvm or the liblinear package");
        }
        if (!(path.endsWith("train") ||path.endsWith("train.exe"))) {
          throw new LibException("The option --lib-external does not specify the path to 'train' file or the 'train.exe' file in the libsvm or the liblinear package. ");
        }
        setPathExternalTrain(path);
      } catch (SecurityException e) {
        throw new LibException("Access denied to the file specified by the option --lib-external. ", e);
      }
    }
    if (getConfiguration().getOptionValue("lib", "verbosity") != null) {
      verbosity = Verbostity.valueOf(getConfiguration().getOptionValue("lib", "verbosity").toString().toUpperCase());
    }
  }
 
  public String getLibOptions() {
    final StringBuilder sb = new StringBuilder();
    for (String key : libOptions.keySet()) {
      sb.append('-');
      sb.append(key);
      sb.append(' ');
      sb.append(libOptions.get(key));
      sb.append(' ');
    }
    return sb.toString();
  }
 
  public String[] getLibParamStringArray() {
    final ArrayList<String> params = new ArrayList<String>();

    for (String key : libOptions.keySet()) {
      params.add("-"+key); params.add(libOptions.get(key));
    }
    return params.toArray(new String[params.size()]);
  }
 
  public abstract void initLibOptions();
  public abstract void initAllowedLibOptionFlags();
 
  public void parseParameters(String paramstring) throws MaltChainedException {
    if (paramstring == null) {
      return;
    }
    final String[] argv;
    try {
      argv = paramstring.split("[_\\p{Blank}]");
    } catch (PatternSyntaxException e) {
      throw new LibException("Could not split the parameter string '"+paramstring+"'. ", e);
    }
    for (int i=0; i < argv.length-1; i++) {
      if(argv[i].charAt(0) != '-') {
        throw new LibException("The argument flag should start with the following character '-', not with "+argv[i].charAt(0));
      }
      if(++i>=argv.length) {
        throw new LibException("The last argument does not have any value. ");
      }
      try {
        int index = allowedLibOptionFlags.indexOf(argv[i-1].charAt(1));
        if (index != -1) {
          libOptions.put(Character.toString(argv[i-1].charAt(1)), argv[i]);
        } else {
          throw new LibException("Unknown learner parameter: '"+argv[i-1]+"' with value '"+argv[i]+"'. ");   
        }
      } catch (ArrayIndexOutOfBoundsException e) {
        throw new LibException("The learner parameter '"+argv[i-1]+"' could not convert the string value '"+argv[i]+"' into a correct numeric value. ", e);
      } catch (NumberFormatException e) {
        throw new LibException("The learner parameter '"+argv[i-1]+"' could not convert the string value '"+argv[i]+"' into a correct numeric value. ", e)
      } catch (NullPointerException e) {
        throw new LibException("The learner parameter '"+argv[i-1]+"' could not convert the string value '"+argv[i]+"' into a correct numeric value. ", e)
      }
    }
  }
 
  protected void finalize() throws Throwable {
    try {
      closeInstanceWriter();
    } finally {
      super.finalize();
    }
  }
 
  public String toString() {
    final StringBuffer sb = new StringBuffer();
    sb.append("\n"+getLearningMethodName()+" INTERFACE\n");
    sb.append(getLibOptions());
    return sb.toString();
  }

  protected int binariesInstance(String line, FeatureList featureList) throws MaltChainedException {
    int y = -1;
    featureList.clear();
    try
      String[] columns = tabPattern.split(line);

      if (columns.length == 0) {
        return -1;
      }
      try {
        y = Integer.parseInt(columns[0]);
      } catch (NumberFormatException e) {
        throw new LibException("The instance file contain a non-integer value '"+columns[0]+"'", e);
      }
      for(int j = 1; j < columns.length; j++) {
        final String[] items = pipePattern.split(columns[j]);
        for (int k = 0; k < items.length; k++) {
          try {
            int colon = items[k].indexOf(':');
            if (colon == -1) {
              if (Integer.parseInt(items[k]) != -1) {
                int v = featureMap.addIndex(j, Integer.parseInt(items[k]));
                if (v != -1) {
                  featureList.add(v,1);
                }
              }
            } else {
              int index = featureMap.addIndex(j, Integer.parseInt(items[k].substring(0,colon)));
              double value;
              if (items[k].substring(colon+1).indexOf('.') != -1) {
                value = Double.parseDouble(items[k].substring(colon+1));
              } else {
                value = Integer.parseInt(items[k].substring(colon+1));
              }
              featureList.add(index,value);
            }
          } catch (NumberFormatException e) {
            throw new LibException("The instance file contain a non-numeric value '"+items[k]+"'", e);
          }
        }
      }
    } catch (ArrayIndexOutOfBoundsException e) {
      throw new LibException("Couln't read from the instance file. ", e);
    }
    return y;
  }

  protected void binariesInstances2SVMFileFormat(InputStreamReader isr, OutputStreamWriter osw) throws MaltChainedException {
    try {
      final BufferedReader in = new BufferedReader(isr);
      final BufferedWriter out = new BufferedWriter(osw);
      final FeatureList featureSet = new FeatureList();
      while(true) {
        String line = in.readLine();
        if(line == null) break;
        int y = binariesInstance(line, featureSet);
        if (y == -1) {
          continue;
        }
        out.write(Integer.toString(y));
       
            for (int k=0; k < featureSet.size(); k++) {
              MaltFeatureNode x = featureSet.get(k);
          out.write(' ');
          out.write(Integer.toString(x.getIndex()));
          out.write(':');
          out.write(Double.toString(x.getValue()));        
        }
        out.write('\n');
      }     
      in.close()
      out.close();
    } catch (NumberFormatException e) {
      throw new LibException("The instance file contain a non-numeric value", e);
    } catch (IOException e) {
      throw new LibException("Couln't read from the instance file, when converting the Malt instances into LIBSV/LIBLINEAR format. ", e);
    }
  }
 
  protected void saveFeatureMap(OutputStream os, FeatureMap map) throws MaltChainedException {
    try {
        ObjectOutputStream output = new ObjectOutputStream(os);
          try{
            output.writeObject(map);
          }
          finally{
            output.close();
          }
    } catch (IOException e) {
      throw new LibException("Save feature map error", e);
    }
  }

  protected FeatureMap loadFeatureMap(InputStream is) throws MaltChainedException {
    FeatureMap map = new FeatureMap();
    try {
        ObjectInputStream input = new ObjectInputStream(is);
        try {
          map = (FeatureMap)input.readObject();
        } finally {
          input.close();
        }
    } catch (ClassNotFoundException e) {
      throw new LibException("Load feature map error", e);
    } catch (IOException e) {
      throw new LibException("Load feature map error", e);
    }
    return map;
  }
}
TOP

Related Classes of org.maltparserx.ml.lib.Lib

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.