Package ca.uwo.csd.ai.nlp.kernel

Source Code of ca.uwo.csd.ai.nlp.kernel.TreeKernel

package ca.uwo.csd.ai.nlp.kernel;

//import edu.stanford.nlp.trees.Tree;
import ca.uwo.csd.ai.nlp.common.Tree;
import java.util.ArrayList;
import java.util.List;
import ca.uwo.csd.ai.nlp.libsvm.svm_node;

/**
* <code>TreeKernel</code> provides a naive implementation of the kernel function described in
* 'Parsing with a single neuron: Convolution kernels for NLP problems'.
* @author Syeed Ibn Faiz
*/
public class TreeKernel implements CustomKernel {
    private final static int MAX_NODE = 300;
    double mem[][] = new double[MAX_NODE][MAX_NODE];
    private double lambda; //penalizing factor
   
    public TreeKernel() {
        this(0.5);
    }
   
    public TreeKernel(double lambda) {
        this.lambda = lambda;
    }
   
    @Override
    public double evaluate(svm_node x, svm_node y) {
        Object k1 = x.data;
        Object k2 = y.data;
       
        if (!(k1 instanceof Tree) || !(k2 instanceof Tree)) {
            throw new IllegalArgumentException("svm_node does not contain tree data.");
        }
       
        Tree t1 = (Tree) k1;
        Tree t2 = (Tree) k2;
       
        List<Tree> nodes1 = getNodes(t1);
        List<Tree> nodes2 = getNodes(t2);
               
        int N1 = Math.min(MAX_NODE, nodes1.size());
        int N2 = Math.min(MAX_NODE, nodes2.size());
       
        //fill mem with -1.0
        initMem(mem, N1, N2);
       
        double result = 0.0;
        for (int i = 0; i < N1; i++) {
            for (int j = 0; j < N2; j++) {
                result += compute(i, j, nodes1, nodes2, mem);
            }
        }
       
        return result;
    }
   
    /**
     * Efficient computation avoiding costly equals method of Tree
     * @param trees
     * @param t
     * @return
     */
    private int indexOf(List<Tree> trees, Tree t) {
        for (int i = 0; i < trees.size(); i++) {
            if (t == trees.get(i)) {
                return i;
            }
        }
        return -1;
    }
   
    private double compute(int i, int j, List<Tree> nodes1, List<Tree> nodes2, double[][] mem) {
        if (mem[i][j] >= 0) {
            return mem[i][j];
        }
        //if (sameProduction(nodes1.get(i), nodes2.get(j))) {
        if (nodes1.get(i).value().equals(nodes2.get(j).value()) &&
                nodes1.get(i).hashCode() == nodes2.get(j).hashCode()) {     //similar hashCode -> same production
           
            mem[i][j] = lambda * lambda;
            if (!nodes1.get(i).isLeaf() && !nodes2.get(j).isLeaf()) {
                List<Tree> childList1 = nodes1.get(i).getChildrenAsList();
                List<Tree> childList2 = nodes2.get(j).getChildrenAsList();
                for (int k = 0; k < childList1.size(); k++) {
                    //mem[i][j] *= 1 + compute(nodes1.indexOf(childList1.get(k)), nodes2.indexOf(childList2.get(k)), nodes1, nodes2, mem);
                    mem[i][j] *= 1 + compute(indexOf(nodes1, childList1.get(k)), indexOf(nodes2, childList2.get(k)), nodes1, nodes2, mem);
                }
            }
        } else {
            mem[i][j] = 0.0;           
        }
               
        return mem[i][j];
    }
   
    private boolean sameProduction(Tree t1, Tree t2) {
        if (t1.value().equals(t2.value())) {
            List<Tree> childList1 = t1.getChildrenAsList();
            List<Tree> childList2 = t2.getChildrenAsList();
            if (childList1.size() == childList2.size()) {
                for (int i = 0; i < childList1.size(); i++) {
                    if (!childList1.get(i).value().equals(childList2.get(i).value())) {
                        return false;
                    }
                }
                return true;
            }
        }
        return false;
    }
    private void initMem(double[][] mem, int N1, int N2) {
        for (int i = 0; i < N1; i++) {
            for (int j = 0; j < N2; j++) {
                mem[i][j] = -1.0;
            }
        }
    }
    private List<Tree> getNodes(Tree t) {
        ArrayList<Tree> nodes = new ArrayList<Tree>();
        addNodes(t, nodes);
        return nodes;
    }
   
    private void addNodes(Tree t, List<Tree> nodes) {
        nodes.add(t);
        List<Tree> childList = t.getChildrenAsList();
        for (Tree child : childList) {
            addNodes(child, nodes);
        }
    }     
}
TOP

Related Classes of ca.uwo.csd.ai.nlp.kernel.TreeKernel

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.