{-# OPTIONS -fglasgow-exts #-}

module ModInterp2

where

import System.IO

class MonadTrans t where
	 lift :: Monad m => m a -> t m a
class Monad m => MonadIO m where
	 liftIO :: IO a -> m a
instance MonadIO IO where
	 liftIO = id

{- 
NOTE: we can't say: --- need undecidable instances
instance (MonadTrans t, Monad (t m)) => MonadIO (t m) where 
    liftIO = lift . liftIO
-}

-------------------------------------------------------------------------------------
-------------------------------------------------------------------------------------
-------------------------------------------------------------------------------------
-- From Control.Monad.State source code

class (Monad m) => MonadState s m |  m -> s 
 where
	get :: m s
	put :: s -> m ()

newtype StateT s m a = StateT { runStateT :: s -> m (a,s) }

instance (Monad m) => Monad (StateT s m) where
	return a = StateT $ \s -> return (a, s)
	m >>= k  = StateT $ \s -> do
		(a, s') <- runStateT m s
		runStateT (k a) s'
	fail str = StateT $ \_ -> fail str

instance (Monad m) => MonadState s (StateT s m) where
	get   = StateT $ \s -> return (s, s)
	put s = StateT $ \_ -> return ((), s)

instance MonadTrans (StateT s) where
	lift m = StateT $ \s -> do
		a <- m
		return (a, s)

instance (MonadIO m) => MonadIO (StateT s m) where
	liftIO = lift . liftIO

evalStateT :: (Monad m) => StateT s m a -> s -> m a
evalStateT m s = do
	(a, _) <- runStateT m s
	return a

execStateT :: (Monad m) => StateT s m a -> s -> m s
execStateT m s = do
	(_, s') <- runStateT m s
	return s'

mapStateT :: (m (a, s) -> n (b, s)) -> StateT s m a -> StateT s n b
mapStateT f m = StateT $ f . runStateT m

withStateT :: (s -> s) -> StateT s m a -> StateT s m a
withStateT f m = StateT $ runStateT m . f


-------------------------------------------------------------------------------------
-------------------------------------------------------------------------------------
-- ---------------------------------------------------------------------------
-- class MonadError
--
--    throws an exception inside the monad and thus interrupts
--    normal execution order, until an error handler is reached}
--
--    catches an exception inside the monad (that was previously
--    thrown by throwError

class Error a where
	noMsg  :: a
	strMsg :: String -> a

	noMsg    = strMsg ""
	strMsg _ = noMsg

instance Error String where
	strMsg = id

class (Monad m) => MonadError e m | m -> e 
  where
	throwError :: e -> m a
	catchError :: m a -> (e -> m a) -> m a

newtype ErrorT e m a = ErrorT { runErrorT :: m (Either e a) }

instance (Monad m, Error e) => Monad (ErrorT e m) where
	return a = ErrorT $ return (Right a)
	m >>= k  = ErrorT $ do
		a <- runErrorT m
		case a of
			Left  l -> return (Left l)
			Right r -> runErrorT (k r)
	fail msg = ErrorT $ return (Left (strMsg msg))

instance (Monad m, Error e) => MonadError e (ErrorT e m) where
	throwError l     = ErrorT $ return (Left l)
	m `catchError` h = ErrorT $ do
		a <- runErrorT m
		case a of
			Left  l -> runErrorT (h l)
			Right r -> return (Right r)

instance (Error e) => MonadTrans (ErrorT e) where
	lift m = ErrorT $ do
		a <- m
		return (Right a)

instance (Error e, MonadIO m) => MonadIO (ErrorT e m) where
	liftIO = lift . liftIO

-------------------------------------------------------------------------------------
-------------------------------------------------------------------------------------
-------------------------------------------------------------------------------------

-- Environment monad transformer

class MonadEnv r m | m -> r 
  where
	 inEnv :: r -> m a -> m a
	 rdEnv :: m r 

data EnvT r m a = EnvT (r -> m a)

runEnvT :: EnvT r m a -> r -> m a 
runEnvT (EnvT x) = x

instance Monad m => Monad (EnvT r m) where
	 return a = EnvT $ \r -> return a
	 m >>= k  = EnvT $ \r -> runEnvT m r >>= (\a -> runEnvT (k a) r)

instance MonadTrans (EnvT r) where
	 lift m = EnvT $ \r -> m

instance MonadIO m => MonadIO (EnvT r m) where
	 liftIO = lift . liftIO

instance (Monad m) => MonadEnv r (EnvT r m) where
    inEnv r (EnvT x) = EnvT (\s -> x r)
    rdEnv = EnvT (\r -> return r)

-------------------------------------------------------------------------------------
-------------------------------------------------------------------------------------
-------------------------------------------------------------------------------------

-- Interactions between the monads

instance (MonadError e m) => MonadError e (StateT s m) where
	throwError       = lift . throwError
	m `catchError` h = StateT $ \s -> runStateT m s
		`catchError` \e -> runStateT (h e) s

instance (MonadError e m) => MonadError e (EnvT r m) where
	 throwError      = lift . throwError
	 m `catchError` h = EnvT $ \s -> runEnvT m s 
	    `catchError` \e -> runEnvT (h e) s 

instance (Error e, MonadState s m) => MonadState s (ErrorT e m) where
	get = lift get
	put = lift . put

instance (MonadState s m) => MonadState s (EnvT e m) where
	get = lift get
	put = lift . put

-------------------------------------------------------------------------------------
-------------------------------------------------------------------------------------


data BinOp = Add | Sub | Mul | Div deriving (Eq,Show)

data Term = 
    -- Basic operations (could produce errors)
    ConstInt Int | BinOp BinOp Term Term | IfZero Term Term Term	
    -- Imperative operations
  | ConstUnit | Seq Term Term | Print Term
    -- State operations
  | Inc | Get 
    -- Functions (de Bruijn Notation)
  | Lam Term | Var Int | App Term Term
    -- Exceptions
  | Throw Err | Catch Term Err Term
  deriving (Eq,Show)

data Value = UnitVal | IntVal Int | FnVal [Value] Term
			  deriving (Eq, Show)

data Err = DivByZero 
         | InvalidArg Value Term 
         | VarUnbound Int 
         | User String
				 deriving (Eq,Show)

instance Error Err where
	 noMsg    = User "none"
	 strMsg s = User s

type St = Int

type InterpM = EnvT [Value] (StateT St (ErrorT Err IO))


-- run :: InterpM Value -> IO ()
run c = do let stateM = runEnvT c []
           let errorM = evalStateT stateM (0::Int)
           x <- runErrorT errorM
           case x of 
              Left  e -> print ("INTERP ERROR:" ++ (show e) ) 
              Right v -> print (show v)

-- The monadic interpreter 

-- interp :: Term -> InterpM Value

interp (ConstUnit) = return UnitVal

interp (Seq t1 t2) = do v1 <- interp t1
                        interp t2

interp (Print t) = do v <- interp t 
                      liftIO (print (show v))
                      return UnitVal

interp (ConstInt i) = return (IntVal i)

interp t@(BinOp b t1 t2) = 
   do v1 <- interp t1 
      v2 <- interp t2
      case (v1,v2) of
         (IntVal i1, IntVal i2) -> 
             case b of 
               Add -> return (IntVal (i1 + i2))
               Sub -> return (IntVal (i1 - i2))
               Mul -> return (IntVal (i1 * i2))
               Div -> if i2 == 0 
                      then throwError DivByZero 
                      else return (IntVal (div i1 i2))
         (_,IntVal _) -> throwError (InvalidArg v1 t)
         (_,_) -> throwError (InvalidArg v2 t)

interp t@(IfZero t1 t2 t3) = 
  do v1 <- interp t1
     case v1 of 
       IntVal i1 -> if i1 == 0 then interp t2 else interp t3
       _         -> throwError (InvalidArg v1 t)

interp Get = do i <- get
                return (IntVal i)
interp Inc = 
  do i <- get
     put (i+1)
     return UnitVal

interp (Lam t) = 
    do env <- rdEnv
       return (FnVal env t)

interp (Var i) = 
    do env <- rdEnv
       if i < (length env)
         then return (env !! i)
         else throwError (VarUnbound i)

interp t@(App t1 t2) =
    do v1 <- interp t1
       v2 <- interp t2
       case v1 of 
          FnVal env t -> inEnv (v2:env) (interp t)
          _ -> throwError (InvalidArg v1 t)

interp (Throw e) = throwError e

interp (Catch t1 e1 t2) = 
	 catchError (interp t1) (\e2 -> if e1 == e2 
									        then interp t2 
									        else throwError e2)

test = run (interp (Seq Inc (App (Lam (BinOp Add (Var 0) Get)) (ConstInt 2))))










