Package litil.tc

Source Code of litil.tc.HMTypeChecker

package litil.tc;

import litil.TypeScope;
import litil.ast.*;

import java.util.*;

public class HMTypeChecker {
    private static int _id = 0;

    public Type analyze(AstNode node, TypeScope env) {
        return analyze(node, env, new HashSet<Type>());
    }

    public Type analyze(AstNode node, TypeScope env, Set<Type> nonGen) {
        int id = _id++;
        //System.out.println("(" + id + ")" + "analyze: " + node + "\n\tenv=" + env + "\n\tnonGen=" + nonGen);
        if (node instanceof Program) {
            for (Instruction inst : ((Program) node).instructions) {
                if (inst instanceof LetBinding) {
                    LetBinding let = (LetBinding) inst;
                    if (let.args.size() > 0) {
                        env.define(let.name, Type.Function(new Type.Variable(), new Type.Variable()));
                    }
                }

            }
            for (Instruction inst : ((Program) node).instructions) {
                try {
                    Type type = analyze(inst, env, nonGen);
                    //System.out.println(inst.repr(1) + "\n\n\t=>" + type + "\n-----------------");
                } catch (TypeError e) {
                    System.err.println("" + env);
                    throw new RuntimeException("Cannot type " + inst, e);
                }
            }
            return null;
        } else if (node instanceof Expr.ENum) {
            return Type.INT;
        } else if (node instanceof Expr.EBool) {
            return Type.BOOL;
        } else if (node instanceof Expr.EStr) {
            return Type.STR;
        } else if (node instanceof Expr.EChar) {
            return Type.CHAR;
        } else if (node == Expr.EUnit) {
            return Type.UNIT;
        } else if (node instanceof Expr.EList) {
            Type valueType = new Type.Variable();
            for (Expr expr : ((Expr.EList) node).values) {
                unify(valueType, analyze(expr, env, nonGen));
            }
            return Type.List(valueType);
        } else if (node instanceof Expr.ETuple) {
            List<Type> valuesTypes = new ArrayList<Type>();
            for (Expr expr : ((Expr.ETuple) node).values) {
                valuesTypes.add(analyze(expr, env, nonGen));
            }

            return Type.Product(valuesTypes);
        } else if (node instanceof Expr.ERecord) {
            Expr.ERecord rec = (Expr.ERecord) node;
            Map<String, Type> valuesTypes = new HashMap<String, Type>();
            for (Map.Entry<String, Expr> entry : rec.values.entrySet()) {
                valuesTypes.put(entry.getKey(), analyze(entry.getValue(), env, nonGen));
            }


            return new Type.RecordType(valuesTypes);
        } else if (node instanceof Expr.EName) {
            return getType(((Expr.EName) node).name, env, nonGen);
        } else if (node instanceof Expr.EAccessor) {
            Expr.EAccessor acc = (Expr.EAccessor) node;
            Type objType = analyze(acc.obj, env, nonGen);
            Type.Variable res = new Type.Variable();
            Type.RecordType desiredType = new Type.RecordType(Collections.<String, Type>singletonMap(acc.field, res));
            unify(objType, desiredType);
            return res;
        } else if (node instanceof Expr.EAp) {
            Expr.EAp ap = (Expr.EAp) node;
            //System.out.println("(" + id + ")" + "ap=" + ap);
            Type funcType = analyze(ap.fn, env, nonGen);

            //System.out.println("(" + id + ")" + "::" + ap.fn + " > " + funcType);

            Type argType = analyze(ap.arg, env, nonGen);
            //System.out.println("(" + id + ")" + "ap.arg=" + ap.arg+" > "+argType);
            Type.Variable resType = new Type.Variable();
            try {
                unify(Type.Function(argType, resType), funcType);
            } catch (TypeError e) {
                throw new TypeError("Cannot type function application " + ap, Type.Function(argType, resType), funcType, e);
            }
            //System.out.println("(" + id + ")" + " resType of "+ap+"=> " + resType);
            return resType;
        } else if (node instanceof Expr.EIf) {
            Expr.EIf eif = (Expr.EIf) node;
            Type condType = analyze(eif.cond, env, nonGen);
            if (!isException(condType)) {
                unify(condType, Type.BOOL);
            }

            Type thenType = null;
            for (Instruction instr : eif.thenInstructions) {
                thenType = analyze(instr, env, nonGen);
            }

            Type elseType = null;
            for (Instruction instr : eif.elseInstructions) {
                elseType = analyze(instr, env, nonGen);
            }

            if (!isException(thenType) && !isException(elseType)) {
                unify(thenType, elseType);
            }
            return isException(thenType) ? elseType : thenType;
        } else if (node instanceof Expr.PatterMatching) {
            Expr.PatterMatching patmat = (Expr.PatterMatching) node;
            //System.out.println("analyzing " + patmat);
            Type inputType = analyze(patmat.input, env, nonGen);
            Type resType = new Type.Variable();
            boolean allExceptions = true;
            for (Expr.PatterMatching.Case pcase : patmat.cases) {
                //System.out.println("\thandling " + pcase);
                //unify pattern type and inputType
                //unify outcome types
                TypeScope newScope = env.child();
                Type patType = visit(pcase.pattern, newScope, nonGen);
                //System.out.println("scope after visiting pattern: " + newScope);
                //System.out.println("and patType=" + patType + " for pat " + pcase.pattern);
                //System.out.println("***************");
                //System.err.println("in " + inputType.getClass());
                if (!isException(inputType)) {
                    unify(patType, inputType);
                }
                //System.out.println("patType=" + patType + ", inputType=" + inputType);
                //System.out.println("$$$$$$$$$$$$$$$$");
                Type caseType = null;
                for (Instruction instruction : pcase.outcome) {
                    caseType = analyze(instruction, newScope, nonGen);
                }
                if (!isException(caseType)) {
                    unify(caseType, resType);
                    resType = caseType;
                    allExceptions = false;
                }

                //System.out.println("caseType=" + caseType + ", resType=" + resType);
                //System.out.println("%%%%%%%%%%%%%%%%%");

            }
            //System.out.println("resType=" + resType + " for " + patmat);
            return allExceptions ? Type.EXCEPTION : resType;
        } else if (node instanceof Expr.EThrow) {
            Expr.EThrow ethrow = (Expr.EThrow) node;
            Type exType = analyze(ethrow.exception, env, nonGen);
            unify(exType, Type.EXCEPTION);
            return Type.EXCEPTION;
        } else if (node instanceof Expr.TryCatch) {
            Expr.TryCatch tc = (Expr.TryCatch) node;


            Type tryType = null;
            for (Instruction instr : tc.tryBody) {
                tryType = analyze(instr, env, nonGen);
            }
            tryType = prune(tryType);
            System.err.println("" + tryType.getClass());
            System.err.println("TRY TYPE = " + (tryType == Type.EXCEPTION));
            Type.Variable resType = new Type.Variable();
            if (!isException(tryType)) {
                System.err.println("NOOOO");
                unify(tryType, resType);
            }


            for (Expr.PatterMatching.Case pcase : tc.catchCases) {
                TypeScope newScope = env.child();
                Type patType = visit(pcase.pattern, newScope, nonGen);
                unify(patType, Type.EXCEPTION);
                Type caseType = null;
                for (Instruction instruction : pcase.outcome) {
                    caseType = analyze(instruction, newScope, nonGen);
                }

                unify(caseType, resType);

            }

            return resType;
        } else if (node instanceof Expr.ELam) {
            Expr.ELam lam = (Expr.ELam) node;

            TypeScope newEnv = env.child();
            Set<Type> newNonGen = new HashSet<Type>(nonGen);
            List<Type> argTypes = new ArrayList<Type>();

            for (Named arg : lam.args) {
                Type argType = new Type.Variable();
                argTypes.add(argType);
                newEnv.define(arg.name, argType);
                newNonGen.add(argType);
            }
            Type resultType = null;
            for (Instruction instr : lam.instructions) {
                resultType = analyze(instr, newEnv, newNonGen);
            }
            if (lam.type != null) {
                unify(resultType, lam.type);
            }
            return Type.Function(argTypes, resultType);
        } else if (node instanceof LetBinding) {
            LetBinding let = (LetBinding) node;
            TypeScope newEnv = env.child();
            Set<Type> newNonGen = new HashSet<Type>(nonGen);
            List<Type> argTypes = new ArrayList<Type>();
            Type.Variable letTypeVar = new Type.Variable();

            //should a function be polymorphic to itself ?
            newEnv.define(let.name, letTypeVar);
            newNonGen.add(letTypeVar);
            for (Named arg : let.args) {
                Type argType = new Type.Variable();
                argTypes.add(argType);
                newEnv.define(arg.name, argType);
                newNonGen.add(argType);
            }
            Type resultType = null;
            for (Instruction instr : let.instructions) {
                resultType = analyze(instr, newEnv, newNonGen);
            }
            if (let.type != null) {
                unify(resultType, let.type);
            }
            Type letType = Type.Function(argTypes, resultType);
            env.define(let.name, letType);
            return letType;
        } else if (node instanceof DestructuringLetBinding) {
            DestructuringLetBinding let = (DestructuringLetBinding) node;
            TypeScope newEnv = env.child();
            Set<Type> newNonGen = new HashSet<Type>(nonGen);
            List<Type> argTypes = new ArrayList<Type>();
            Type letTypeVar = visit(let.main, newEnv, newNonGen);
            newNonGen.add(letTypeVar);
            for (Pattern arg : let.args) {
                Type argType = visit(arg, newEnv, newNonGen);
                argTypes.add(argType);
            }
            Type resultType = null;
            for (Instruction instr : let.instructions) {
                resultType = analyze(instr, newEnv, newNonGen);
            }

            Type letType = Type.Function(argTypes, resultType);
            define(let.main, letType, env);
            return letType;
        } else if (node instanceof DataDecl) {
            DataDecl dataDecl = (DataDecl) node;
            for (DataDecl.TypeConstructor tyCon : dataDecl.typeConstructors) {
                env.define(tyCon.name, Type.TyCon(tyCon, dataDecl));
            }
            return dataDecl.type;
        } else if (node instanceof ExceptionDecl) {
            ExceptionDecl dataDecl = (ExceptionDecl) node;
            env.define(dataDecl.name, new Type.ExceptionCon(dataDecl.name, dataDecl.types));
            return Type.EXCEPTION;
        } else {
            throw new RuntimeException("Unhandled node type " + node.getClass());
        }
    }

