Package de.jungblut.math.minimize

Source Code of de.jungblut.math.minimize.OWLQN

package de.jungblut.math.minimize;

import gnu.trove.list.array.TDoubleArrayList;

import java.util.ArrayList;

import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.commons.math3.util.FastMath;

import de.jungblut.math.DoubleVector;
import de.jungblut.math.dense.DenseDoubleVector;

/**
* Java translation of C++ code of
* "Orthant-Wise Limited-memory Quasi-Newton Optimizer for L1-regularized Objectives"
* (@see <a href=
* "http://research.microsoft.com/en-us/downloads/b1eb1016-1738-4bd5-83a9-370
* c9d498a03/">http://research.microsoft.com/</a>). <br/>
* <br/>
*
* The Orthant-Wise Limited-memory Quasi-Newton algorithm (OWL-QN) is a
* numerical optimization procedure for finding the optimum of an objective of
* the form {smooth function} plus {L1-norm of the parameters}. It has been used
* for training log-linear models (such as logistic regression) with
* L1-regularization. The algorithm is described in
* "Scalable training of L1-regularized log-linear models" by Galen Andrew and
* Jianfeng Gao. <br/>
* <br/>
*
* Orthant-Wise Limited-memory Quasi-Newton algorithm minimizes functions of the
* form<br/>
* <br/>
*
* f(w) = loss(w) + C |w|_1<br/>
* <br/>
*
* where loss is an arbitrary differentiable convex loss function, and |w|_1 is
* the L1 norm of the weight (parameter) vector. It is based on the L-BFGS
* Quasi-Newton algorithm, with modifications to deal with the fact that the L1
* norm is not differentiable. The algorithm is very fast, and capable of
* scaling efficiently to problems with millions of parameters.<br/>
* <br/>
*
*
* This is a straight forward translation, with the use of my math library.
*
* @author thomas.jungblut
*
*/
public class OWLQN extends AbstractMinimizer {

  private static final Log LOG = LogFactory.getLog(OWLQN.class);

  private DoubleVector x, grad, newX, newGrad, dir;
  private DoubleVector steepestDescDir;
  private double[] alphas;
  private TDoubleArrayList roList;
  private TDoubleArrayList costs;
  private ArrayList<DoubleVector> sList, yList;
  private double value;
  private int m = 10;
  private double l1weight = 0;

  private double tol = 1e-4;
  private boolean gradCheck = false;

  @Override
  public DoubleVector minimize(CostFunction f, DoubleVector theta,
      int maxIterations, boolean verbose) {

    DoubleVector zeros = new DenseDoubleVector(theta.getDimension());
    this.x = theta;
    this.grad = zeros;
    this.newX = theta.deepCopy();
    this.newGrad = zeros;
    this.dir = zeros;
    this.steepestDescDir = newGrad;

    this.alphas = new double[m];
    this.roList = new TDoubleArrayList(m);
    this.costs = new TDoubleArrayList(m);
    this.sList = new ArrayList<>();
    this.yList = new ArrayList<>();

    this.value = evaluateL1(f);
    this.grad = newGrad;

    for (int i = 0; i < maxIterations; i++) {
      updateDir(f, verbose);
      boolean continueIterations = backTrackingLineSearch(i, f);
      shift();
      costs.add(value);

      // also break on too small average improvement over 5 iterations
      if (costs.size() > 5) {
        double first = costs.get(0);
        while (costs.size() > 5) {
          costs.removeAt(0);
        }
        double avgImprovement = (first - value) / costs.size();
        double perc = avgImprovement / Math.abs(value);
        if (perc < tol) {
          break;
        }
      }

      // break if we can't get any improvement
      if (!continueIterations) {
        break;
      }

      if (verbose) {
        LOG.info("Iteration " + i + " | Cost: " + value);
      }

    }

    // cleanup all the stuff in this class
    x = null;
    grad = null;
    newGrad = null;
    dir = null;
    steepestDescDir = null;
    alphas = null;
    roList = null;
    costs = null;
    sList = null;
    yList = null;

    return newX;
  }

