package org.maltparserx.parser.algorithm.stack;
import java.util.ArrayList;
import java.util.Stack;
import org.maltparserx.core.exception.MaltChainedException;
import org.maltparserx.core.syntaxgraph.DependencyStructure;
import org.maltparserx.core.syntaxgraph.node.DependencyNode;
import org.maltparserx.parser.DependencyParserConfig;
import org.maltparserx.parser.Oracle;
import org.maltparserx.parser.ParserConfiguration;
import org.maltparserx.parser.history.GuideUserHistory;
import org.maltparserx.parser.history.action.GuideUserAction;
/**
* @author Johan Hall
*
*/
public class SwapLazyOracle extends Oracle {
private ArrayList<Integer> swapArray;
private boolean swapArrayActive = false;
public SwapLazyOracle(DependencyParserConfig manager, GuideUserHistory history) throws MaltChainedException {
super(manager, history);
setGuideName("swaplazy");
swapArray = new ArrayList<Integer>();
}
public GuideUserAction predict(DependencyStructure gold, ParserConfiguration configuration) throws MaltChainedException {
final StackConfig config = (StackConfig)configuration;
final Stack<DependencyNode> stack = config.getStack();
if (!swapArrayActive) {
createSwapArray(gold);
swapArrayActive = true;
}
if (stack.size() < 2) {
return updateActionContainers(NonProjective.SHIFT, null);
} else {
final DependencyNode left = stack.get(stack.size()-2);
final DependencyNode right = stack.get(stack.size()-1);
final int leftIndex = left.getIndex();
final int rightIndex = right.getIndex();
if (swapArray.get(leftIndex) > swapArray.get(rightIndex) && necessarySwap(gold, config.getDependencyGraph(), right, config.getInput())) {
return updateActionContainers(NonProjective.SWAP, null);
} else if (!left.isRoot() && gold.getTokenNode(leftIndex).getHead().getIndex() == rightIndex
&& nodeComplete(gold, config.getDependencyGraph(), leftIndex)) {
return updateActionContainers(NonProjective.LEFTARC, gold.getTokenNode(leftIndex).getHeadEdge().getLabelSet());
} else if (gold.getTokenNode(rightIndex).getHead().getIndex() == leftIndex
&& nodeComplete(gold, config.getDependencyGraph(), rightIndex)) {
return updateActionContainers(NonProjective.RIGHTARC, gold.getTokenNode(rightIndex).getHeadEdge().getLabelSet());
} else {
return updateActionContainers(NonProjective.SHIFT, null);
}
}
}
private boolean nodeComplete(DependencyStructure gold, DependencyStructure parseDependencyGraph, int nodeIndex) {
final DependencyNode goldNode = gold.getTokenNode(nodeIndex);
final DependencyNode parseNode = parseDependencyGraph.getTokenNode(nodeIndex);
if (goldNode.hasLeftDependent()) {
if (!parseNode.hasLeftDependent()) {
return false;
} else if (goldNode.getLeftmostDependent().getIndex() != parseNode.getLeftmostDependent().getIndex()) {
return false;
}
}
if (goldNode.hasRightDependent()) {
if (!parseNode.hasRightDependent()) {
return false;
} else if (goldNode.getRightmostDependent().getIndex() != parseNode.getRightmostDependent().getIndex()) {
return false;
}
}
return true;
}
private boolean necessarySwap(DependencyStructure gold, DependencyStructure parse, DependencyNode node, Stack<DependencyNode> input) throws MaltChainedException {
DependencyNode left = node;
int index = input.size() - 1;
if (index < 0) {
return true;
}
DependencyNode right = input.peek();
int rc = -1;
while (projectiveInterval(parse, left, right)) {
if (rc == right.getIndex()) {
return false;
}
if (gold.getDependencyNode(node.getIndex()).getHead().getIndex() == right.getIndex()) {
return !leftComplete(gold, node);
}
if (gold.getDependencyNode(right.getIndex()).getHead().getIndex() == node.getIndex()) {
if (gold.getDependencyNode(right.getIndex()).hasRightDependent()) {
rc = gold.getDependencyNode(right.getIndex()).getRightmostProperDescendantIndex();
}
else {
return false;
}
}
if (index > 0) {
left = right;
right = input.get(--index);
} else {
break;
}
}
return true;
}
private boolean projectiveInterval(DependencyStructure parse, DependencyNode left, DependencyNode right) throws MaltChainedException {
final int l = swapArray.get(left.getIndex());
final int r = swapArray.get(right.getIndex());
DependencyNode node = null;
if (l > r) {
return false;
} else {
for (int i = l + 1; i < r; i++) {
for (int j = 0; j < swapArray.size(); j++) {
if (swapArray.get(j) == i) {
node = parse.getDependencyNode(j);
break;
}
}
while (node.hasHead()) {
node = node.getHead();
}
if (!(node == left || node == right)) {
return false;
}
}
return true;
}
}
private boolean leftComplete(DependencyStructure gold, DependencyNode right) throws MaltChainedException {
final DependencyNode goldNode = gold.getDependencyNode(right.getIndex());
if (!goldNode.hasLeftDependent()) {
return true;
} else if (!right.hasLeftDependent()) {
return false;
} else if (goldNode.getLeftmostDependent().getIndex() == right.getLeftmostDependent().getIndex()) {
return true;
}
return false;
}
public void finalizeSentence(DependencyStructure dependencyGraph) throws MaltChainedException {
swapArrayActive = false;
}
public void terminate() throws MaltChainedException {
}
private void createSwapArray(DependencyStructure goldDependencyGraph) throws MaltChainedException {
swapArray.clear();
final int n = goldDependencyGraph.getHighestDependencyNodeIndex();
for (int i = 0; i <= n; i++) {
swapArray.add(new Integer(i));
}
createSwapArray(goldDependencyGraph.getDependencyRoot(), 0);
}
private int createSwapArray(DependencyNode node, int order) {
int o = order;
if (node != null) {
for (int i=0; i < node.getLeftDependentCount(); i++) {
o = createSwapArray(node.getLeftDependent(i), o);
}
swapArray.set(node.getIndex(), o++);
for (int i=node.getRightDependentCount(); i >= 0; i--) {
o = createSwapArray(node.getRightDependent(i), o);
}
}
return o;
}
}