package org.apache.lucene.expressions.js;
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
import java.io.IOException;
import java.io.Reader;
import java.lang.reflect.Constructor;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.lang.reflect.Modifier;
import java.text.ParseException;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.Map;
import java.util.Properties;
import org.antlr.runtime.ANTLRStringStream;
import org.antlr.runtime.CharStream;
import org.antlr.runtime.CommonTokenStream;
import org.antlr.runtime.RecognitionException;
import org.antlr.runtime.tree.Tree;
import org.apache.lucene.expressions.Expression;
import org.apache.lucene.queries.function.FunctionValues;
import org.apache.lucene.util.IOUtils;
import org.objectweb.asm.ClassWriter;
import org.objectweb.asm.Label;
import org.objectweb.asm.Opcodes;
import org.objectweb.asm.Type;
import org.objectweb.asm.commons.GeneratorAdapter;
/**
* An expression compiler for javascript expressions.
* <p>
* Example:
* <pre class="prettyprint">
* Expression foo = JavascriptCompiler.compile("((0.3*popularity)/10.0)+(0.7*score)");
* </pre>
* <p>
* See the {@link org.apache.lucene.expressions.js package documentation} for
* the supported syntax and default functions.
* <p>
* You can compile with an alternate set of functions via {@link #compile(String, Map, ClassLoader)}.
* For example:
* <pre class="prettyprint">
* Map<String,Method> functions = new HashMap<String,Method>();
* // add all the default functions
* functions.putAll(JavascriptCompiler.DEFAULT_FUNCTIONS);
* // add cbrt()
* functions.put("cbrt", Math.class.getMethod("cbrt", double.class));
* // call compile with customized function map
* Expression foo = JavascriptCompiler.compile("cbrt(score)+ln(popularity)",
* functions,
* getClass().getClassLoader());
* </pre>
*
* @lucene.experimental
*/
public class JavascriptCompiler {
static final class Loader extends ClassLoader {
Loader(ClassLoader parent) {
super(parent);
}
public Class<? extends Expression> define(String className, byte[] bytecode) {
return defineClass(className, bytecode, 0, bytecode.length).asSubclass(Expression.class);
}
}
private static final int CLASSFILE_VERSION = Opcodes.V1_6;
// We use the same class name for all generated classes as they all have their own class loader.
// The source code is displayed as "source file name" in stack trace.
private static final String COMPILED_EXPRESSION_CLASS = JavascriptCompiler.class.getName() + "$CompiledExpression";
private static final String COMPILED_EXPRESSION_INTERNAL = COMPILED_EXPRESSION_CLASS.replace('.', '/');
private static final Type EXPRESSION_TYPE = Type.getType(Expression.class);
private static final Type FUNCTION_VALUES_TYPE = Type.getType(FunctionValues.class);
private static final org.objectweb.asm.commons.Method
EXPRESSION_CTOR = getMethod("void <init>(String, String[])"),
EVALUATE_METHOD = getMethod("double evaluate(int, " + FunctionValues.class.getName() + "[])"),
DOUBLE_VAL_METHOD = getMethod("double doubleVal(int)");
// to work around import clash:
private static org.objectweb.asm.commons.Method getMethod(String method) {
return org.objectweb.asm.commons.Method.getMethod(method);
}
// This maximum length is theoretically 65535 bytes, but as its CESU-8 encoded we dont know how large it is in bytes, so be safe
// rcmuir: "If your ranking function is that large you need to check yourself into a mental institution!"
private static final int MAX_SOURCE_LENGTH = 16384;
private final String sourceText;
private final Map<String, Integer> externalsMap = new LinkedHashMap<String, Integer>();
private final ClassWriter classWriter = new ClassWriter(ClassWriter.COMPUTE_FRAMES | ClassWriter.COMPUTE_MAXS);
private GeneratorAdapter gen;
private final Map<String,Method> functions;
/**
* Compiles the given expression.
*
* @param sourceText The expression to compile
* @return A new compiled expression
* @throws ParseException on failure to compile
*/
public static Expression compile(String sourceText) throws ParseException {
return new JavascriptCompiler(sourceText).compileExpression(JavascriptCompiler.class.getClassLoader());
}
/**
* Compiles the given expression with the supplied custom functions.
* <p>
* Functions must be {@code public static}, return {@code double} and
* can take from zero to 256 {@code double} parameters.
*
* @param sourceText The expression to compile
* @param functions map of String names to functions
* @param parent a {@code ClassLoader} that should be used as the parent of the loaded class.
* It must contain all classes referred to by the given {@code functions}.
* @return A new compiled expression
* @throws ParseException on failure to compile
*/
public static Expression compile(String sourceText, Map<String,Method> functions, ClassLoader parent) throws ParseException {
if (parent == null) {
throw new NullPointerException("A parent ClassLoader must be given.");
}
for (Method m : functions.values()) {
checkFunction(m, parent);
}
return new JavascriptCompiler(sourceText, functions).compileExpression(parent);
}
/**
* This method is unused, it is just here to make sure that the function signatures don't change.
* If this method fails to compile, you also have to change the byte code generator to correctly
* use the FunctionValues class.
*/
@SuppressWarnings({"unused", "null"})
private static void unusedTestCompile() {
FunctionValues f = null;
double ret = f.doubleVal(2);
}
/**
* Constructs a compiler for expressions.
* @param sourceText The expression to compile
*/
private JavascriptCompiler(String sourceText) {
this(sourceText, DEFAULT_FUNCTIONS);
}
/**
* Constructs a compiler for expressions with specific set of functions
* @param sourceText The expression to compile
*/
private JavascriptCompiler(String sourceText, Map<String,Method> functions) {
if (sourceText == null) {
throw new NullPointerException();
}
this.sourceText = sourceText;
this.functions = functions;
}
/**
* Compiles the given expression with the specified parent classloader
*
* @return A new compiled expression
* @throws ParseException on failure to compile
*/
private Expression compileExpression(ClassLoader parent) throws ParseException {
try {
Tree antlrTree = getAntlrComputedExpressionTree();
beginCompile();
recursiveCompile(antlrTree, Type.DOUBLE_TYPE);
endCompile();
Class<? extends Expression> evaluatorClass = new Loader(parent)
.define(COMPILED_EXPRESSION_CLASS, classWriter.toByteArray());
Constructor<? extends Expression> constructor = evaluatorClass.getConstructor(String.class, String[].class);
return constructor.newInstance(sourceText, externalsMap.keySet().toArray(new String[externalsMap.size()]));
} catch (InstantiationException exception) {
throw new IllegalStateException("An internal error occurred attempting to compile the expression (" + sourceText + ").", exception);
} catch (IllegalAccessException exception) {
throw new IllegalStateException("An internal error occurred attempting to compile the expression (" + sourceText + ").", exception);
} catch (NoSuchMethodException exception) {
throw new IllegalStateException("An internal error occurred attempting to compile the expression (" + sourceText + ").", exception);
} catch (InvocationTargetException exception) {
throw new IllegalStateException("An internal error occurred attempting to compile the expression (" + sourceText + ").", exception);
}
}
private void beginCompile() {
classWriter.visit(CLASSFILE_VERSION,
Opcodes.ACC_PUBLIC | Opcodes.ACC_SUPER | Opcodes.ACC_FINAL | Opcodes.ACC_SYNTHETIC,
COMPILED_EXPRESSION_INTERNAL,
null, EXPRESSION_TYPE.getInternalName(), null);
String clippedSourceText = (sourceText.length() <= MAX_SOURCE_LENGTH) ?
sourceText : (sourceText.substring(0, MAX_SOURCE_LENGTH - 3) + "...");
classWriter.visitSource(clippedSourceText, null);
GeneratorAdapter constructor = new GeneratorAdapter(Opcodes.ACC_PUBLIC | Opcodes.ACC_SYNTHETIC,
EXPRESSION_CTOR, null, null, classWriter);
constructor.loadThis();
constructor.loadArgs();
constructor.invokeConstructor(EXPRESSION_TYPE, EXPRESSION_CTOR);
constructor.returnValue();
constructor.endMethod();
gen = new GeneratorAdapter(Opcodes.ACC_PUBLIC | Opcodes.ACC_SYNTHETIC,
EVALUATE_METHOD, null, null, classWriter);
}
private void recursiveCompile(Tree current, Type expected) {
int type = current.getType();
String text = current.getText();
switch (type) {
case JavascriptParser.AT_CALL:
Tree identifier = current.getChild(0);
String call = identifier.getText();
int arguments = current.getChildCount() - 1;
Method method = functions.get(call);
if (method == null) {
throw new IllegalArgumentException("Unrecognized method call (" + call + ").");
}
int arity = method.getParameterTypes().length;
if (arguments != arity) {
throw new IllegalArgumentException("Expected (" + arity + ") arguments for method call (" +
call + "), but found (" + arguments + ").");
}
for (int argument = 1; argument <= arguments; ++argument) {
recursiveCompile(current.getChild(argument), Type.DOUBLE_TYPE);
}
gen.invokeStatic(Type.getType(method.getDeclaringClass()),
org.objectweb.asm.commons.Method.getMethod(method));
gen.cast(Type.DOUBLE_TYPE, expected);
break;
case JavascriptParser.NAMESPACE_ID:
int index;
if (externalsMap.containsKey(text)) {
index = externalsMap.get(text);
} else {
index = externalsMap.size();
externalsMap.put(text, index);
}
gen.loadArg(1);
gen.push(index);
gen.arrayLoad(FUNCTION_VALUES_TYPE);
gen.loadArg(0);
gen.invokeVirtual(FUNCTION_VALUES_TYPE, DOUBLE_VAL_METHOD);
gen.cast(Type.DOUBLE_TYPE, expected);
break;
case JavascriptParser.HEX:
pushLong(expected, Long.parseLong(text.substring(2), 16));
break;
case JavascriptParser.OCTAL:
pushLong(expected, Long.parseLong(text.substring(1), 8));
break;
case JavascriptParser.DECIMAL:
gen.push(Double.parseDouble(text));
gen.cast(Type.DOUBLE_TYPE, expected);
break;
case JavascriptParser.AT_NEGATE:
recursiveCompile(current.getChild(0), Type.DOUBLE_TYPE);
gen.visitInsn(Opcodes.DNEG);
gen.cast(Type.DOUBLE_TYPE, expected);
break;
case JavascriptParser.AT_ADD:
pushArith(Opcodes.DADD, current, expected);
break;
case JavascriptParser.AT_SUBTRACT:
pushArith(Opcodes.DSUB, current, expected);
break;
case JavascriptParser.AT_MULTIPLY:
pushArith(Opcodes.DMUL, current, expected);
break;
case JavascriptParser.AT_DIVIDE:
pushArith(Opcodes.DDIV, current, expected);
break;
case JavascriptParser.AT_MODULO:
pushArith(Opcodes.DREM, current, expected);
break;
case JavascriptParser.AT_BIT_SHL:
pushShift(Opcodes.LSHL, current, expected);
break;
case JavascriptParser.AT_BIT_SHR:
pushShift(Opcodes.LSHR, current, expected);
break;
case JavascriptParser.AT_BIT_SHU:
pushShift(Opcodes.LUSHR, current, expected);
break;
case JavascriptParser.AT_BIT_AND:
pushBitwise(Opcodes.LAND, current, expected);
break;
case JavascriptParser.AT_BIT_OR:
pushBitwise(Opcodes.LOR, current, expected);
break;
case JavascriptParser.AT_BIT_XOR:
pushBitwise(Opcodes.LXOR, current, expected);
break;
case JavascriptParser.AT_BIT_NOT:
recursiveCompile(current.getChild(0), Type.LONG_TYPE);
gen.push(-1L);
gen.visitInsn(Opcodes.LXOR);
gen.cast(Type.LONG_TYPE, expected);
break;
case JavascriptParser.AT_COMP_EQ:
pushCond(GeneratorAdapter.EQ, current, expected);
break;
case JavascriptParser.AT_COMP_NEQ:
pushCond(GeneratorAdapter.NE, current, expected);
break;
case JavascriptParser.AT_COMP_LT:
pushCond(GeneratorAdapter.LT, current, expected);
break;
case JavascriptParser.AT_COMP_GT:
pushCond(GeneratorAdapter.GT, current, expected);
break;
case JavascriptParser.AT_COMP_LTE:
pushCond(GeneratorAdapter.LE, current, expected);
break;
case JavascriptParser.AT_COMP_GTE:
pushCond(GeneratorAdapter.GE, current, expected);
break;
case JavascriptParser.AT_BOOL_NOT:
Label labelNotTrue = new Label();
Label labelNotReturn = new Label();
recursiveCompile(current.getChild(0), Type.INT_TYPE);
gen.visitJumpInsn(Opcodes.IFEQ, labelNotTrue);
pushBoolean(expected, false);
gen.goTo(labelNotReturn);
gen.visitLabel(labelNotTrue);
pushBoolean(expected, true);
gen.visitLabel(labelNotReturn);
break;
case JavascriptParser.AT_BOOL_AND:
Label andFalse = new Label();
Label andEnd = new Label();
recursiveCompile(current.getChild(0), Type.INT_TYPE);
gen.visitJumpInsn(Opcodes.IFEQ, andFalse);
recursiveCompile(current.getChild(1), Type.INT_TYPE);
gen.visitJumpInsn(Opcodes.IFEQ, andFalse);
pushBoolean(expected, true);
gen.goTo(andEnd);
gen.visitLabel(andFalse);
pushBoolean(expected, false);
gen.visitLabel(andEnd);
break;
case JavascriptParser.AT_BOOL_OR:
Label orTrue = new Label();
Label orEnd = new Label();
recursiveCompile(current.getChild(0), Type.INT_TYPE);
gen.visitJumpInsn(Opcodes.IFNE, orTrue);
recursiveCompile(current.getChild(1), Type.INT_TYPE);
gen.visitJumpInsn(Opcodes.IFNE, orTrue);
pushBoolean(expected, false);
gen.goTo(orEnd);
gen.visitLabel(orTrue);
pushBoolean(expected, true);
gen.visitLabel(orEnd);
break;
case JavascriptParser.AT_COND_QUE:
Label condFalse = new Label();
Label condEnd = new Label();
recursiveCompile(current.getChild(0), Type.INT_TYPE);
gen.visitJumpInsn(Opcodes.IFEQ, condFalse);
recursiveCompile(current.getChild(1), expected);
gen.goTo(condEnd);
gen.visitLabel(condFalse);
recursiveCompile(current.getChild(2), expected);
gen.visitLabel(condEnd);
break;
default:
throw new IllegalStateException("Unknown operation specified: (" + current.getText() + ").");
}
}
private void pushArith(int operator, Tree current, Type expected) {
pushBinaryOp(operator, current, expected, Type.DOUBLE_TYPE, Type.DOUBLE_TYPE, Type.DOUBLE_TYPE);
}
private void pushShift(int operator, Tree current, Type expected) {
pushBinaryOp(operator, current, expected, Type.LONG_TYPE, Type.INT_TYPE, Type.LONG_TYPE);
}
private void pushBitwise(int operator, Tree current, Type expected) {
pushBinaryOp(operator, current, expected, Type.LONG_TYPE, Type.LONG_TYPE, Type.LONG_TYPE);
}
private void pushBinaryOp(int operator, Tree current, Type expected, Type arg1, Type arg2, Type returnType) {
recursiveCompile(current.getChild(0), arg1);
recursiveCompile(current.getChild(1), arg2);
gen.visitInsn(operator);
gen.cast(returnType, expected);
}
private void pushCond(int operator, Tree current, Type expected) {
Label labelTrue = new Label();
Label labelReturn = new Label();
recursiveCompile(current.getChild(0), Type.DOUBLE_TYPE);
recursiveCompile(current.getChild(1), Type.DOUBLE_TYPE);
gen.ifCmp(Type.DOUBLE_TYPE, operator, labelTrue);
pushBoolean(expected, false);
gen.goTo(labelReturn);
gen.visitLabel(labelTrue);
pushBoolean(expected, true);
gen.visitLabel(labelReturn);
}
private void pushBoolean(Type expected, boolean truth) {
switch (expected.getSort()) {
case Type.INT:
gen.push(truth);
break;
case Type.LONG:
gen.push(truth ? 1L : 0L);
break;
case Type.DOUBLE:
gen.push(truth ? 1. : 0.);
break;
default:
throw new IllegalStateException("Invalid expected type: " + expected);
}
}
private void pushLong(Type expected, long i) {
switch (expected.getSort()) {
case Type.INT:
gen.push((int) i);
break;
case Type.LONG:
gen.push(i);
break;
case Type.DOUBLE:
gen.push((double) i);
break;
default:
throw new IllegalStateException("Invalid expected type: " + expected);
}
}
private void endCompile() {
gen.returnValue();
gen.endMethod();
classWriter.visitEnd();
}
private Tree getAntlrComputedExpressionTree() throws ParseException {
CharStream input = new ANTLRStringStream(sourceText);
JavascriptLexer lexer = new JavascriptLexer(input);
CommonTokenStream tokens = new CommonTokenStream(lexer);
JavascriptParser parser = new JavascriptParser(tokens);
try {
return parser.expression().tree;
} catch (RecognitionException exception) {
throw new IllegalArgumentException(exception);
} catch (RuntimeException exception) {
if (exception.getCause() instanceof ParseException) {
throw (ParseException)exception.getCause();
}
throw exception;
}
}
/**
* The default set of functions available to expressions.
* <p>
* See the {@link org.apache.lucene.expressions.js package documentation}
* for a list.
*/
public static final Map<String,Method> DEFAULT_FUNCTIONS;
static {
Map<String,Method> map = new HashMap<String,Method>();
Reader in = null;
try {
final Properties props = new Properties();
in = IOUtils.getDecodingReader(JavascriptCompiler.class,
JavascriptCompiler.class.getSimpleName() + ".properties", IOUtils.CHARSET_UTF_8);
props.load(in);
for (final String call : props.stringPropertyNames()) {
final String[] vals = props.getProperty(call).split(",");
if (vals.length != 3) {
throw new Error("Syntax error while reading Javascript functions from resource");
}
final Class<?> clazz = Class.forName(vals[0].trim());
final String methodName = vals[1].trim();
final int arity = Integer.parseInt(vals[2].trim());
@SuppressWarnings({"rawtypes", "unchecked"}) Class[] args = new Class[arity];
Arrays.fill(args, double.class);
Method method = clazz.getMethod(methodName, args);
checkFunction(method, JavascriptCompiler.class.getClassLoader());
map.put(call, method);
}
} catch (NoSuchMethodException e) {
throw new Error("Cannot resolve function", e);
} catch (ClassNotFoundException e) {
throw new Error("Cannot resolve function", e);
} catch (IOException e) {
throw new Error("Cannot resolve function", e);
} finally {
IOUtils.closeWhileHandlingException(in);
}
DEFAULT_FUNCTIONS = Collections.unmodifiableMap(map);
}
private static void checkFunction(Method method, ClassLoader parent) {
// We can only call the function if the given parent class loader of our compiled class has access to the method:
final ClassLoader functionClassloader = method.getDeclaringClass().getClassLoader();
if (functionClassloader != null) { // it is a system class iff null!
boolean found = false;
while (parent != null) {
if (parent == functionClassloader) {
found = true;
break;
}
parent = parent.getParent();
}
if (!found) {
throw new IllegalArgumentException(method + " is not declared by a class which is accessible by the given parent ClassLoader.");
}
}
// do some checks if the signature is "compatible":
if (!Modifier.isStatic(method.getModifiers())) {
throw new IllegalArgumentException(method + " is not static.");
}
if (!Modifier.isPublic(method.getModifiers())) {
throw new IllegalArgumentException(method + " is not public.");
}
if (!Modifier.isPublic(method.getDeclaringClass().getModifiers())) {
throw new IllegalArgumentException(method.getDeclaringClass().getName() + " is not public.");
}
for (Class<?> clazz : method.getParameterTypes()) {
if (!clazz.equals(double.class)) {
throw new IllegalArgumentException(method + " must take only double parameters");
}
}
if (method.getReturnType() != double.class) {
throw new IllegalArgumentException(method + " does not return a double.");
}
}
}