    private boolean isException(Type type) {
        type = prune(type);
        return ((type instanceof Type.Oper) && ((Type.Oper) type).name.equals(((Type.Oper) Type.EXCEPTION).name));
    }

    private Type visit(Pattern pat, TypeScope scope, Set<Type> nonGen) {
        if (pat instanceof Pattern.IdPattern) {
            Type.Variable res = new Type.Variable();
            String name = ((Pattern.IdPattern) pat).name;
            scope.define(name, res);
            nonGen.add(res);
            return res;
        } else if (pat instanceof Pattern.WildcardPattern) {
            return new Type.Variable();
        } else if (pat instanceof Pattern.TuplePattern) {
            Pattern.TuplePattern tup = (Pattern.TuplePattern) pat;
            List<Type> types = new ArrayList<Type>();
            for (Pattern tpat : tup.items) {
                types.add(visit(tpat, scope, nonGen));
            }
            return Type.Product(types);
        } else if (pat instanceof Pattern.TyConPattern) {
            Pattern.TyConPattern tyConPattern = (Pattern.TyConPattern) pat;
            List<Type> patTypes = new ArrayList<Type>();
            for (Pattern pattern : tyConPattern.patterns) {
                Type patType = visit(pattern, scope, nonGen);
                patTypes.add(patType);
            }
            Type refConsType = scope.get(tyConPattern.name);
            if (refConsType == null) {
                throw new RuntimeException("Unknow type constructor " + tyConPattern.name);
            } else {
                if (refConsType instanceof Type.TyCon) {
                    Type.TyCon refCons = (Type.TyCon) fresh((Type.TyCon) refConsType, nonGen);


                    if (tyConPattern.patterns.size() != refCons.types.size()) {
                        throw new RuntimeException("Bad arity: case " + pat + " provided " + tyConPattern.patterns.size() + " arguments whereas " + tyConPattern.name + " takes " + (refCons.types.size()));
                    }
                    for (int i = 0; i < patTypes.size(); i++) {
                        unify(patTypes.get(i), refCons.types.get(i));
                    }

                    //System.out.println("data type of pattern " + pat + " = " + refCons.dataType);
                    return refCons.dataType;//FIXME: fresh ?
                } else if (refConsType instanceof Type.ExceptionCon) {
                    Type.ExceptionCon refCons = (Type.ExceptionCon) refConsType;


                    if (tyConPattern.patterns.size() != refCons.types.size()) {
                        throw new RuntimeException("Bad arity: case " + pat + " provided " + tyConPattern.patterns.size() + " arguments whereas " + tyConPattern.name + " takes " + (refCons.types.size()));
                    }
                    for (int i = 0; i < patTypes.size(); i++) {
                        unify(patTypes.get(i), refCons.types.get(i));
                    }

                    //System.out.println("data type of pattern " + pat + " = " + refCons.dataType);
                    return Type.EXCEPTION;//FIXME: fresh ?
                } else {
                    throw new RuntimeException("Invalid type constructor " + tyConPattern.name);
                }
            }
        } else {
            throw new RuntimeException("Unknow pattern type " + pat);
        }
    }

