Package org.drools.beliefs.bayes

Source Code of org.drools.beliefs.bayes.JunctionTreeBuilder

package org.drools.beliefs.bayes;

import org.drools.beliefs.graph.Edge;
import org.drools.beliefs.graph.Graph;
import org.drools.beliefs.graph.GraphNode;
import org.drools.core.util.bitmask.OpenBitSet;

import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.PriorityQueue;
import java.util.Set;

public class JunctionTreeBuilder {
    private Graph<BayesVariable> graph;
    private boolean[][]          adjacencyMatrix;

    public Graph<BayesVariable> getGraph() {
        return graph;
    }

    public JunctionTreeBuilder(Graph<BayesVariable> graph) {
        this.graph = graph;
        adjacencyMatrix = new boolean[graph.size()][graph.size()];

        for (GraphNode<BayesVariable> v : graph ) {
            for (Edge e1 : v.getInEdges()) {
                GraphNode pV1 = graph.getNode(e1.getOutGraphNode().getId());
                connect(adjacencyMatrix, pV1.getId(), v.getId());
            }
        }
    }
    public JunctionTree build() {
        return build( true );
    }

    public JunctionTree build(boolean init) {
        moralize();
        List<OpenBitSet>  cliques = triangulate();
        return junctionTree(cliques, init);
    }


    public void moralize() {
        for (GraphNode<BayesVariable> v : graph ) {
            for ( Edge e1 : v.getInEdges() ) {
                GraphNode pV1 = graph.getNode(e1.getOutGraphNode().getId());
                moralize(v, pV1);
            }
        }
    }

    public void moralize(GraphNode<BayesVariable> v, GraphNode v1) {
        for ( Edge e2 : v.getInEdges() ) {
            // moralize, by connecting each parent with each other
            GraphNode v2 = graph.getNode(e2.getOutGraphNode().getId());
            if ( v1 == v2 ) {
                continue; // don't connect to itself
            }
            if ( adjacencyMatrix[v1.getId()][v2.getId()] ) {
                // already connected, continue
                continue;
            }
            connect(getAdjacencyMatrix(), v1.getId(), v2.getId());
        }
    }

    public static void connect(boolean[][] adjMatrix, int v1, int v2) {
        adjMatrix[v1][v2] = true;
        adjMatrix[v2][v1] = true;
    }

    public static void disconnect(boolean[][] adjMatrix, int v1, int v2) {
        adjMatrix[v1][v2] = false;
        adjMatrix[v2][v1] = false;
    }

    public List<OpenBitSet> triangulate() {
        // A PriorityQueue is used, as it needs resorting, as new edges are added due to elimination
        // Build up a Priority Queue of Vertices to be eliminated
        // As their edge counts increase, th PriorityQueue ensures they efficiently maintained in a sorted order.
        boolean[][] clonedAdjMatrix = cloneAdjacencyMarix( adjacencyMatrix );
        PriorityQueue<EliminationCandidate> p = new PriorityQueue<EliminationCandidate>(graph.size());
        Map<Integer, EliminationCandidate> elmVertMap = new HashMap<Integer, EliminationCandidate>();

        for (GraphNode<BayesVariable> v : graph ) {
            EliminationCandidate elmCandVert = new EliminationCandidate(graph, clonedAdjMatrix, v);
            p.add( elmCandVert );
            elmVertMap.put( v.getId(), elmCandVert );
        }

        // Iterate and eliminate each Vertex in turn
        List<OpenBitSet> cliques = new ArrayList<OpenBitSet>();
        while ( !p.isEmpty() ) {
            EliminationCandidate v = p.remove();
            updateCliques(cliques, v.getCliqueBitSit()); // keep track of the maximal cliques formed during elimination

            // Not all vertexes get updated, as they may already be connected. Only track those that have changed edges
            Set<Integer> verticesToUpdate = new HashSet<Integer>();
            boolean[] adjList = clonedAdjMatrix[ v.getV().getId() ];
            createClique(v.getV().getId(), clonedAdjMatrix, verticesToUpdate, adjList);
            eliminateVertex(p, elmVertMap, clonedAdjMatrix, adjList, verticesToUpdate, v);
        }
        return cliques;
    }

    public void eliminateVertex(PriorityQueue<EliminationCandidate> p, Map<Integer, EliminationCandidate> elmVertMap, boolean[][] clonedAdjMatrix, boolean[] adjList, Set<Integer> verticesToUpdate, EliminationCandidate v) {
        // remove the vertex, by disconnecting all it's edges
        int id = v.getV().getId();
        for ( int i = 0; i < adjList.length; i++ ) {
            disconnect(clonedAdjMatrix, id, i);
        }

        // iterate all vertices that had their edges updated, and update their score and re-add to the priority queue
        // must also update everything they touch too
        for (Integer i : verticesToUpdate) {
            EliminationCandidate vertexToUpdate = elmVertMap.get( i );
            p.remove( vertexToUpdate );
            vertexToUpdate.update();
            p.add(vertexToUpdate);
        }
    }

