321 lines
17 KiB
Haskell
321 lines
17 KiB
Haskell
module TypeChecker where
|
|
|
|
import Debug.Trace
|
|
import AbsJavalette
|
|
import PrintJavalette
|
|
import ErrM
|
|
import Control.Monad.State
|
|
import Data.List
|
|
|
|
type AnotTree = [TopDef]
|
|
type Errors = [String]
|
|
type Env = [[(Ident, Type)]]
|
|
type MyState = (Env, AnotTree, Errors)
|
|
type MyStateM = State MyState
|
|
|
|
|
|
typecheck :: Program -> Err Program
|
|
typecheck (Program fs) = do let (at, err) = evalState (checkFuncs fs >> checkProgram fs >> getAnotTree) emptyState
|
|
if err == []
|
|
then return (Program at)
|
|
else fail $ printTree at ++ (show err)
|
|
|
|
getAnotTree :: MyStateM (AnotTree, Errors)
|
|
getAnotTree = do (_, at, e) <- get
|
|
return (at, e)
|
|
|
|
checkFuncs :: [TopDef] -> MyStateM ()
|
|
checkFuncs [] = return ()
|
|
checkFuncs (f@(FnDef t i a (Block s)):fs) = do (env, at, err) <- get
|
|
case find (\(ident, _) -> ident == i) (concat env) of
|
|
-- TODO: check if args already in scope
|
|
Nothing -> do let args = [t | (Arg t _) <- a]
|
|
let newEnv = [(i, (Fun t args))]:env
|
|
put (newEnv, at, err)
|
|
checkFuncs fs
|
|
Just _ -> do addErr (show i ++ " already in scope")
|
|
checkFuncs fs
|
|
|
|
|
|
checkProgram :: [TopDef] -> MyStateM ()
|
|
checkProgram [] = return ()
|
|
checkProgram (f:fs) = do checkFun f
|
|
checkProgram fs
|
|
|
|
checkFun :: TopDef -> MyStateM ()
|
|
checkFun (FnDef t i a (Block s)) = do pushFunScope a
|
|
ns <- checkStmts s [] t
|
|
popFunScope
|
|
b <- addArgs a []
|
|
(env, at, err) <- get
|
|
let newAt = at ++ [(FnDef t i b (Block ns))]
|
|
put (env, newAt, err)
|
|
return ()
|
|
|
|
checkStmts :: [Stmt] -> [Stmt] -> Type -> MyStateM [Stmt]
|
|
checkStmts [] nss rt = return nss
|
|
checkStmts (s:ss) nss rt = do ns <- checkStmt s rt
|
|
checkStmts ss (nss ++ [ns]) rt
|
|
|
|
checkStmt :: Stmt -> Type -> MyStateM Stmt
|
|
checkStmt s rt = case s of
|
|
|
|
BStmt (Block s) -> do pushBlockScope
|
|
ns <- checkStmts s [] rt
|
|
popBlockScope
|
|
return (BStmt (Block ns))
|
|
|
|
Decl t vars -> do nvars <- addVars t vars []
|
|
return (Decl t nvars)
|
|
|
|
Ass i e -> do vt <- findVarType i
|
|
(TAnot et ne) <- infer e
|
|
case vt of
|
|
Just t -> case t == et of
|
|
True -> return (Ass i (TAnot et ne))
|
|
False -> do addErr (show e ++ " is not of the type " ++
|
|
show t ++ " (" ++ show i ++ ")")
|
|
return (Ass i (TAnot et ne))
|
|
Nothing -> return (Ass i e)
|
|
|
|
|
|
Incr i -> do m <- findVarType i
|
|
case m of
|
|
Just Int -> return (Incr i)
|
|
Just Doub -> return (Incr i)
|
|
Just t -> do addErr ("(" ++ show i ++
|
|
") incrementing is only allowed on Int " ++
|
|
"and Doub, not on " ++
|
|
show t)
|
|
return (Incr i)
|
|
Nothing -> do addErr ("(" ++ show i ++ ") incrementing is only " ++
|
|
"allowed on Int and Doub")
|
|
return (Incr i)
|
|
|
|
Decr i -> do m <- findVarType i
|
|
case m of
|
|
Just Int -> return (Decr i)
|
|
Just Doub -> return (Decr i)
|
|
Just t -> do addErr ("(" ++ show i ++
|
|
") decrementing is only allowed on Int " ++
|
|
"and Doub, not on " ++
|
|
show t)
|
|
return (Decr i)
|
|
Nothing -> do addErr ("(" ++ show i ++ ") decrementing is only " ++
|
|
"allowed on Int and Doub")
|
|
return (Decr i)
|
|
|
|
Ret e -> do m <- infer e
|
|
case m of
|
|
(TAnot t ne) -> case t == rt of
|
|
True -> return (Ret m)
|
|
False -> do addErr (show "wrong return type " ++ show t ++
|
|
" should be " ++ show rt)
|
|
return (Ret m)
|
|
_ -> do addErr (show e ++ " return not annotated")
|
|
return (Ret m)
|
|
-- TODO check if it returns the right type
|
|
|
|
VRet -> return VRet
|
|
|
|
Cond e stm -> do m <- infer e
|
|
ns <- checkStmt stm rt
|
|
return (Cond m ns)
|
|
|
|
CondElse e s1 s2 -> do m <- infer e
|
|
ns1 <- checkStmt s1 rt
|
|
ns2 <- checkStmt s2 rt
|
|
return (CondElse m ns1 ns2)
|
|
|
|
While e stm -> do m <- infer e
|
|
ns <- checkStmt stm rt
|
|
return (While m ns)
|
|
|
|
SExp e -> do m <- infer e
|
|
return (SExp m)
|
|
s -> do addErr "Not an valid statement"
|
|
return s
|
|
|
|
addArgs :: [Arg] -> [Arg] -> MyStateM ([Arg])
|
|
addArgs [] nass = return (nass)
|
|
addArgs ((Arg t i):ass) nass = do ((e:env), at, err) <- get
|
|
case (find (\(ident, nt) -> ident == i) e) of
|
|
Nothing -> do let newEnv = ((i, t):e):env
|
|
put(trace (show newEnv) newEnv, at, err)
|
|
addArgs ass ((Arg t i):nass)
|
|
Just ident -> do addErr (show ident ++ " already in scope")
|
|
addArgs ass ((Arg t i):nass)
|
|
|
|
|
|
addVars :: Type -> [Item] -> [Item] -> MyStateM ([Item])
|
|
addVars t [] nis = return (nis)
|
|
addVars t ((NoInit i):is) nis = do ((e:env), at, err) <- get
|
|
case (find (\(ident, nt) -> ident == i) e) of
|
|
Nothing -> do let newEnv = ((i, t):e):env
|
|
put (newEnv, at, err)
|
|
addVars t is (nis ++ [(NoInit i)])
|
|
Just ident -> do addErr (show ident ++ " already initialized")
|
|
addVars t is (nis ++ [(NoInit i)])
|
|
addVars t ((Init i ex):is) nis = do (TAnot nt ne) <- infer ex
|
|
case t == nt of
|
|
True -> do ((e:env), at, err) <- get
|
|
case (find (\(ident, nt) -> ident == i) e) of
|
|
Nothing -> do let newEnv = ((i, t):e):env
|
|
put (newEnv, at, err)
|
|
addVars t is (nis ++ [(Init i (TAnot nt ne))])
|
|
Just ident -> do addErr (show ident ++ " already initialized")
|
|
addVars t is (nis ++ [(Init i (TAnot nt ne))])
|
|
False -> do addErr (show ex ++ " is not of type " ++ show t)
|
|
((e:env), at, err) <- get
|
|
let newEnv = ((i, t):e):env
|
|
put (newEnv, at, err)
|
|
addVars t is (nis ++ [(Init i ex)])
|
|
|
|
|
|
|
|
infer :: Expr -> MyStateM Expr
|
|
infer expr = case expr of
|
|
EVar i -> do m <- findVarType i
|
|
case m of
|
|
Just t -> return (TAnot t (EVar i))
|
|
Nothing -> return (EVar i)
|
|
|
|
ELitInt e -> return (TAnot Int (ELitInt e))
|
|
ELitDoub e -> return (TAnot Doub (ELitDoub e))
|
|
ELitTrue -> return (TAnot Bool ELitTrue)
|
|
ELitFalse -> return (TAnot Bool ELitFalse)
|
|
EApp i exs -> do m <- findVarType i
|
|
nexs <- inferList exs []
|
|
case m of
|
|
Just (Fun t a) -> case ((length nexs) == (length a)) of
|
|
True -> case (and [t1 == t2 | ((TAnot t1 _),t2) <- zip nexs a]) of
|
|
True -> return (TAnot t (EApp i nexs))
|
|
False -> do addErr (show i ++ ":s arguments (" ++
|
|
show nexs ++ ") are not equal to " ++
|
|
show a)
|
|
return (TAnot t (EApp i nexs))
|
|
False -> case i of
|
|
Ident "printString" -> case exs of
|
|
[EString _] -> return (EApp i exs)
|
|
_ -> do addErr ("printString only takes one literal string as a argument")
|
|
return (EApp i exs)
|
|
_ -> do addErr ("wrong number of arguments for " ++ show i)
|
|
return (TAnot t (EApp i nexs))
|
|
|
|
Nothing -> do addErr ("wrong arguments in " ++ show i)
|
|
return (EApp i exs) -- TODO: check for Nothing or other
|
|
|
|
|
|
EString e -> return (EString e)
|
|
Neg e -> do m <- infer e
|
|
case m of
|
|
(TAnot Int _) -> return (TAnot Int (Neg m))
|
|
(TAnot Doub _) -> return (TAnot Doub (Neg m))
|
|
_ -> do addErr (show e ++ " is not of type Int or Bool")
|
|
return (Neg e)
|
|
Not e -> do m <- infer e
|
|
case m of
|
|
(TAnot Bool _) -> return (TAnot Bool (Not m))
|
|
_ -> do addErr (show e ++ " is not of type Bool")
|
|
return (Not e)
|
|
|
|
EMul e1 op e2 -> do t <- findType e1 e2 [Int, Doub]
|
|
case t of
|
|
Just (Int, en1, en2) -> return (TAnot Int (EMul en1 op en2))
|
|
Just (Doub, en1, en2) -> return (TAnot Doub (EMul en1 op en2))
|
|
Nothing -> return (EMul e1 op e2)
|
|
|
|
EAdd e1 op e2 -> do t <- findType e1 e2 [Int, Doub]
|
|
case t of
|
|
Just (Int, en1, en2) -> return (TAnot Int (EAdd en1 op en2))
|
|
Just (Doub, en1, en2) -> return (TAnot Doub (EAdd en1 op en2))
|
|
Nothing -> return (EAdd e1 op e2)
|
|
|
|
ERel e1 op e2 -> case find (== op) [LTH,LE,GTH,GE] of
|
|
Just _ -> do t <- findType e1 e2 [Int, Doub]
|
|
case t of
|
|
Just (Int, en1, en2) -> return (TAnot Bool (ERel en1 op en2))
|
|
Just (Doub, en1, en2) -> return (TAnot Bool (ERel en1 op en2))
|
|
Nothing -> return (ERel e1 op e2)
|
|
|
|
Nothing -> do m <- findType e1 e2 [Int, Doub, Bool]
|
|
case m of
|
|
Just (_, en1, en2) -> return (TAnot Bool (ERel en1 op en2))
|
|
Nothing -> return (ERel e1 op e2)
|
|
|
|
EAnd e1 e2 -> do t <- findType e1 e2 [Bool]
|
|
case t of
|
|
Just (Bool, en1, en2) -> return (TAnot Bool (EAnd en1 en2))
|
|
Nothing -> return (EAnd e1 e2)
|
|
|
|
EOr e1 e2 -> do t <- findType e1 e2 [Bool]
|
|
case t of
|
|
Just (Bool, en1, en2) -> return (TAnot Bool (EOr en1 en2))
|
|
Nothing -> return (EOr e1 e2)
|
|
|
|
inferList :: [Expr] -> [Expr] -> MyStateM [Expr]
|
|
inferList [] nes = return nes
|
|
inferList (e:es) nes = do ne <- infer e
|
|
inferList es (nes ++ [ne])
|
|
|
|
findType :: Expr -> Expr -> [Type] -> MyStateM (Maybe (Type, Expr, Expr))
|
|
findType e1 e2 allowed = do t1 <- infer e1
|
|
t2 <- infer e2
|
|
let (TAnot tt1 _) = t1
|
|
let (TAnot tt2 _) = t2
|
|
case tt1 == tt2 of
|
|
True -> case t1 of
|
|
(TAnot t _) -> case (find (== t) allowed) of
|
|
Just _ -> return (Just (t, t1, t2))
|
|
Nothing -> do addErr (show t ++
|
|
" is not allowed here")
|
|
return Nothing
|
|
_ -> return Nothing
|
|
False -> do addErr (show e1 ++ " and " ++ show e2 ++ " are not of type " ++ show t1)
|
|
return Nothing
|
|
|
|
findVarType :: Ident -> MyStateM (Maybe Type)
|
|
findVarType var = do (env, at, err) <- get
|
|
let m = find (\(i, t) -> i == var) (concat env)
|
|
case m of
|
|
Just t -> return (Just (snd t))
|
|
Nothing -> do addErr ((show var) ++ " not in scope")
|
|
return Nothing
|
|
|
|
|
|
-- initializing functions
|
|
emptyState :: MyState
|
|
emptyState = (emptyEnv, [], [])
|
|
|
|
emptyEnv :: Env
|
|
emptyEnv = [[
|
|
(Ident "printString", (Fun Void [])),
|
|
(Ident "printDouble", (Fun Void [Doub])),
|
|
(Ident "printInt", (Fun Void [Int])),
|
|
(Ident "readDouble", (Fun Doub [])),
|
|
(Ident "readInt", (Fun Int []))
|
|
]]
|
|
|
|
-- helper functions
|
|
pushFunScope :: [Arg] -> MyStateM ()
|
|
pushFunScope a = do (env, at, err) <- get
|
|
let args = [(i,t) | (Arg t i) <- a]
|
|
put (args:env, at, err)
|
|
return ()
|
|
|
|
popFunScope :: MyStateM ()
|
|
popFunScope = do (env, at, err) <- get
|
|
put (tail env, at, err)
|
|
return ()
|
|
|
|
pushBlockScope :: MyStateM ()
|
|
pushBlockScope = pushFunScope []
|
|
|
|
popBlockScope :: MyStateM ()
|
|
popBlockScope = popFunScope
|
|
|
|
addErr :: String -> MyStateM ()
|
|
addErr err = do (env, at, errs) <- get
|
|
put (env, at, err:errs)
|
|
return ()
|
|
|