  private void updateDir(CostFunction f, boolean verbose) {
    makeSteepestDescDir();
    mapDirectionByInverseHessian();
    fixDirectionSigns();

    if (gradCheck) {
      testDirectionDerivation(f);
    }
  }

  private void testDirectionDerivation(CostFunction f) {
    double dirNorm = FastMath.sqrt(dir.dot(dir));
    // if dirNorm is 0, we probably hit the minimum. So we have no gradient to
    // descent to.
    if (dirNorm != 0d) {
      double eps = 1.05e-8 / dirNorm;
      getNextPoint(eps);
      double val2 = evaluateL1(f);
      double numDeriv = (val2 - value) / eps;
      double deriv = directionDerivation();
      LOG.info("GradCheck: expected= " + numDeriv + " vs. " + deriv
          + "! AbsDiff= " + Math.abs(numDeriv - deriv));
    }
  }

  private void fixDirectionSigns() {
    if (l1weight > 0) {
      for (int i = 0; i < dir.getDimension(); i++) {
        if (dir.get(i) * steepestDescDir.get(i) <= 0) {
          dir.set(i, 0);
        }
      }
    }
  }

  private void mapDirectionByInverseHessian() {
    int count = sList.size();

    if (count != 0) {
      for (int i = count - 1; i >= 0; i--) {
        alphas[i] = -sList.get(i).dot(dir) / roList.get(i);
        addMult(dir, yList.get(i), alphas[i]);
      }

      DoubleVector lastY = yList.get(count - 1);
      double yDotY = lastY.dot(lastY);
      double scalar = roList.get(count - 1) / yDotY;
      scale(dir, scalar);

      for (int i = 0; i < count; i++) {
        double beta = yList.get(i).dot(dir) / roList.get(i);
        addMult(dir, sList.get(i), -alphas[i] - beta);
      }
    }
  }

  private void makeSteepestDescDir() {
    if (l1weight == 0) {
      scaleInto(dir, grad, -1);
    } else {

      for (int i = 0; i < dir.getDimension(); i++) {
        if (x.get(i) < 0) {
          dir.set(i, -grad.get(i) + l1weight);
        } else if (x.get(i) > 0) {
          dir.set(i, -grad.get(i) - l1weight);
        } else {
          if (grad.get(i) < -l1weight) {
            dir.set(i, -grad.get(i) - l1weight);
          } else if (grad.get(i) > l1weight) {
            dir.set(i, -grad.get(i) + l1weight);
          } else {
            dir.set(i, 0);
          }
        }
      }
    }

    steepestDescDir = dir;
  }

  private boolean backTrackingLineSearch(int iter, CostFunction f) {
    double origDirDeriv = directionDerivation();
    // if a non-descent direction is chosen, the line search will break anyway,
    // so throw here
    // The most likely reason for this is a bug in your function's gradient
    // computation
    if (origDirDeriv > 0d) {
      throw new RuntimeException(
          "L-BFGS chose a non-descent direction: check your gradient!");
    } else if (origDirDeriv == 0d || Double.isNaN(origDirDeriv)) {
      LOG.info("L-BFGS apparently found the minimum. No direction to descent anymore.");
      return false;
    }

    double alpha = 1.0;
    double backoff = 0.5;
    if (iter == 0) {
      double normDir = FastMath.sqrt(dir.dot(dir));
      alpha = (1 / normDir);
      backoff = 0.1;
    }

    double c1 = 1e-4;
    double oldValue = value;

    while (true) {
      getNextPoint(alpha);
      value = evaluateL1(f);

      if (Double.isNaN(value) || value <= oldValue + c1 * origDirDeriv * alpha) {
        break;
      }

      alpha *= backoff;
    }

    return true;
  }

  private void getNextPoint(double alpha) {
    addMultInto(newX, x, dir, alpha);

    if (l1weight > 0) {
      for (int i = 0; i < x.getDimension(); i++) {
        if (x.get(i) * newX.get(i) < 0.0) {
          newX.set(i, 0d);
        }
      }
    }
  }

  private void addMultInto(DoubleVector a, DoubleVector b, DoubleVector c,
      double d) {
    for (int i = 0; i < a.getDimension(); i++) {
      a.set(i, b.get(i) + c.get(i) * d);
    }
  }

