Package org.neo4j.nlp.impl.util

Source Code of org.neo4j.nlp.impl.util.VectorUtilTest

package org.neo4j.nlp.impl.util;

import junit.framework.Assert;
import org.junit.Test;
import org.neo4j.graphdb.GraphDatabaseService;
import org.neo4j.graphdb.Node;
import org.neo4j.graphdb.Transaction;
import org.neo4j.nlp.helpers.GraphManager;
import org.neo4j.nlp.impl.cache.ClassRelationshipCache;
import org.neo4j.nlp.impl.cache.PatternRelationshipCache;
import org.neo4j.nlp.impl.manager.ClassNodeManager;
import org.neo4j.nlp.impl.manager.DataNodeManager;
import org.neo4j.nlp.impl.manager.DataRelationshipManager;
import org.neo4j.nlp.impl.manager.NodeManager;
import org.neo4j.test.TestGraphDatabaseFactory;

import java.io.File;
import java.io.IOException;
import java.nio.charset.Charset;
import java.nio.charset.StandardCharsets;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.*;

public class VectorUtilTest
{

    private static GraphDatabaseService setUpDb()
    {
        return new TestGraphDatabaseFactory().newImpermanentDatabaseBuilder().newGraphDatabase();
    }

    final public static Charset ENCODING = StandardCharsets.UTF_8;

    final public static String negativeSentimentDirectory = "../examples/graphify-examples-sentiment-analysis/src/main/resources/txt_sentoken/neg";
    final public static String positiveSentimentDirectory = "../examples/graphify-examples-sentiment-analysis/src/main/resources/txt_sentoken/pos";


    @Test
    public void sentimentAnalysisTest() throws IOException {
        GraphDatabaseService db = setUpDb();
        GraphManager graphManager = new GraphManager("Pattern");

        // Invalidate all caches
        NodeManager.globalNodeCache.invalidateAll();
        DataNodeManager.dataCache.invalidateAll();
        ClassNodeManager.classCache.invalidateAll();
        GraphManager.edgeCache.invalidateAll();
        GraphManager.inversePatternCache.invalidateAll();
        GraphManager.patternCache.invalidateAll();
        DataRelationshipManager.relationshipCache.invalidateAll();
        ClassRelationshipCache.relationshipCache.invalidateAll();
        PatternRelationshipCache.relationshipCache.invalidateAll();
        VectorUtil.vectorSpaceModelCache.invalidateAll();

        // Train the model on 200 examples each of positive and negative reviews
        train(db, graphManager);

        // Test the model on the next 200 examples of positive and negative reviews
        Map<String, Double> errorMap = test(db, graphManager);

        // To ensure the validity of the classifier, assert success ratio is greater than 50%
        Assert.assertTrue(errorMap.get("positive") > .5 && errorMap.get("negative") > .5);

        System.out.println(errorMap);
    }

    final static Integer trainCount = 200;

    private static Map<String, Double> test(GraphDatabaseService db, GraphManager graphManager) throws IOException {
        Integer negativeError = 0;
        Integer positiveError = 0;

        List<String> negativeText = readLargerTextFile(negativeSentimentDirectory);
        List<String> positiveText = readLargerTextFile(positiveSentimentDirectory);

        for (int i = 1; i < trainCount * 2; i++) {
            if(i % 2 == 0)
            {
                negativeError += testOnText(negativeText.get(trainCount + (i / 2)), "negative", db, graphManager) ? 0 : 1;
            } else {
                positiveError += testOnText(positiveText.get(trainCount + (i / 2)), "positive", db, graphManager) ? 0 : 1;
            }
        }

        Map<String, Double> errorMap = new HashMap<>();

        errorMap.put("negative", 1.0 - (negativeError.doubleValue() / trainCount.doubleValue()));
        errorMap.put("positive", 1.0 - (positiveError.doubleValue() / trainCount.doubleValue()));

        // Return success ratio
        return errorMap;
    }

    private static void train(GraphDatabaseService db, GraphManager graphManager) throws IOException {
        List<String> negativeText = readLargerTextFile(negativeSentimentDirectory);
        List<String> positiveText = readLargerTextFile(positiveSentimentDirectory);

        for (int i = 0; i < trainCount * 2; i++) {
            if(i % 2 == 0)
            {
                trainOnText(new String[] { negativeText.get(i / 2) }, new String[]{"negative"}, db, graphManager);
            } else {
                trainOnText(new String[] { positiveText.get(i / 2) }, new String[]{"positive"}, db, graphManager);
            }
        }
    }

    private static void trainOnText(String[] text, String[] label, GraphDatabaseService db, GraphManager graphManager) {
        Transaction tx = db.beginTx();
        getRootPatternNode(db, graphManager);
        LearningManager.trainInput(Arrays.asList(text), Arrays.asList(label), graphManager, db);
        tx.success();
        tx.close();
    }

    private static Node getRootPatternNode(GraphDatabaseService db, GraphManager graphManager) {
        Node patternNode;
        patternNode = new NodeManager().getOrCreateNode(graphManager, GraphManager.ROOT_TEMPLATE, db);
        if(!patternNode.hasProperty("matches")) {
            patternNode.setProperty("matches", 0);
            patternNode.setProperty("threshold", GraphManager.MIN_THRESHOLD);
            patternNode.setProperty("root", 1);
            patternNode.setProperty("phrase", "{0} {1}");
        }
        return patternNode;
    }

    private static String cleanText(String t) {
        t = t.replaceAll("(,|:|;)", " ");
        t = t.replaceAll("  ", " ");
        t = t.replaceAll("\\n", " ");
        t = t.replaceAll("([\\.]\\s)", " .1. .1. ");
        t = t.replaceAll("  ", " ");
        t = t.trim();

        return t;
    }

    private static boolean testOnText(String text, String label, GraphDatabaseService db, GraphManager graphManager) {
        Map<String, List<LinkedHashMap<String, Object>>> hashMap = VectorUtil.similarDocumentMapForVector(db, graphManager, cleanText(text));

        // Validate guess
        ArrayList classes = (ArrayList) hashMap.get("classes");

        if (classes.size() > 0) {
            LinkedHashMap className = (LinkedHashMap) classes.stream().findFirst().get();

            return className.get("class").equals(label);
        }
        else
        {
            return false;
        }
    }

    public static List<String> readLargerTextFile(String aFileName) throws IOException {

        List<String> files = new ArrayList<>();
        File fileDir = new File(aFileName);

        for (final File fileEntry : fileDir.listFiles()) {
            StringBuilder sb = new StringBuilder();

            Path path = Paths.get(aFileName + "/" + fileEntry.getName());
            try (Scanner scanner = new Scanner(path, ENCODING.name())) {
                while (scanner.hasNextLine()) {
                    sb.append(scanner.nextLine()).append("\n");
                }
            }

            String result = sb.toString();
            files.add(result);
        }

        return files;
    }

}
TOP

Related Classes of org.neo4j.nlp.impl.util.VectorUtilTest

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.