    public void createClique(int v, boolean[][] clonedAdjMatrix, Set<Integer> verticesToUpdate, boolean[] adjList) {
        for ( int i = 0; i < adjList.length; i++ ) {
            if ( !adjList[i] ) {
                // not connected to this vertex
                continue;
            }
            getRelatedVerticesToUpdate(v, clonedAdjMatrix, verticesToUpdate, i);

            boolean needsConnection = false;
            for ( int j = i+1; j < adjList.length; j++ ) {
                // i + 1, so it doesn't check if a node is connected with itself
                if ( !adjList[j] || clonedAdjMatrix[i][j] ) {
                    // edge already exists
                    continue;
                }

                connect(adjacencyMatrix, i, j);
                connect(clonedAdjMatrix, i, j );
                getRelatedVerticesToUpdate(v, clonedAdjMatrix, verticesToUpdate, j);

                needsConnection = true;
            }

            if ( needsConnection ) {
                verticesToUpdate.add( i );
            }
        }
    }

    private void getRelatedVerticesToUpdate(int v, boolean[][] clonedAdjMatrix, Set<Integer> verticesToUpdate, int i) {
        verticesToUpdate.add( i );
        // must also add anything i touches
        for ( Integer k : JunctionTreeBuilder.getAdjacentVertices(clonedAdjMatrix, i) ) {
            if ( k != v ) {
                verticesToUpdate.add( k ); // don't add the v
            }
        }
    }

    public static void updateCliques(List<OpenBitSet> cliques, OpenBitSet newClique) {
        // iterate all the existing cliques, checking for a superset.
        // if no superset is found, then add the cliques
        boolean superSetFound = false;
        for ( OpenBitSet existingCluster : cliques ) {
            // is existingCluster a supserset of newClique, visa-vis is newClique a subset of  existingCluster
            if ( OpenBitSet.andNotCount(newClique, existingCluster) == 0) {
                superSetFound = true;
                break; // superset found
            }
        }
        if ( !superSetFound ) {
            cliques.add(newClique);
        }
    }

    public boolean[][] getAdjacencyMatrix() {
        return adjacencyMatrix;
    }

    /**
     * Clones the provided array
     *
     * @param src
     * @return a new clone of the provided array
     */
    public static boolean[][] cloneAdjacencyMarix(boolean[][] src) {
        int length = src.length;
        boolean[][] target = new boolean[length][src[0].length];
        for (int i = 0; i < length; i++) {
            System.arraycopy(src[i], 0, target[i], 0, src[i].length);
        }
        return target;
    }

    public JunctionTree junctionTree(List<OpenBitSet> cliques, boolean init) {
        List<SeparatorSet> list = new ArrayList<SeparatorSet>();
        for ( int i = 0; i < cliques.size(); i++ ) {
            for ( int j = i+1; j < cliques.size(); j++ ) {
                SeparatorSet separatorSet = new SeparatorSet( cliques.get(i), i, cliques.get(j), j, graph );
                if ( separatorSet.getMass() > 0 ) {
                    list.add(separatorSet);
                }
            }
        }

        list.addAll( list );
        Collections.sort(list);

        SeparatorSet[][][] sepGraphs = new SeparatorSet[cliques.size()][][];
        JunctionTreeClique[] jtNodes = new JunctionTreeClique[cliques.size()];
        JunctionTreeSeparator[] jtSeps = new JunctionTreeSeparator[cliques.size()-1];

        OpenBitSet[] varNodeToCliques = new OpenBitSet[graph.size()];
        for ( int i = 0, length = cliques.size(); i < length; i++ ) {
            // assign each Clique to a JunctionNode and give it a
            JunctionTreeClique node = new JunctionTreeClique( i, graph, cliques.get(i));
            jtNodes[i] = node;
            sepGraphs[i] = new SeparatorSet[graph.size()][graph.size()];
            mapVarNodeToCliques(varNodeToCliques, i, cliques.get(i));
        }

        for ( int i = 0, j = 0, length = cliques.size()-1; i < length; ) {
            SeparatorSet separatorSet = list.get(j++);

            JunctionTreeClique node1 =  jtNodes[separatorSet.getId1()];
            JunctionTreeClique node2 =  jtNodes[separatorSet.getId2()];

            if ( sepGraphs[node1.getId()] == sepGraphs[node2.getId()]) {
                continue;
            }

            //mergeGraphs( graphs, node1, graphs[node2.getId()] );
            mergeGraphs( sepGraphs, separatorSet);
            i++;
        }

        createJunctionTreeGraph( sepGraphs[0], jtNodes[0], jtNodes, jtSeps, 0 );

        mapNodeToCliqueFamily(varNodeToCliques, jtNodes);

        return new JunctionTree(graph, jtNodes[0], jtNodes, jtSeps, init);
    }


