String word = tree.children()[0].label().value();
word = dvModel.getVocabWord(word);
// SimpleMatrix currentVector = nodeVectors.get(tree);
// SimpleMatrix currentVectorDerivative = nonlinearityVectorToDerivative(currentVector);
// SimpleMatrix derivative = deltaUp.elementMult(currentVectorDerivative);
SimpleMatrix derivative = deltaUp;
wordVectorDerivatives.put(word, wordVectorDerivatives.get(word).plus(derivative));
}
return;
}
SimpleMatrix currentVector = nodeVectors.get(tree);
SimpleMatrix currentVectorDerivative = NeuralUtils.elementwiseApplyTanhDerivative(currentVector);
SimpleMatrix scoreW = dvModel.getScoreWForNode(tree);
currentVectorDerivative = currentVectorDerivative.elementMult(scoreW.transpose());
// the delta that is used at the current nodes
SimpleMatrix deltaCurrent = deltaUp.plus(currentVectorDerivative);
SimpleMatrix W = dvModel.getWForNode(tree);
SimpleMatrix WTdelta = W.transpose().mult(deltaCurrent);
if (tree.children().length == 2) {
//TODO: RS: Change to the nice "getWForNode" setup?
String leftLabel = dvModel.basicCategory(tree.children()[0].label().value());
String rightLabel = dvModel.basicCategory(tree.children()[1].label().value());
binaryScoreDerivatives.put(leftLabel, rightLabel,
binaryScoreDerivatives.get(leftLabel, rightLabel).plus(currentVector.transpose()));
SimpleMatrix leftVector = nodeVectors.get(tree.children()[0]);
SimpleMatrix rightVector = nodeVectors.get(tree.children()[1]);
SimpleMatrix childrenVector = NeuralUtils.concatenateWithBias(leftVector, rightVector);
if (op.trainOptions.useContextWords) {
childrenVector = concatenateContextWords(childrenVector, tree.getSpan(), words);
}
SimpleMatrix W_df = deltaCurrent.mult(childrenVector.transpose());
binaryW_dfs.put(leftLabel, rightLabel, binaryW_dfs.get(leftLabel, rightLabel).plus(W_df));
// and then recurse
SimpleMatrix leftDerivative = NeuralUtils.elementwiseApplyTanhDerivative(leftVector);
SimpleMatrix rightDerivative = NeuralUtils.elementwiseApplyTanhDerivative(rightVector);
SimpleMatrix leftWTDelta = WTdelta.extractMatrix(0, deltaCurrent.numRows(), 0, 1);
SimpleMatrix rightWTDelta = WTdelta.extractMatrix(deltaCurrent.numRows(), deltaCurrent.numRows() * 2, 0, 1);
backpropDerivative(tree.children()[0], words, nodeVectors,
binaryW_dfs, unaryW_dfs,
binaryScoreDerivatives, unaryScoreDerivatives, wordVectorDerivatives,
leftDerivative.elementMult(leftWTDelta));
backpropDerivative(tree.children()[1], words, nodeVectors,
binaryW_dfs, unaryW_dfs,
binaryScoreDerivatives, unaryScoreDerivatives, wordVectorDerivatives,
rightDerivative.elementMult(rightWTDelta));
} else if (tree.children().length == 1) {
String childLabel = dvModel.basicCategory(tree.children()[0].label().value());
unaryScoreDerivatives.put(childLabel,unaryScoreDerivatives.get(childLabel).plus(currentVector.transpose()));
SimpleMatrix childVector = nodeVectors.get(tree.children()[0]);
SimpleMatrix childVectorWithBias = NeuralUtils.concatenateWithBias(childVector);
if (op.trainOptions.useContextWords) {
childVectorWithBias = concatenateContextWords(childVectorWithBias, tree.getSpan(), words);
}
SimpleMatrix W_df = deltaCurrent.mult(childVectorWithBias.transpose());
// System.out.println("unary backprop derivative for " + childLabel);
// System.out.println("Old transform:");
// System.out.println(unaryW_dfs.get(childLabel));
// System.out.println(" Delta:");
// System.out.println(W_df.scale(scale));
unaryW_dfs.put(childLabel,unaryW_dfs.get(childLabel).plus(W_df));
// and then recurse
SimpleMatrix childDerivative = NeuralUtils.elementwiseApplyTanhDerivative(childVector);
//SimpleMatrix childDerivative = childVector;
SimpleMatrix childWTDelta = WTdelta.extractMatrix(0, deltaCurrent.numRows(), 0, 1);
backpropDerivative(tree.children()[0], words, nodeVectors,
binaryW_dfs, unaryW_dfs,
binaryScoreDerivatives, unaryScoreDerivatives, wordVectorDerivatives,
childDerivative.elementMult(childWTDelta));
}