    private void define(Pattern pat, Type t, TypeScope scope) {
        Type type = prune(t);
        if (pat instanceof Pattern.IdPattern) {
            scope.define(((Pattern.IdPattern) pat).name, type);
        } else if (pat instanceof Pattern.WildcardPattern) {
            // do nothing
        } else if (pat instanceof Pattern.TuplePattern) {
            Pattern.TuplePattern tup = (Pattern.TuplePattern) pat;
            if (type instanceof Type.Oper) {
                List<? extends Type> types = ((Type.Oper) type).args;
                List<Pattern> items = tup.items;
                for (int i = 0, itemsSize = items.size(); i < itemsSize; i++) {
                    Pattern tpat = items.get(i);
                    Type tpatType = types.get(i);
                    define(tpat, tpatType, scope);
                }
            } else {
                throw new RuntimeException("Invalid type " + type + "::" + type.getClass() + " for pattern " + pat);
            }
        } else if (pat instanceof Pattern.TyConPattern) {
            Pattern.TyConPattern tyConPattern = (Pattern.TyConPattern) pat;
            if (type instanceof Type.TyCon) {
                Type.TyCon tyConType = (Type.TyCon) type;
                List<Type> types = tyConType.types;
                for (int i = 0, typesSize = types.size(); i < typesSize; i++) {
                    Type tpatType = types.get(i);
                    Pattern tpat = tyConPattern.patterns.get(i);
                    define(tpat, tpatType, scope);
                }
            } else {
                throw new RuntimeException("Invalid type " + type + " for pattern " + pat);
            }
        } else {
            throw new RuntimeException("Unknow pattern type " + pat);
        }
    }