  private void addMult(DoubleVector a, DoubleVector b, double c) {
    for (int i = 0; i < a.getDimension(); i++) {
      a.set(i, a.get(i) + b.get(i) * c);
    }
  }

  private void scale(DoubleVector a, double b) {
    for (int i = 0; i < a.getDimension(); i++) {
      a.set(i, a.get(i) * b);
    }
  }

  void scaleInto(DoubleVector a, DoubleVector b, double c) {
    for (int i = 0; i < a.getDimension(); i++) {
      a.set(i, b.get(i) * c);
    }
  }

  private double directionDerivation() {
    if (l1weight == 0.0) {
      return dir.dot(grad);
    } else {
      double val = 0.0;
      for (int i = 0; i < dir.getDimension(); i++) {
        if (dir.get(i) != 0) {
          if (x.get(i) < 0) {
            val += dir.get(i) * (grad.get(i) - l1weight);
          } else if (x.get(i) > 0) {
            val += dir.get(i) * (grad.get(i) + l1weight);
          } else if (dir.get(i) < 0) {
            val += dir.get(i) * (grad.get(i) - l1weight);
          } else if (dir.get(i) > 0) {
            val += dir.get(i) * (grad.get(i) + l1weight);
          }
        }
      }

      return val;
    }
  }

  private double evaluateL1(CostFunction f) {
    CostGradientTuple evaluateCost = f.evaluateCost(newX);
    newGrad = evaluateCost.getGradient();
    double val = evaluateCost.getCost();
    if (l1weight > 0) {
      for (int i = 0; i < newGrad.getDimension(); i++) {
        val += Math.abs(newX.get(i)) * l1weight;
      }
    }

    return val;
  }

  private void shift() {
    DoubleVector nextS = null;
    DoubleVector nextY = null;

    int listSize = sList.size();

    if (listSize < m) {
      nextS = new DenseDoubleVector(x.getDimension());
      nextY = new DenseDoubleVector(x.getDimension());
    }

    if (nextS == null) {
      nextS = sList.get(0);
      sList.remove(0);
      nextY = yList.get(0);
      yList.remove(0);
      roList.removeAt(0);
    }

    addMultInto(nextS, newX, x, -1);
    addMultInto(nextY, newGrad, grad, -1);
    double ro = nextS.dot(nextY);

    sList.add(nextS);
    yList.add(nextY);
    roList.add(ro);

    DoubleVector tmpNewX = newX.deepCopy();
    newX = x.deepCopy();
    x = tmpNewX;

    DoubleVector tmpNewGrad = newGrad.deepCopy();
    newGrad = grad.deepCopy();
    grad = tmpNewGrad;
  }

  /**
   * Set to true this will check the gradients every iteration and print out if
   * it aligns with the numerical gradient.
   */
  public OWLQN doGradChecks() {
    this.gradCheck = true;
    return this;
  }

  /**
   * The amount of directions and gradients to keep, this is the "limited" part
   * of L-BFGS. It defaults to 10.
   */
  public OWLQN setM(int m) {
    this.m = m;
    return this;
  }

  /**
   * This implementation also supports l1 weight adjustment (without the
   * costfunction knowing about it). This is turned off by default.
   */
  public OWLQN setL1Weight(double l1weight) {
    this.l1weight = l1weight;
    return this;
  }

  /**
   * The breaking tolerance over a window of five iterations. This defaults to
   * 1e-4.
   */
  public OWLQN setTolerance(double tol) {
    this.tol = tol;
    return this;
  }

  /**
   * Minimizes the given cost function with L-BFGS.
   *
   * @param f the costfunction to minimize.
   * @param theta the initial weights.
   * @param maxIterations the maximum amount of iterations.
   * @param verbose true if progress output shall be printed.
   * @return the optimized set of parameters for the cost function.
   */
  public static DoubleVector minimizeFunction(CostFunction f,
      DoubleVector theta, int maxIterations, boolean verbose) {
    return new OWLQN().minimize(f, theta, maxIterations, verbose);
  }
}
TOP

Related Classes of de.jungblut.math.minimize.OWLQN

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.