Package org.apache.commons.math.optimization

Source Code of org.apache.commons.math.optimization.NelderMeadTest$ValueChecker

/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements.  See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License.  You may obtain a copy of the License at
*
*      http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.commons.math.optimization;

import org.apache.commons.math.optimization.ConvergenceChecker;
import org.apache.commons.math.optimization.CostException;
import org.apache.commons.math.optimization.CostFunction;
import org.apache.commons.math.optimization.NelderMead;
import org.apache.commons.math.ConvergenceException;
import org.apache.commons.math.optimization.PointCostPair;
import org.apache.commons.math.random.JDKRandomGenerator;
import org.apache.commons.math.random.NotPositiveDefiniteMatrixException;
import org.apache.commons.math.random.RandomGenerator;
import org.apache.commons.math.random.RandomVectorGenerator;
import org.apache.commons.math.random.UncorrelatedRandomVectorGenerator;
import org.apache.commons.math.random.UniformRandomGenerator;

import junit.framework.*;

public class NelderMeadTest
  extends TestCase {

  public NelderMeadTest(String name) {
    super(name);
  }

  public void testCostExceptions() throws ConvergenceException {
      CostFunction wrong =
          new CostFunction() {
            public double cost(double[] x) throws CostException {
                if (x[0] < 0) {
                    throw new CostException("{0}", new Object[] { "oops"});
                } else if (x[0] > 1) {
                    throw new CostException(new RuntimeException("oops"));
                } else {
                    return x[0] * (1 - x[0]);
                }
            }
      };
      try {
          new NelderMead(0.9, 1.9, 0.4, 0.6).minimize(wrong, 10, new ValueChecker(1.0e-3),
                                                      new double[] { -0.5 }, new double[] { 0.5 });
          fail("an exception should have been thrown");
      } catch (CostException ce) {
          // expected behavior
          assertNull(ce.getCause());
      } catch (Exception e) {
          fail("wrong exception caught: " + e.getMessage());
      }
      try {
          new NelderMead(0.9, 1.9, 0.4, 0.6).minimize(wrong, 10, new ValueChecker(1.0e-3),
                                                      new double[] { 0.5 }, new double[] { 1.5 });
          fail("an exception should have been thrown");
      } catch (CostException ce) {
          // expected behavior
          assertNotNull(ce.getCause());
      } catch (Exception e) {
          fail("wrong exception caught: " + e.getMessage());
      }
  }

  public void testRosenbrock()
    throws CostException, ConvergenceException, NotPositiveDefiniteMatrixException {

    CostFunction rosenbrock =
      new CostFunction() {
        public double cost(double[] x) {
          ++count;
          double a = x[1] - x[0] * x[0];
          double b = 1.0 - x[0];
          return 100 * a * a + b * b;
        }
      };

    count = 0;
    NelderMead nm = new NelderMead();
    try {
      nm.minimize(rosenbrock, 100, new ValueChecker(1.0e-3),
                  new double[][] {
                    { -1.2, 1.0 }, { 3.5, -2.3 }, { 0.4, 1.5 }
                  }, 1, 5384353l);
      fail("an exception should have been thrown");
    } catch (ConvergenceException ce) {
        // expected behavior
    } catch (Exception e) {
        fail("wrong exception caught: " + e.getMessage());
    }

    count = 0;
    PointCostPair optimum =
        nm.minimize(rosenbrock, 100, new ValueChecker(1.0e-3),
                    new double[][] {
                      { -1.2, 1.0 }, { 0.9, 1.2 }, { 3.5, -2.3 }
                    }, 10, 1642738l);

    assertTrue(count > 700);
    assertTrue(count < 800);
    assertEquals(0.0, optimum.getCost(), 5.0e-5);
    assertEquals(1.0, optimum.getPoint()[0], 0.01);
    assertEquals(1.0, optimum.getPoint()[1], 0.01);

    PointCostPair[] minima = nm.getMinima();
    assertEquals(10, minima.length);
    assertNotNull(minima[0]);
    assertNull(minima[minima.length - 1]);
    for (int i = 0; i < minima.length; ++i) {
        if (minima[i] == null) {
            if ((i + 1) < minima.length) {
                assertTrue(minima[i+1] == null);
            }
        } else {
            if (i > 0) {
                assertTrue(minima[i-1].getCost() <= minima[i].getCost());
            }
        }
    }

    RandomGenerator rg = new JDKRandomGenerator();
    rg.setSeed(64453353l);
    RandomVectorGenerator rvg =
        new UncorrelatedRandomVectorGenerator(new double[] { 0.9, 1.1 },
                                              new double[] { 0.2, 0.2 },
                                              new UniformRandomGenerator(rg));
    optimum =
        nm.minimize(rosenbrock, 100, new ValueChecker(1.0e-3), rvg);
    assertEquals(0.0, optimum.getCost(), 2.0e-4);
    optimum =
        nm.minimize(rosenbrock, 100, new ValueChecker(1.0e-3), rvg, 3);
    assertEquals(0.0, optimum.getCost(), 3.0e-5);

  }

  public void testPowell()
    throws CostException, ConvergenceException {

    CostFunction powell =
      new CostFunction() {
        public double cost(double[] x) {
          ++count;
          double a = x[0] + 10 * x[1];
          double b = x[2] - x[3];
          double c = x[1] - 2 * x[2];
          double d = x[0] - x[3];
          return a * a + 5 * b * b + c * c * c * c + 10 * d * d * d * d;
        }
      };

    count = 0;
    NelderMead nm = new NelderMead();
    PointCostPair optimum =
      nm.minimize(powell, 200, new ValueChecker(1.0e-3),
                  new double[] {  3.0, -1.0, 0.0, 1.0 },
                  new double[] {  4.00.0, 1.0, 2.0 },
                  1, 1642738l);
    assertTrue(count < 150);
    assertEquals(0.0, optimum.getCost(), 6.0e-4);
    assertEquals(0.0, optimum.getPoint()[0], 0.07);
    assertEquals(0.0, optimum.getPoint()[1], 0.07);
    assertEquals(0.0, optimum.getPoint()[2], 0.07);
    assertEquals(0.0, optimum.getPoint()[3], 0.07);

  }

  private static class ValueChecker implements ConvergenceChecker {

    public ValueChecker(double threshold) {
      this.threshold = threshold;
    }

    public boolean converged(PointCostPair[] simplex) {
      PointCostPair smallest = simplex[0];
      PointCostPair largest  = simplex[simplex.length - 1];
      return (largest.getCost() - smallest.getCost()) < threshold;
    }

    private double threshold;

  };

  public static Test suite() {
    return new TestSuite(NelderMeadTest.class);
  }

  private int count;

}
TOP

Related Classes of org.apache.commons.math.optimization.NelderMeadTest$ValueChecker

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.