Package org.h2.test.utils

Source Code of org.h2.test.utils.ProxyCodeGenerator

/*
* Copyright 2004-2011 H2 Group. Multiple-Licensed under the H2 License,
* Version 1.0, and under the Eclipse Public License, Version 1.0
* (http://h2database.com/html/license.html).
* Initial Developer: H2 Group
*/
package org.h2.test.utils;

import java.io.PrintWriter;
import java.io.StringWriter;
import java.lang.reflect.Constructor;
import java.lang.reflect.InvocationHandler;
import java.lang.reflect.Method;
import java.lang.reflect.Modifier;
import java.util.HashMap;
import java.util.TreeMap;
import java.util.TreeSet;
import org.h2.util.New;
import org.h2.util.SourceCompiler;

/**
* A code generator for class proxies.
*/
public class ProxyCodeGenerator {

    private static SourceCompiler compiler = new SourceCompiler();
    private static HashMap<Class<?>, Class<?>> proxyMap = New.hashMap();

    private TreeSet<String> imports = new TreeSet<String>();
    private TreeMap<String, Method> methods = new TreeMap<String, Method>();
    private String packageName;
    private String className;
    private Class<?> extendsClass;
    private Constructor<?> constructor;

    /**
     * Check whether there is already a proxy class generated.
     *
     * @param c the class
     * @return true if yes
     */
    public static boolean isGenerated(Class<?> c) {
        return proxyMap.containsKey(c);
    }

    /**
     * Generate a proxy class. The returned class extends the given class.
     *
     * @param c the class to extend
     * @return the proxy class
     */
    public static Class<?> getClassProxy(Class<?> c) throws ClassNotFoundException {
        Class<?> p = proxyMap.get(c);
        if (p != null) {
            return p;
        }
        // TODO how to extend a class with private constructor
        // TODO call right constructor
        // TODO use the right package
        ProxyCodeGenerator cg = new ProxyCodeGenerator();
        cg.setPackageName("bytecode");
        cg.generateClassProxy(c);
        StringWriter sw = new StringWriter();
        cg.write(new PrintWriter(sw));
        String code = sw.toString();
        String proxy = "bytecode."+ c.getSimpleName() + "Proxy";
        compiler.setSource(proxy, code);
        // System.out.println(code);
        Class<?> px = compiler.getClass(proxy);
        proxyMap.put(c, px);
        return px;
    }

    private void setPackageName(String packageName) {
        this.packageName = packageName;
    }

    /**
     * Generate a class that implements all static methods of the given class,
     * but as non-static.
     *
     * @param c the class to extend
     */
    void generateStaticProxy(Class<?> clazz) {
        imports.clear();
        addImport(InvocationHandler.class);
        addImport(Method.class);
        addImport(clazz);
        className = getClassName(clazz) + "Proxy";
        for (Method m : clazz.getDeclaredMethods()) {
            if (Modifier.isStatic(m.getModifiers())) {
                if (!Modifier.isPrivate(m.getModifiers())) {
                    addMethod(m);
                }
            }
        }
    }

    private void generateClassProxy(Class<?> clazz) {
        imports.clear();
        addImport(InvocationHandler.class);
        addImport(Method.class);
        addImport(clazz);
        className = getClassName(clazz) + "Proxy";
        extendsClass = clazz;
        int doNotOverride = Modifier.FINAL | Modifier.STATIC |
                Modifier.PRIVATE | Modifier.ABSTRACT | Modifier.VOLATILE;
        Class<?> dc = clazz;
        while (dc != null) {
            addImport(dc);
            for (Method m : dc.getDeclaredMethods()) {
                if ((m.getModifiers() & doNotOverride) == 0) {
                    addMethod(m);
                }
            }
            dc = dc.getSuperclass();
        }
        for (Constructor<?> c : clazz.getDeclaredConstructors()) {
            if (Modifier.isPrivate(c.getModifiers())) {
                continue;
            }
            if (constructor == null) {
                constructor = c;
            } else if (c.getParameterTypes().length < constructor.getParameterTypes().length) {
                constructor = c;
            }
        }
    }

    private void addMethod(Method m) {
        if (methods.containsKey(getMethodName(m))) {
            // already declared in a subclass
            return;
        }
        addImport(m.getReturnType());
        for (Class<?> c : m.getParameterTypes()) {
            addImport(c);
        }
        for (Class<?> c : m.getExceptionTypes()) {
            addImport(c);
        }
        methods.put(getMethodName(m), m);
    }

    private String getMethodName(Method m) {
        StringBuilder buff = new StringBuilder();
        buff.append(m.getReturnType()).append(' ');
        buff.append(m.getName());
        for (Class<?> p : m.getParameterTypes()) {
            buff.append(' ');
            buff.append(p.getName());
        }
        return buff.toString();
    }

    private void addImport(Class<?> c) {
        while (c.isArray()) {
            c = c.getComponentType();
        }
        if (!c.isPrimitive()) {
            if (!"java.lang".equals(c.getPackage().getName())) {
                imports.add(c.getName());
            }
        }
    }

    private static String getClassName(Class<?> c) {
        return getClassName(c, false);
    }

    private static String getClassName(Class<?> c, boolean varArg) {
        if (varArg) {
            c = c.getComponentType();
        }
        String s = c.getSimpleName();
        while (true) {
            c = c.getEnclosingClass();
            if (c == null) {
                break;
            }
            s = c.getSimpleName() + "." + s;
        }
        if (varArg) {
            return s + "...";
        }
        return s;
    }

