Package com.nr.test.test_chapter16

Source Code of com.nr.test.test_chapter16.Test_Kmeans

package com.nr.test.test_chapter16;

import static com.nr.NRUtil.SQR;
import static com.nr.NRUtil.buildMatrix;
import static com.nr.NRUtil.buildVector;
import static com.nr.test.NRTestUtil.maxel;
import static java.lang.Math.sqrt;
import static org.junit.Assert.fail;

import org.junit.After;
import org.junit.Before;
import org.junit.Test;

import com.nr.ci.Kmeans;
import com.nr.ran.Normaldev;

public class Test_Kmeans {

  @Before
  public void setUp() throws Exception {
  }

  @After
  public void tearDown() throws Exception {
  }

  @Test
  public void test() {
    int i,j,k,flag=0,NDIM=2,NMEANS=4,NPT=1000;
    double sqrt2=sqrt(2.0),d0,d1,sbeps;
    double ffrac[]={0.25,0.25,0.25,0.25};
    double mmeans[]={0.0,0.0,0.75,0.0,-0.25,-0.25,0.33,0.66};
    double gguess[]={0.1,0.1,0.7,0.1,-0.2,-0.3,0.3,0.5};
    double ssigma[]={0.1,0.1,0.02,0.2,0.01,0.1,0.1,0.05};
    double vvec1[]={1.0,0.0,1.0,0.0,sqrt2,sqrt2,sqrt2,sqrt2};
    double vvec2[]={0.0,1.0,0.0,1.0,-sqrt2,sqrt2,-sqrt2,sqrt2};
    int[] count=new int[NMEANS],error=new int[NMEANS];
    double[] frac=buildVector(ffrac),offset=new double[NMEANS];
    double[][] guess=buildMatrix(NMEANS,NDIM,gguess);
    double[][] means=buildMatrix(NMEANS,NDIM,mmeans),sigma=buildMatrix(NMEANS,NDIM,ssigma);
    double[][] vec1=buildMatrix(NMEANS,NDIM,vvec1),vec2=buildMatrix(NMEANS,NDIM,vvec2),x=new double[NPT][NDIM];
    boolean localflag, globalflag=false;

   

    // Test Kmeans
    System.out.println("Testing Kmeans");

    Normaldev ndev = new Normaldev(0.0,1.0,17);

    // Generate four groups of data
    k=0;
    for (i=0;i<4;i++) {
      count[i]=(int)(NPT*frac[i]);
      for (j=0;j<count[i];j++) {
        d0=sigma[i][0]*ndev.dev();
        d1=sigma[i][1]*ndev.dev();
        x[k][0]=means[i][0]+d0*vec1[i][0]+d1*vec2[i][0];
        x[k][1]=means[i][1]+d0*vec1[i][1]+d1*vec2[i][1];
        k++;
      }
    }

    Kmeans kmean=new Kmeans(x,guess)// *** put in some weird guesses

    for (i=0;i<100;i++) {
      flag=kmean.estep();
      if (flag == 0) break;
      kmean.mstep();
    }

    // check for convergence
//    System.out.println("  flag: %f\n", flag);
    localflag = flag > 0;
    globalflag = globalflag || localflag;
    if (localflag) {
      fail("*** Kmeans: No solution with 100 iterations");
     
    }

    // Check for correct populations
//    for (i=0;i<NMEANS;i++) System.out.printf(kmean.count[i] << " ";
//    System.out.printf(endl;
    for (i=0;i<NMEANS;i++) {
      error[i]=kmean.count[i]-count[i];
      localflag = (error[i] > 0 ? error[i] : -error[i]) > 15;
      globalflag = globalflag || localflag;
      if (localflag) {
        fail("*** Kmeans: Populations are not the correct approximate size");
       
      }
    }

    // Check group assignments
    k=0;
    for (i=0;i<NMEANS;i++) {
      error[i]=0;
      for (j=0;j<count[i];j++) {
        if (kmean.assign[k] != i) error[i]++;
        k++;
      }
//      System.out.printf(error[i] << " ";
    }
//    System.out.printf(endl;

    // Check for correct determination of means
    for (i=0;i<NMEANS;i++) offset[i]=sqrt(SQR(kmean.means[i][0]-means[i][0])
      +SQR(kmean.means[i][1]-means[i][1]));
//    System.out.printf(maxel(offset));
    sbeps=0.02;
    localflag = maxel(offset) > sbeps;
    globalflag = globalflag || localflag;
    if (localflag) {
      fail("*** Kmeans: Means are incorrectly identified");
     
    }

    if (globalflag) System.out.println("Failed\n");
    else System.out.println("Passed\n");
  }

}
TOP

Related Classes of com.nr.test.test_chapter16.Test_Kmeans

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.