    private Type getType(String name, TypeScope env, Set<Type> nonGen) {
        if (env.get(name) != null) {
            Type fresh = fresh(env.get(name), nonGen);
            //System.out.println("getType " + name + " :: " + env.get(name) + " => " + fresh);
            return fresh;
        } else {
            throw new RuntimeException("Unknown entity " + name);
        }
    }

    private Type fresh(Type type, Set<Type> nonGen) {
        return freshRec(type, nonGen, new HashMap<Type.Variable, Type.Variable>());
    }

    private Type freshRec(Type type, Set<Type> nonGen, Map<Type.Variable, Type.Variable> mappings) {
        Type t = prune(type);
        if (t instanceof Type.Variable) {
            Type.Variable v = (Type.Variable) t;
            if (isGeneric(v, nonGen)) {
                //System.out.println("freshrec " + type + " nonGenric=" + nonGen + " ; " + mappings);
                if (!mappings.containsKey(v)) {
                    Type.Variable variable = new Type.Variable();
                    //System.out.println("new var ! " + v + " => " + variable);
                    mappings.put(v, variable);
                }
                return mappings.get(v);
            } else {
                //System.out.println("freshrec " + type + " => self");
                return v;
            }
        } else if (t instanceof Type.Oper) {
            Type.Oper oper = (Type.Oper) t;
            List<Type> args = new ArrayList<Type>(oper.args.size());
            for (Type arg : oper.args) {
                args.add(freshRec(arg, nonGen, mappings));
            }
            return new Type.Oper(oper.name, args, oper.repr);
        } else if (t instanceof Type.TyCon) {
            Type.TyCon tycon = (Type.TyCon) t;
            List<Type> types = new ArrayList<Type>(tycon.types.size());
            for (Type tpe : tycon.types) {
                types.add(freshRec(tpe, nonGen, mappings));
            }
            Type dataType = freshRec(tycon.dataType, nonGen, mappings);
            return new Type.TyCon(tycon.name, types, dataType);
        } else if (t instanceof Type.RecordType) {
            Type.RecordType recType = (Type.RecordType) t;
            Map<String, Type> types = new HashMap<String, Type>(recType.types.size());
            for (Map.Entry<String, Type> entry : recType.types.entrySet()) {
                types.put(entry.getKey(), freshRec(entry.getValue(), nonGen, mappings));
            }
            return new Type.RecordType(types);
        } else if (t instanceof Type.ExceptionCon) {
            return t;
        } else {
            throw new RuntimeException("Wat ? " + type);
        }

    }

