Package org.drools.pmml.pmml_4_1.predictive.models

Source Code of org.drools.pmml.pmml_4_1.predictive.models.DecisionTreeTest

/*
* Copyright 2011 JBoss Inc
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
*       http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.drools.pmml.pmml_4_1.predictive.models;


import org.dmg.pmml.pmml_4_1.descr.MISSINGVALUESTRATEGY;
import org.dmg.pmml.pmml_4_1.descr.PMML;
import org.dmg.pmml.pmml_4_1.descr.TreeModel;
import org.drools.pmml.pmml_4_1.DroolsAbstractPMMLTest;
import org.drools.pmml.pmml_4_1.PMML4Compiler;
import org.drools.pmml.pmml_4_1.PMML4Helper;
import org.junit.After;
import org.junit.Test;
import org.kie.api.definition.type.FactType;
import org.kie.api.runtime.ClassObjectFilter;
import org.kie.api.runtime.KieSession;
import org.kie.internal.io.ResourceFactory;

import java.util.Collection;

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;

public class DecisionTreeTest extends DroolsAbstractPMMLTest {


    private static final boolean VERBOSE = false;
    private static final String source1 = "org/drools/pmml/pmml_4_1/test_tree_simple.xml";
    private static final String source2 = "org/drools/pmml/pmml_4_1/test_tree_missing.xml";
    private static final String packageName = "org.drools.pmml.pmml_4_1.test";



    @After
    public void tearDown() {
        getKSession().dispose();
    }

    @Test
    public void testSimpleTree() throws Exception {
        setKSession( getModelSession( source1, VERBOSE ) );
        setKbase( getKSession().getKieBase() );
        KieSession kSession = getKSession();

//        kSession.addEventListener( new org.drools.event.rule.DebugAgendaEventListener() );

        kSession.fireAllRules()//init model

        FactType tgt = kSession.getKieBase().getFactType( packageName, "Fld5" );
       
        kSession.getEntryPoint( "in_Fld1" ).insert( 30.0 );
        kSession.getEntryPoint( "in_Fld2" ).insert( 60.0 );
        kSession.getEntryPoint( "in_Fld3" ).insert( "false" );
        kSession.getEntryPoint( "in_Fld4" ).insert( "optA" );

        kSession.fireAllRules();

        checkFirstDataFieldOfTypeStatus( tgt, true, false, "Missing", "tgtY" );
    }
   
   
   
    protected Object getToken( KieSession kSession ) {
        FactType tok = kSession.getKieBase().getFactType( PMML4Helper.pmmlDefaultPackageName(), "TreeToken" );
        assertNotNull( tok );
        Collection c = kSession.getObjects( new ClassObjectFilter( tok.getFactClass() ) );
        assertEquals( 1, c.size() );
        return c.iterator().next();
    }


    @Test
    public void testMissingTree() throws Exception {
        setKSession( getModelSession( source2, VERBOSE ) );
        setKbase( getKSession().getKieBase() );
        KieSession kSession = getKSession();

        kSession.fireAllRules()//init model

        FactType tgt = kSession.getKieBase().getFactType( packageName, "Fld9" );
        FactType tok = kSession.getKieBase().getFactType( PMML4Helper.pmmlDefaultPackageName(), "TreeToken" );

        kSession.getEntryPoint( "in_Fld1" ).insert( 45.0 );
        kSession.getEntryPoint( "in_Fld2" ).insert( 60.0 );
        kSession.getEntryPoint( "in_Fld3" ).insert( "optA" );

        kSession.fireAllRules();

        Object token = getToken( kSession );
        assertEquals( 0.6, tok.get( token, "confidence" ) );
        assertEquals( "null", tok.get( token, "current" ) );

        checkFirstDataFieldOfTypeStatus( tgt, true, false, "Missing", "tgtZ" );
       
       
    }



    @Test
    public void testMissingTreeWeighted1() throws Exception {
        setKSession( getModelSession( source2, VERBOSE ) );
        setKbase( getKSession().getKieBase() );
        KieSession kSession = getKSession();

        kSession.fireAllRules()//init model

        FactType tgt = kSession.getKieBase().getFactType( packageName, "Fld9" );
        FactType tok = kSession.getKieBase().getFactType( PMML4Helper.pmmlDefaultPackageName(), "TreeToken" );

        kSession.getEntryPoint( "in_Fld1" ).insert( -1.0 );
        kSession.getEntryPoint( "in_Fld2" ).insert( -1.0 );
        kSession.getEntryPoint( "in_Fld3" ).insert( "optA" );

        kSession.fireAllRules();

        Object token = getToken( kSession );
        assertEquals( 0.8, tok.get( token, "confidence" ) );
        assertEquals( "null", tok.get( token, "current" ) );
        assertEquals( 50.0, tok.get( token, "totalCount" ) );

        checkFirstDataFieldOfTypeStatus(tgt, true, false, "Missing", "tgtX" );


    }




    @Test
    public void testMissingTreeWeighted2() throws Exception {
        setKSession( getModelSession( source2, VERBOSE ) );
        setKbase( getKSession().getKieBase() );
        KieSession kSession = getKSession();

        kSession.fireAllRules()//init model

        FactType tgt = kSession.getKieBase().getFactType( packageName, "Fld9" );
        FactType tok = kSession.getKieBase().getFactType( PMML4Helper.pmmlDefaultPackageName(), "TreeToken" );

        kSession.getEntryPoint( "in_Fld1" ).insert( -1.0 );
        kSession.getEntryPoint( "in_Fld2" ).insert( -1.0 );
        kSession.getEntryPoint( "in_Fld3" ).insert( "miss" );

        kSession.fireAllRules();

        Object token = getToken( kSession );
        assertEquals( 0.6, tok.get( token, "confidence" ) );
        assertEquals( "null", tok.get( token, "current" ) );
        assertEquals( 100.0, tok.get( token, "totalCount" ) );

        checkFirstDataFieldOfTypeStatus(tgt, true, false, "Missing", "tgtX" );
    }




    @Test
    public void testMissingTreeDefault() throws Exception {
        PMML4Compiler compiler = new PMML4Compiler();
        PMML pmml = compiler.loadModel( PMML, ResourceFactory.newClassPathResource( source2 ).getInputStream() );

        for ( Object o : pmml.getAssociationModelsAndBaselineModelsAndClusteringModels() ) {
            if ( o instanceof TreeModel ) {
                TreeModel tree = (TreeModel) o;
                tree.setMissingValueStrategy( MISSINGVALUESTRATEGY.DEFAULT_CHILD );
            }
        }

        KieSession kSession = getSession( compiler.generateTheory( pmml ) );

        setKSession( kSession );
        setKbase( getKSession().getKieBase() );

       
       
        kSession.fireAllRules()//init model

        FactType tgt = kSession.getKieBase().getFactType( packageName, "Fld9" );
        FactType tok = kSession.getKieBase().getFactType( PMML4Helper.pmmlDefaultPackageName(), "TreeToken" );

        kSession.getEntryPoint( "in_Fld1" ).insert( 70.0 );
        kSession.getEntryPoint( "in_Fld2" ).insert( 40.0 );
        kSession.getEntryPoint( "in_Fld3" ).insert( "miss" );

        kSession.fireAllRules();

        Object token = getToken( kSession );
        assertEquals( 0.72, (Double) tok.get( token, "confidence" ), 1e-6 );
        assertEquals( "null", tok.get( token, "current" ) );
        assertEquals( 40.0, tok.get( token, "totalCount" ) );

        checkFirstDataFieldOfTypeStatus(tgt, true, false, "Missing", "tgtX" );
    }


    @Test
    public void testMissingTreeAllMissingDefault() throws Exception {
        PMML4Compiler compiler = new PMML4Compiler();
        PMML pmml = compiler.loadModel( PMML, ResourceFactory.newClassPathResource( source2 ).getInputStream() );

        for ( Object o : pmml.getAssociationModelsAndBaselineModelsAndClusteringModels() ) {
            if ( o instanceof TreeModel ) {
                TreeModel tree = (TreeModel) o;
                tree.setMissingValueStrategy( MISSINGVALUESTRATEGY.DEFAULT_CHILD );
            }
        }

        String theory = compiler.generateTheory( pmml );
        if ( VERBOSE ) {
            System.out.println( theory );
        }
        KieSession kSession = getSession( theory );
        setKSession( kSession );
        setKbase( getKSession().getKieBase() );



        kSession.fireAllRules()//init model

        FactType tgt = kSession.getKieBase().getFactType( packageName, "Fld9" );
        FactType tok = kSession.getKieBase().getFactType( PMML4Helper.pmmlDefaultPackageName(), "TreeToken" );

        kSession.getEntryPoint( "in_Fld1" ).insert( -1.0 );
        kSession.getEntryPoint( "in_Fld2" ).insert( -1.0 );
        kSession.getEntryPoint( "in_Fld3" ).insert( "miss" );

        kSession.fireAllRules();

        Object token = getToken( kSession );
        assertEquals( 1.0, (Double) tok.get( token, "confidence" ), 1e-6 );
        assertEquals( "null", tok.get( token, "current" ) );
        assertEquals( 0.0, tok.get( token, "totalCount" ) );

//        checkFirstDataFieldOfTypeStatus(tgt, true, false, "Missing", "tgtX" );
    }




    @Test
    public void testMissingTreeLastChoice() throws Exception {
        PMML4Compiler compiler = new PMML4Compiler();
        PMML pmml = compiler.loadModel( PMML, ResourceFactory.newClassPathResource( source2 ).getInputStream() );

        for ( Object o : pmml.getAssociationModelsAndBaselineModelsAndClusteringModels() ) {
            if ( o instanceof TreeModel ) {
                TreeModel tree = (TreeModel) o;
                tree.setMissingValueStrategy( MISSINGVALUESTRATEGY.LAST_PREDICTION );
            }
        }

        String theory = compiler.generateTheory( pmml );
        if ( VERBOSE ) {
            System.out.println( theory );
        }
        KieSession kSession = getSession( theory );
        setKSession( kSession );
        setKbase( getKSession().getKieBase() );



        kSession.fireAllRules()//init model

        FactType tgt = kSession.getKieBase().getFactType( packageName, "Fld9" );
        FactType tok = kSession.getKieBase().getFactType( PMML4Helper.pmmlDefaultPackageName(), "TreeToken" );

        kSession.getEntryPoint( "in_Fld1" ).insert( -1.0 );
        kSession.getEntryPoint( "in_Fld2" ).insert( -1.0 );
        kSession.getEntryPoint( "in_Fld3" ).insert( "optA" );

        kSession.fireAllRules();

        Object token = getToken( kSession );
        assertEquals( 0.8, (Double) tok.get( token, "confidence" ), 1e-6 );
        assertEquals( "null", tok.get( token, "current" ) );
        assertEquals( 50.0, tok.get( token, "totalCount" ) );

        checkFirstDataFieldOfTypeStatus( tgt, true, false, "Missing", "tgtX" );
    }




    @Test
    public void testMissingTreeNull() throws Exception {
        PMML4Compiler compiler = new PMML4Compiler();
        PMML pmml = compiler.loadModel( PMML, ResourceFactory.newClassPathResource( source2 ).getInputStream() );

        for ( Object o : pmml.getAssociationModelsAndBaselineModelsAndClusteringModels() ) {
            if ( o instanceof TreeModel ) {
                TreeModel tree = (TreeModel) o;
                tree.setMissingValueStrategy( MISSINGVALUESTRATEGY.NULL_PREDICTION );
            }
        }

        String theory = compiler.generateTheory( pmml );
        if ( VERBOSE ) {
            System.out.println( theory );
        }
        KieSession kSession = getSession( theory );
        setKSession( kSession );
        setKbase( getKSession().getKieBase() );



        kSession.fireAllRules()//init model

        FactType tgt = kSession.getKieBase().getFactType( packageName, "Fld9" );
        FactType tok = kSession.getKieBase().getFactType( PMML4Helper.pmmlDefaultPackageName(), "TreeToken" );

        kSession.getEntryPoint( "in_Fld1" ).insert( -1.0 );
        kSession.getEntryPoint( "in_Fld2" ).insert( -1.0 );
        kSession.getEntryPoint( "in_Fld3" ).insert( "optA" );

        kSession.fireAllRules();

        Object token = getToken( kSession );
        assertEquals( 0.0, (Double) tok.get( token, "confidence" ), 1e-6 );
        assertEquals( "null", tok.get( token, "current" ) );
        assertEquals( 0.0, tok.get( token, "totalCount" ) );

        assertEquals( 0, getKSession().getObjects( new ClassObjectFilter( tgt.getFactClass() ) ).size() );
    }



    @Test
    public void testMissingAggregate() throws Exception {
        PMML4Compiler compiler = new PMML4Compiler();
        PMML pmml = compiler.loadModel( PMML, ResourceFactory.newClassPathResource( source2 ).getInputStream() );

        for ( Object o : pmml.getAssociationModelsAndBaselineModelsAndClusteringModels() ) {
            if ( o instanceof TreeModel ) {
                TreeModel tree = (TreeModel) o;
                tree.setMissingValueStrategy( MISSINGVALUESTRATEGY.AGGREGATE_NODES );
            }
        }

        String theory = compiler.generateTheory( pmml );
        if ( VERBOSE ) {
            System.out.println( theory );
        }
        KieSession kSession = getSession( theory );
        setKSession( kSession );
        setKbase( getKSession().getKieBase() );



        kSession.fireAllRules()//init model

        FactType tgt = kSession.getKieBase().getFactType( packageName, "Fld9" );
        FactType tok = kSession.getKieBase().getFactType( PMML4Helper.pmmlDefaultPackageName(), "TreeToken" );

        kSession.getEntryPoint( "in_Fld1" ).insert( 45.0 );
        kSession.getEntryPoint( "in_Fld2" ).insert( 90.0 );
        kSession.getEntryPoint( "in_Fld3" ).insert( "miss" );

        kSession.fireAllRules();

        Object token = getToken( kSession );
        assertEquals( 0.47, (Double) tok.get( token, "confidence" ), 1e-2 );
        assertEquals( "null", tok.get( token, "current" ) );
        assertEquals( 60.0, tok.get( token, "totalCount" ) );

        checkFirstDataFieldOfTypeStatus( tgt, true, false, "Missing", "tgtY" );
    }



    @Test
    public void testMissingTreeNone() throws Exception {
        PMML4Compiler compiler = new PMML4Compiler();
        PMML pmml = compiler.loadModel( PMML, ResourceFactory.newClassPathResource( source2 ).getInputStream() );

        for ( Object o : pmml.getAssociationModelsAndBaselineModelsAndClusteringModels() ) {
            if ( o instanceof TreeModel ) {
                TreeModel tree = (TreeModel) o;
                tree.setMissingValueStrategy( MISSINGVALUESTRATEGY.NONE );
            }
        }

        String theory = compiler.generateTheory( pmml );
        if ( VERBOSE ) {
            System.out.println( theory );
        }
        KieSession kSession = getSession( theory );
        setKSession( kSession );
        setKbase( getKSession().getKieBase() );



        kSession.fireAllRules()//init model

        FactType tgt = kSession.getKieBase().getFactType( packageName, "Fld9" );
        FactType tok = kSession.getKieBase().getFactType( PMML4Helper.pmmlDefaultPackageName(), "TreeToken" );

        kSession.getEntryPoint( "in_Fld1" ).insert( -1.0 );
        kSession.getEntryPoint( "in_Fld2" ).insert( -1.0 );
        kSession.getEntryPoint( "in_Fld3" ).insert( "miss" );

        kSession.fireAllRules();

        Object token = getToken( kSession );
        assertEquals( 0.6, (Double) tok.get( token, "confidence" ), 1e-6 );
        assertEquals( "null", tok.get( token, "current" ) );
        assertEquals( 100.0, tok.get( token, "totalCount" ) );

        checkFirstDataFieldOfTypeStatus( tgt, true, false, "Missing", "tgtX" );
    }


    @Test
    public void testSimpleTreeOutput() throws Exception {
        setKSession( getModelSession( source2, VERBOSE ) );
        setKbase( getKSession().getKieBase() );
        KieSession kSession = getKSession();

        kSession.fireAllRules()//init model

        FactType tgt = kSession.getKieBase().getFactType( packageName, "Fld9" );
        FactType tok = kSession.getKieBase().getFactType( PMML4Helper.pmmlDefaultPackageName(), "TreeToken" );

        kSession.getEntryPoint( "in_Fld1" ).insert( -1.0 );
        kSession.getEntryPoint( "in_Fld2" ).insert( -1.0 );
        kSession.getEntryPoint( "in_Fld3" ).insert( "optA" );

        kSession.fireAllRules();

        Object token = getToken( kSession );
        assertEquals( 0.8, tok.get( token, "confidence" ) );
        assertEquals( "null", tok.get( token, "current" ) );
        assertEquals( 50.0, tok.get( token, "totalCount" ) );

        checkFirstDataFieldOfTypeStatus(tgt, true, false, "Missing", "tgtX" );

        checkFirstDataFieldOfTypeStatus( kSession.getKieBase().getFactType( packageName, "OutClass" ),
                    true, false, "Missing", "tgtX" );
        checkFirstDataFieldOfTypeStatus( kSession.getKieBase().getFactType( packageName, "OutProb" ),
                    true, false, "Missing", 0.8 );


    }

}
TOP

Related Classes of org.drools.pmml.pmml_4_1.predictive.models.DecisionTreeTest

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.