Package org.grouplens.lenskit.scored

Source Code of org.grouplens.lenskit.scored.SparseVectorScoredIdsTest

/*
* LensKit, an open source recommender systems toolkit.
* Copyright 2010-2014 LensKit Contributors.  See CONTRIBUTORS.md.
* Work on LensKit has been funded by the National Science Foundation under
* grants IIS 05-34939, 08-08692, 08-12148, and 10-17697.
*
* This program is free software; you can redistribute it and/or modify
* it under the terms of the GNU Lesser General Public License as
* published by the Free Software Foundation; either version 2.1 of the
* License, or (at your option) any later version.
*
* This program is distributed in the hope that it will be useful, but WITHOUT
* ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
* FOR A PARTICULAR PURPOSE. See the GNU General Public License for more
* details.
*
* You should have received a copy of the GNU General Public License along with
* this program; if not, write to the Free Software Foundation, Inc., 51
* Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
*/
package org.grouplens.lenskit.scored;

import it.unimi.dsi.fastutil.longs.Long2ObjectMap;
import org.grouplens.lenskit.collections.LongUtils;
import org.grouplens.lenskit.symbols.Symbol;
import org.grouplens.lenskit.symbols.TypedSymbol;
import org.grouplens.lenskit.vectors.MutableSparseVector;
import org.junit.Test;

import java.util.HashSet;
import java.util.Set;

import static org.junit.Assert.*;

public class SparseVectorScoredIdsTest {
    private final Symbol fooSym = Symbol.of("foo");
    private final Symbol barSym = Symbol.of("bar");
    private final Symbol bazSym = Symbol.of("baz");
    private final TypedSymbol<String> tsym = TypedSymbol.of(String.class, "test.wombat");
   
    @Test
    public void testSparseVectorScoredIds() {
        MutableSparseVector sv = MutableSparseVector.create(LongUtils.packedSet(1,2,4));
        sv.set(1,1.0);
        sv.set(4,16.0);
       
        MutableSparseVector foo = sv.addChannelVector(fooSym);
        foo.set(1,2.0);
        foo.set(4,5.0);
       
        MutableSparseVector bar = sv.addChannelVector(barSym);
        bar.set(1,3.0);
       
        MutableSparseVector baz = sv.addChannelVector(bazSym);
        baz.set(2, 100.0);

        Long2ObjectMap<String> wombat = sv.addChannel(tsym);
        wombat.put(1, "hello");
        wombat.put(4, "goodbye");
       
        // check that the hasUnboxedChannel function is correct.
        for (ScoredId sid: ScoredIds.collectionFromVector(sv)) {
            assertTrue(sid.hasUnboxedChannel(fooSym));
            assertFalse(sid.hasUnboxedChannel(bazSym));
            assertTrue(sid.hasChannel(tsym));
        }
       
        ScoredIdBuilder builder = new ScoredIdBuilder();
        Set<ScoredId> expected = new HashSet<ScoredId>();
        expected.add(builder.setId(1)
                            .setScore(1.0)
                            .addChannel(fooSym, 2.0)
                            .addChannel(barSym, 3.0)
                            .addChannel(tsym, "hello")
                            .build());
        expected.add(builder.clearChannels()
                            .setId(4)
                            .setScore(16.0)
                            .addChannel(fooSym, 5.0)
                            .addChannel(tsym, "goodbye")
                            .build());
       
        // get the scored ids and put them in a hashset (for comparison).
        assertEquals(expected, new HashSet<ScoredId>(ScoredIds.collectionFromVector(sv)));
    }
}
TOP

Related Classes of org.grouplens.lenskit.scored.SparseVectorScoredIdsTest

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.