package aima.test.core.unit.probability.bayes.approx;
import org.junit.Assert;
import org.junit.Test;
import aima.core.probability.ProbabilityModel;
import aima.core.probability.RandomVariable;
import aima.core.probability.bayes.BayesianNetwork;
import aima.core.probability.bayes.approx.PriorSample;
import aima.core.probability.bayes.approx.RejectionSampling;
import aima.core.probability.example.BayesNetExampleFactory;
import aima.core.probability.example.ExampleRV;
import aima.core.probability.proposition.AssignmentProposition;
import aima.core.util.MockRandomizer;
/**
*
* @author Ciaran O'Reilly
* @author Ravi Mohan
*/
public class RejectionSamplingTest {
public static final double DELTA_THRESHOLD = ProbabilityModel.DEFAULT_ROUNDING_THRESHOLD;
@Test
public void testPriorSample_basic() {
BayesianNetwork bn = BayesNetExampleFactory
.constructCloudySprinklerRainWetGrassNetwork();
AssignmentProposition[] e = new AssignmentProposition[] { new AssignmentProposition(
ExampleRV.SPRINKLER_RV, Boolean.TRUE) };
MockRandomizer r = new MockRandomizer(new double[] { 0.1 });
RejectionSampling rs = new RejectionSampling(new PriorSample(r));
double[] estimate = rs.rejectionSampling(
new RandomVariable[] { ExampleRV.RAIN_RV }, e, bn, 100)
.getValues();
Assert.assertArrayEquals(new double[] { 1.0, 0.0 }, estimate,
DELTA_THRESHOLD);
}
@Test
public void testRejectionSampling_AIMA3e_pg532() {
// AIMA3e pg. 532
BayesianNetwork bn = BayesNetExampleFactory
.constructCloudySprinklerRainWetGrassNetwork();
AssignmentProposition[] e = new AssignmentProposition[] { new AssignmentProposition(
ExampleRV.SPRINKLER_RV, Boolean.TRUE) };
// 400 required as 4 variables and 100 samples planned
double[] ma = new double[400];
for (int i = 0; i < ma.length; i += 4) {
// Of the 100 that we generate, suppose
// that 73 have Sprinkler = false and are rejected,
if (i < (73 * 4)) {
ma[i] = 0.5; // i.e Cloudy=true
ma[i + 1] = 0.2; // i.e. Sprinkler=false
ma[i + 2] = 0.5; // i.e. Rain=true
ma[i + 3] = 0.1; // i.e. WetGrass=true
} else {
ma[i] = 0.5; // i.e Cloudy=true
ma[i + 1] = 0.09; // i.e. Sprinkler=true
// while 27 have Sprinkler = true; of the 27,
// 8 have Rain = true
if (i < ((73 + 8) * 4)) {
ma[i + 2] = 0.5; // i.e. Rain=true
} else {
// and 19 have Rain = false.
ma[i + 2] = 0.9; // i.e. Rain=false
}
ma[i + 3] = 0.1; // i.e. WetGrass=true
}
}
MockRandomizer r = new MockRandomizer(ma);
RejectionSampling rs = new RejectionSampling(new PriorSample(r));
double[] estimate = rs.rejectionSampling(
new RandomVariable[] { ExampleRV.RAIN_RV }, e, bn, 100)
.getValues();
Assert.assertArrayEquals(new double[] { 0.2962962962962963,
0.7037037037037037 }, estimate, DELTA_THRESHOLD);
}
}