    private boolean isGeneric(Type.Variable v, Set<Type> nonGen) {
        return !occursIn(v, nonGen);
    }


    private Type prune(Type type) {
        if (type instanceof Type.Variable) {
            Type.Variable v = (Type.Variable) type;
            if (v.instance != null) {
                v.instance = prune(v.instance);
                ////System.out.println("prune "+type+" => "+v.instance);
                return v.instance;

            }
        }
        ////System.out.println("prune "+type+" => "+type);
        return type;
    }

    private void unify(Type t1, Type t2) {
        //System.out.println("unify " + t1 + ", " + t2 + " : " + Thread.currentThread().getStackTrace()[2].toString());
        Type type1 = prune(t1);
        Type type2 = prune(t2);
        //System.out.println("unify(p) " + type1 + ", " + type2);
        if (type1 instanceof Type.TyCon) {
            type1 = ((Type.TyCon) type1).makeFn();
        } else if (type1 instanceof Type.ExceptionCon) {
            type1 = ((Type.ExceptionCon) type1).makeFn();
        }

        if (type2 instanceof Type.TyCon) {
            type2 = ((Type.TyCon) type2).makeFn();
        } else if (type2 instanceof Type.ExceptionCon) {
            type2 = ((Type.ExceptionCon) type2).makeFn();
        }
        //System.out.println("unify(c) " + type1 + ", " + type2);

        if (type1 instanceof Type.Variable) {
            if (!type1.equals(type2)) {
                Type.Variable v = (Type.Variable) type1;
                if (occursInType(v, type2)) {
                    throw new TypeError("Recursive unification", t1, t2);
                }
                //System.err.println("!!");
                v.instance = type2;
            }
        } else if (type2 instanceof Type.Variable) {
            unify(type2, type1);
        } else if (type1 instanceof Type.Oper && type2 instanceof Type.Oper) {
            Type.Oper o1 = (Type.Oper) type1;
            Type.Oper o2 = (Type.Oper) type2;

            if (!o1.name.equals(o2.name) || o1.args.size() != o2.args.size()) {
                throw new TypeError("Type mismatch: ", t1, t2);
            } else {
                for (int i = 0; i < o1.args.size(); i++) {
                    unify(o1.args.get(i), o2.args.get(i));
                }
            }
        } else if (type1 instanceof Type.RecordType && type2 instanceof Type.RecordType) {
            Type.RecordType recType1 = (Type.RecordType) type1;
            Type.RecordType recType2 = (Type.RecordType) type2;
            for (Map.Entry<String, Type> entry : recType2.types.entrySet()) {
                if (!recType1.types.containsKey(entry.getKey())) {
                    throw new TypeError("Cannot unify, no field " + entry.getKey(), type1, type2);
                }
                unify(entry.getValue(), recType1.types.get(entry.getKey()));
            }
        } else {
            throw new RuntimeException("Wat ? " + t1 + ", " + t2);
        }

    }

    private boolean occursIn(Type.Variable v, Set<Type> nonGen) {
        for (Type type : nonGen) {
            if (occursInType(v, type)) {
                return true;
            }
        }
        return false;
    }

    private boolean occursInType(Type.Variable v, Type type2) {
        ////System.out.println("occurs? "+v+" in "+type2);
        Type pruned = prune(type2);
        if (pruned.equals(v)) {
            //System.out.println(v + " in " + type2);
            return true;
        } else if (pruned instanceof Type.Oper) {
            Type.Oper o = (Type.Oper) pruned;
            for (Type arg : o.args) {
                if (occursInType(v, arg)) {
                    //System.out.println(v + " in " + type2 + " args");
                    return true;
                }
            }
        }
        return false;
    }

}
TOP

Related Classes of litil.tc.HMTypeChecker

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.