    private void write(PrintWriter writer) {
        if (packageName != null) {
            writer.println("package " + packageName + ";");
        }
        for (String imp : imports) {
            writer.println("import " + imp + ";");
        }
        writer.print("public class " + className);
        if (extendsClass != null) {
            writer.print(" extends " + getClassName(extendsClass));
        }
        writer.println(" {");
        writer.println("    private final InvocationHandler ih;");
        writer.println("    public " + className + "() {");
        writer.println("        this(new InvocationHandler() {");
        writer.println("            public Object invoke(Object proxy,");
        writer.println("                    Method method, Object[] args) throws Throwable {");
        writer.println("                return method.invoke(proxy, args);");
        writer.println("            }});");
        writer.println("    }");
        writer.println("    public " + className + "(InvocationHandler ih) {");
        if (constructor != null) {
            writer.print("        super(");
            int i = 0;
            for (Class<?> p : constructor.getParameterTypes()) {
                if (i > 0) {
                    writer.print(", ");
                }
                if (p.isPrimitive()) {
                    if (p == boolean.class) {
                        writer.print("false");
                    } else if (p == byte.class) {
                        writer.print("(byte) 0");
                    } else if (p == char.class) {
                        writer.print("(char) 0");
                    } else if (p == short.class) {
                        writer.print("(short) 0");
                    } else if (p == int.class) {
                        writer.print("0");
                    } else if (p == long.class) {
                        writer.print("0L");
                    } else if (p == float.class) {
                        writer.print("0F");
                    } else if (p == double.class) {
                        writer.print("0D");
                    }
                } else {
                    writer.print("null");
                }
                i++;
            }
            writer.println(");");
        }
        writer.println("        this.ih = ih;");
        writer.println("    }");
        writer.println("    @SuppressWarnings(\"unchecked\")");
        writer.println("    private static <T extends RuntimeException> T convertException(Throwable e) {");
        writer.println("        if (e instanceof Error) {");
        writer.println("            throw (Error) e;");
        writer.println("        }");
        writer.println("        return (T) e;");
        writer.println("    }");
        for (Method m : methods.values()) {
            Class<?> retClass = m.getReturnType();
            writer.print("    ");
            if (Modifier.isProtected(m.getModifiers())) {
                // 'public' would also work
                writer.print("protected ");
            } else {
                writer.print("public ");
            }
            writer.print(getClassName(retClass) +
                " " + m.getName() + "(");
            Class<?>[] pc = m.getParameterTypes();
            for (int i = 0; i < pc.length; i++) {
                Class<?> p = pc[i];
                if (i > 0) {
                    writer.print(", ");
                }
                boolean varArg = i == pc.length - 1 && m.isVarArgs();
                writer.print(getClassName(p, varArg) + " p" + i);
            }
            writer.print(")");
            Class<?>[] ec = m.getExceptionTypes();
            writer.print(" throws RuntimeException");
            if (ec.length > 0) {
                for (Class<?> e : ec) {
                    writer.print(", ");
                    writer.print(getClassName(e));
                }
            }
            writer.println(" {");
            writer.println("        try {");
            writer.print("            ");
            if (retClass != void.class) {
                writer.print("return (");
                if (retClass == boolean.class) {
                    writer.print("Boolean");
                } else if (retClass == byte.class) {
                    writer.print("Byte");
                } else if (retClass == char.class) {
                    writer.print("Character");
                } else if (retClass == short.class) {
                    writer.print("Short");
                } else if (retClass == int.class) {
                    writer.print("Integer");
                } else if (retClass == long.class) {
                    writer.print("Long");
                } else if (retClass == float.class) {
                    writer.print("Float");
                } else if (retClass == double.class) {
                    writer.print("Double");
                } else {
                    writer.print(getClassName(retClass));
                }
                writer.print(") ");
            }
            writer.print("ih.invoke(this, ");
            writer.println(getClassName(m.getDeclaringClass()) +
                    ".class.getDeclaredMethod(\"" + m.getName() +
                    "\",");
            writer.print("                new Class[] {");
            int i = 0;
            for (Class<?> p : m.getParameterTypes()) {
                if (i > 0) {
                    writer.print(", ");
                }
                writer.print(getClassName(p) + ".class");
                i++;
            }
            writer.println("}),");
            writer.print("                new Object[] {");
            for (i = 0; i < m.getParameterTypes().length; i++) {
                if (i > 0) {
                    writer.print(", ");
                }
                writer.print("p" + i);
            }
            writer.println("});");
            writer.println("        } catch (Throwable e) {");
            writer.println("            throw convertException(e);");
            writer.println("        }");
            writer.println("    }");
        }
        writer.println("}");
        writer.flush();
    }

    public static String methodCallFormatter(Method m, Object... args) {
        StringBuilder buff = new StringBuilder();
        buff.append(m.getName()).append('(');
        for (int i = 0; i < args.length; i++) {
            Object a = args[i];
            if (i > 0) {
                buff.append(", ");
            }
            buff.append(a == null ? "null" : a.toString());
        }
        buff.append(")");
        return buff.toString();
    }

}
TOP

Related Classes of org.h2.test.utils.ProxyCodeGenerator

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.