    public void mergeGraphs(SeparatorSet[][][] graphs, SeparatorSet separatorSet) {
        SeparatorSet[][] srcGraph = graphs[separatorSet.getId1()];
        SeparatorSet[][] trgGraph = graphs[separatorSet.getId2()];

        // merge the src graph into the trg graph
        for ( int i = 0; i < srcGraph.length; i++ ) {
            SeparatorSet[] row = srcGraph[i];
            for ( int j = 0; j < row.length; j++ ) {
                if ( row[j] != null ) {
                    trgGraph[i][j] = row[j];
                    // map the cliques to the new map
                    graphs[j] = trgGraph;
                }
            }
        }

        // add new connection
        graphs[separatorSet.getId1()] = trgGraph;
        trgGraph[separatorSet.getId1()][separatorSet.getId2()] = separatorSet;
        trgGraph[separatorSet.getId2()][separatorSet.getId1()] = separatorSet;
    }

    public int createJunctionTreeGraph(SeparatorSet[][] sepGraph, JunctionTreeClique parent, JunctionTreeClique[] jtNodes, JunctionTreeSeparator[] jtSeps, int i) {
        SeparatorSet[] row = sepGraph[parent.getId()];
        for ( int j = 0; j < row.length; j++ ) {
            if ( row[j] != null ) {
                SeparatorSet separatorSet = row[j];
                JunctionTreeClique node1 =  jtNodes[separatorSet.getId1()];
                JunctionTreeClique node2 =  jtNodes[separatorSet.getId2()];
                JunctionTreeClique child = ( node1 != parent ) ? node1 : node2; // this ensures we build it in the current parent/child order, based on recursive iteration
                JunctionTreeSeparator sepNode = new JunctionTreeSeparator(i++, parent, child, separatorSet.getIntersection(), graph );
                jtSeps[sepNode.getId()] = sepNode;

                // connection made, remove from the graph, before recursion
                sepGraph[separatorSet.getId1()][separatorSet.getId2()] = null;
                sepGraph[separatorSet.getId2()][separatorSet.getId1()] = null;
                createJunctionTreeGraph( sepGraph, child, jtNodes, jtSeps, i );
            }
        }
        return i;
    }


    /**
     * Given the set of cliques, mapped via ID in a Bitset, for a given bayes node,
     * Find the best clique. Where best clique is one that contains all it's parents
     * with the smallest number of nodes in that clique. When there are no parents
     * then simply pick the clique with the smallest number nodes.
     * @param varNodeToCliques
     * @param jtNodes
     */
    public void mapNodeToCliqueFamily(OpenBitSet[] varNodeToCliques, JunctionTreeClique[] jtNodes) {
        for ( int i = 0; i < varNodeToCliques.length; i++ ) {
            GraphNode<BayesVariable> varNode = graph.getNode( i );

            // Get OpenBitSet for parents
            OpenBitSet parents = new OpenBitSet();
            int count = 0;
            for ( Edge edge : varNode.getInEdges() ) {
                parents.set( edge.getOutGraphNode().getId() );
                count++;
            }

            if ( count == 0 ) {
                // node has no parents, so simply find the smallest clique it's in.
            }

            OpenBitSet cliques = varNodeToCliques[i];
            if  ( cliques == null ) {
                throw new IllegalStateException("Node exists, that is not part of a clique. " + varNode.toString());
            }
            int bestWeight = -1;
            int clique = -1;
            // finds the smallest node, that contains all the parents
            for ( int j = cliques.nextSetBit(0); j >= 0; j = cliques.nextSetBit( j+ 1 ) ) {
                JunctionTreeClique jtNode = jtNodes[j];

                // if the node has parents, we find the small clique it's in.
                // If it has parents then is jtNode a supserset of parents, visa-vis is parents a subset of jtNode
                if ( (count == 0 || OpenBitSet.andNotCount(parents, jtNode.getBitSet()) == 0 ) && ( clique == -1 || jtNode.getBitSet().cardinality() < bestWeight ) ) {
                    bestWeight = (int) jtNode.getBitSet().cardinality();
                    clique = j;
                }
            }

            if ( clique == -1 ) {
                throw new IllegalStateException("No clique for node found." + varNode.toString());
            }
            varNode.getContent().setFamily( clique );
            jtNodes[clique].addToFamily( varNode.getContent() );
        }

    }

    /**
     * Maps each Bayes node to cliques it's in.
     * It uses a BitSet to map the ID of the cliques
     * @param nodeToCliques
     * @param id
     * @param clique
     */
    public void mapVarNodeToCliques(OpenBitSet[] nodeToCliques, int id, OpenBitSet clique) {
        for ( int i = clique.nextSetBit(0); i >= 0; i = clique.nextSetBit( i + 1 ) ) {
             OpenBitSet cliques = nodeToCliques[i];
            if ( cliques == null ) {
                cliques = new OpenBitSet();
                nodeToCliques[i] = cliques;
            }
            cliques.set(id);
        }
    }

    public static List<Integer> getAdjacentVertices(boolean[][] adjacencyMatrix, int i) {
        List<Integer> list = new ArrayList<Integer>(adjacencyMatrix.length);
        for ( int j = 0; j < adjacencyMatrix[i].length; j++ ) {
            if ( adjacencyMatrix[i][j] ) {
                list.add( j );
            }
        }
        return list;
    }


}
TOP

Related Classes of org.drools.beliefs.bayes.JunctionTreeBuilder

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.