Package hivemall.classifier

Source Code of hivemall.classifier.PerceptronUDTFTest

/*
* Hivemall: Hive scalable Machine Learning Library
*
* Copyright (C) 2013
*   National Institute of Advanced Industrial Science and Technology (AIST)
*   Registration Number: H25PRO-1520
*
* This library is free software; you can redistribute it and/or
* modify it under the terms of the GNU Lesser General Public
* License as published by the Free Software Foundation.
*
* This library is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
* Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with this library; if not, write to the Free Software
* Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301  USA
*/
package hivemall.classifier;

import static org.junit.Assert.assertEquals;
import hivemall.io.FeatureValue;

import java.util.ArrayList;
import java.util.List;

import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils;
import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
import org.junit.Test;

public class PerceptronUDTFTest {

    @Test
    public void testInitialize() throws UDFArgumentException {
        PerceptronUDTF udtf = new PerceptronUDTF();
        ObjectInspector intOI = PrimitiveObjectInspectorFactory.javaIntObjectInspector;
        ListObjectInspector intListOI = ObjectInspectorFactory.getStandardListObjectInspector(intOI);

        /* test for INT_TYPE_NAME feature */
        StructObjectInspector intListSOI = udtf.initialize(new ObjectInspector[] { intListOI, intOI });
        assertEquals("struct<feature:int,weight:float>", intListSOI.getTypeName());

        /* test for STRING_TYPE_NAME feature */
        ObjectInspector stringOI = PrimitiveObjectInspectorFactory.javaStringObjectInspector;
        ListObjectInspector stringListOI = ObjectInspectorFactory.getStandardListObjectInspector(stringOI);
        StructObjectInspector stringListSOI = udtf.initialize(new ObjectInspector[] { stringListOI,
                intOI });
        assertEquals("struct<feature:string,weight:float>", stringListSOI.getTypeName());

        /* test for BIGINT_TYPE_NAME feature */
        ObjectInspector longOI = PrimitiveObjectInspectorFactory.javaLongObjectInspector;
        ListObjectInspector longListOI = ObjectInspectorFactory.getStandardListObjectInspector(longOI);
        StructObjectInspector longListSOI = udtf.initialize(new ObjectInspector[] { longListOI,
                intOI });
        assertEquals("struct<feature:bigint,weight:float>", longListSOI.getTypeName());
    }

    @Test
    public void testUpdate() throws UDFArgumentException {
        PerceptronUDTF udtf = new PerceptronUDTF();
        ObjectInspector stringOI = PrimitiveObjectInspectorFactory.javaStringObjectInspector;
        ListObjectInspector stringListOI = ObjectInspectorFactory.getStandardListObjectInspector(stringOI);
        udtf.initialize(new ObjectInspector[] { stringListOI,
                PrimitiveObjectInspectorFactory.javaIntObjectInspector });

        /* update weights by List<Object> */
        List<String> features1 = new ArrayList<String>();
        features1.add("good");
        features1.add("opinion");
        udtf.update(features1, 1, 0.f);

        /* check weights */
        FeatureValue word1 = FeatureValue.parse(new String("good"));
        assertEquals(1.f, udtf.model.get(word1.getFeature()).get(), 1e-5f);

        FeatureValue word2 = FeatureValue.parse(new String("opinion"));
        assertEquals(1.f, udtf.model.get(word2.getFeature()).get(), 1e-5f);

        /* update weights by List<Object> */
        List<String> features2 = new ArrayList<String>();
        features2.add("bad");
        features2.add("opinion");
        udtf.update(features2, -1, 0.f);

        /* check weights */
        assertEquals(1.f, udtf.model.get(word1.getFeature()).get(), 1e-5f);

        FeatureValue word3 = FeatureValue.parse(new String("bad"));
        assertEquals(-1.f, udtf.model.get(word3.getFeature()).get(), 1e-5f);

        FeatureValue word4 = FeatureValue.parse(new String("opinion"));
        assertEquals(0.f, udtf.model.get(word4.getFeature()).get(), 1e-5f);
    }

    @Test
    public void testUpdateStringTypeDisableBias() throws UDFArgumentException {
        PerceptronUDTF udtf = new PerceptronUDTF();
        ObjectInspector stringOI = PrimitiveObjectInspectorFactory.javaStringObjectInspector;
        ListObjectInspector stringListOI = ObjectInspectorFactory.getStandardListObjectInspector(stringOI);
        ObjectInspector param = ObjectInspectorUtils.getConstantObjectInspector(PrimitiveObjectInspectorFactory.javaStringObjectInspector, new String(""));
        udtf.initialize(new ObjectInspector[] { stringListOI,
                PrimitiveObjectInspectorFactory.javaIntObjectInspector, param });

        /* update weights by List<Object> */
        List<String> features1 = new ArrayList<String>();
        features1.add("good");
        features1.add("opinion");
        udtf.update(features1, 1, 0.f);

        /* check weights */
        FeatureValue word1 = FeatureValue.parse(new String("good"));
        assertEquals(1.f, udtf.model.get(word1.getFeature()).get(), 1e-5f);

        FeatureValue word2 = FeatureValue.parse(new String("opinion"));
        assertEquals(1.f, udtf.model.get(word2.getFeature()).get(), 1e-5f);

        /* update weights by List<Object> */
        List<String> features2 = new ArrayList<String>();
        features2.add("bad");
        features2.add("opinion");
        udtf.update(features2, -1, 0.f);

        /* check weights */
        assertEquals(1.f, udtf.model.get(word1.getFeature()).get(), 1e-5f);

        FeatureValue word3 = FeatureValue.parse(new String("bad"));
        assertEquals(-1.f, udtf.model.get(word3.getFeature()).get(), 1e-5f);

        FeatureValue word4 = FeatureValue.parse(new String("opinion"));
        assertEquals(0.f, udtf.model.get(word4.getFeature()).get(), 1e-5f);
    }
}
TOP

Related Classes of hivemall.classifier.PerceptronUDTFTest

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.