Examples of DoubleAD


Examples of edu.stanford.nlp.math.DoubleAD

    value = 0.0;

    //initialize any variables
    DoubleAD[] derivativeAD = new DoubleAD[x.length];
    for (int i = 0; i < x.length;i++) {
      derivativeAD[i] = new DoubleAD(0.0,0.0);
    }

    DoubleAD[] xAD = new DoubleAD[x.length];
    for (int i = 0; i < x.length;i++){
      xAD[i] = new DoubleAD(x[i],v[i]);
    }

    // Initialize the sums
    DoubleAD[] sums = new DoubleAD[numClasses];
    for (int c = 0; c<numClasses;c++){
      sums[c] = new DoubleAD(0,0);
    }

    DoubleAD[] probs = new DoubleAD[numClasses];
    for (int c = 0; c<numClasses;c++) {
      probs[c] = new DoubleAD(0,0);
    }

    //long curTime = System.currentTimeMillis();
    // Copy the Derivative numerator, and set up the vector V to be used for Hess*V
    for (int i = 0; i < x.length;i++){
      xAD[i].set(x[i],v[i]);
      derivativeAD[i].set(0.0,0.0);
    }

    //System.err.print(System.currentTimeMillis() - curTime + " - ");
    //curTime = System.currentTimeMillis();

    for (int d = 0; d <batch.length ; d++) {

      //Sets the index based on the current batch
      int m = (curElement + d) % data.length;

      int[] features = data[m];

      for (int c = 0; c<numClasses;c++){
        sums[c].set(0.0,0.0);
      }


      for (int c = 0; c < numClasses; c++) {
        for (int feature : features) {
          int i = indexOf(feature, c);
          sums[c] = ADMath.plus(sums[c], xAD[i]);
        }
      }

      DoubleAD total = ADMath.logSum(sums);

      for (int c = 0; c < numClasses; c++) {
        probs[c] = ADMath.exp( ADMath.minus(sums[c], total) );
        if (dataWeights != null) {
          probs[c] = ADMath.multConst(probs[c], dataWeights[d]);
        }
        for (int feature : features) {
          int i = indexOf(feature, c);
          if (c == labels[m]) {
            derivativeAD[i].plusEqualsConst(-1.0);
          }
          derivativeAD[i].plusEquals(probs[c]);
        }
      }

      double dV = sums[labels[m]].getval() - total.getval();
      if (dataWeights != null) {
        dV *= dataWeights[d];
      }
      value -= dV;
    }
View Full Code Here
TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.