{-# LANGUAGE GADTs, TypeFamilies, EmptyDataDecls, ScopedTypeVariables #-}

module CPS where

-- Demonstration of type families with CPS translation
-- uses de Bruijn representation of binding

-- HOMEWORK assignment: Add Callcc and Throw to source language, 
-- define semantics through CPS translation.

data Void   -- Empty type
data Cont t -- type of continuation

data Var g t where
    Z      :: Var (g, t) t
    S      :: Var g t -> Var (g,t') t

data Exp g t where 
    Var    :: Var g t -> Exp g t
    Lit    :: Int ->  Exp g Int
    Lam    :: (Exp (g, t1) t2) -> Exp g (t1 -> t2)
    App    :: Exp g (t1 -> t2) -> Exp g t1 -> Exp g t2
    Halt   :: Exp g Int -> Exp g t
    -- new constructors
    Callcc :: Exp g (Cont t1 -> t1) -> Exp g t1
    Throw  :: Exp g (Cont t1) -> Exp g t1 -> Exp g t2

data Env g where
    ENil   :: Env ()
    ECons  :: Env g -> t -> Env (g, t)

------------------------------------------------------------------------
-- show instances

instance Show Void where
  show x = seq x (error "impossible")

instance Show (Var g n) where
  show Z = "Z"
  show (S x) = "S" ++ show x

instance Show (Exp g n) where
  show (Var v)     = "(Var " ++ show v ++ ")"
  show (Lit i)     = "(Lit " ++ show i ++ ")"
  show (Lam t)     = "(Lam " ++ show t ++ ")"
  show (App t u)   = "(App " ++ show t ++ " " ++ show u ++ ")"
  show (Halt t)    = "(Halt " ++ show t ++ ")"
  show (Callcc t)  = "(Callcc " ++ show t ++ ")"
  show (Throw t u) = "(Throw " ++ show t ++ " " ++ show u ++ ")"
-----------------------------------------------------------------------
-- An interpreter

sLookup :: Var env t -> Env env -> t
sLookup Z     (ECons g v) = v
sLookup (S x) (ECons g v) = sLookup x g

interp :: Env env -> Exp env t -> t
interp e (Var x)     = sLookup x e
interp e (Lit i)     = i
interp e (Lam t)     = \x -> interp (ECons e x) t
interp e (App t1 t2) = (interp e t1) (interp e t2)
interp e (Halt t)    = error ("Machine halted with value " ++ show (interp e t))
-- Haskell cannot interpret these (purely)
interp e (Callcc t)  = error "Can't interpret directly"
interp e (Throw t u) = error "Can't interpret directly"
-----------------------------------------------------------------------
-- interpreter examples

t1 = App (Halt (Lit 3)) (Lit 4)
v1 = interp ENil (App (Halt (Lit 3)) (Lit 4))

t2 :: Exp () (Cont Int -> Int)
t2 = Lam (App (Throw (Var Z) (Lit 3)) (Lit 4))

t3 :: Exp () Int
t3 = Callcc t2

-- an error
err1 = interp ENil t3

-- the goal!
cps1 = cps_prog t3


-----------------------------------------------------------------------
-- weakening...

type family (Append g g')
type instance (Append g ()) = g
type instance (Append g (g', t)) = ((Append g g'), t)

weakvar :: t' -> Env g -> Env g' -> Var (Append g g') t -> Var (Append (g,t') g') t
weakvar t' g ENil x = S x
weakvar t' g (ECons g'' t1) Z = Z
weakvar t' g (ECons g'' t1) (S x) = S (weakvar t' g g'' x)
               
weak :: t' -> Env g -> Env g' -> Exp (Append g g') t -> Exp (Append (g,t') g') t
weak t' g g' (Var x) = Var (weakvar t' g g' x)
weak t' g g' (Lam u) = Lam (weak t' g (ECons g' undefined) u) 
weak t' g g' (App t1 t2) = App (weak t' g g' t1) (weak t' g g' t2)
weak t' g g' (Lit i) = Lit i
weak t' g g' (Halt u) = Halt (weak t' g g' u)
weak t' g g' (Callcc f) = Callcc (weak t' g g' f)
weak t' g g' (Throw t1 t2) = Throw (weak t' g g' t1) (weak t' g g' t2)

shift1 :: forall g t t'. Exp g t -> Exp (g, t') t
shift1 = weak (undefined :: t') (undefined :: Env g) ENil 

---------------------------------
-- CPS conversion
type family CPS t
type instance CPS Int = Int
type instance CPS (a -> b) = CPS a -> (CPS (Cont b)) -> Void
type instance CPS (a,b) = (CPS a, CPS b)
type instance CPS () = ()
type instance CPS (Cont b) = CPS b -> Void

-- CPS conversion

-- [[  x  ]]  c = App c x
-- [[ Lam x.t ]] c = App c (Lam x. Lam c'.[[t]]c')
-- [[ App t1 t2]] c = [[t1]] (Lam v1. [[t2]] (Lam v2. (App (App v1 v2) c)))
-- [[ halt t]] c = P [[ t ]] 

-- P[[ t ]] = [[ t ]] (Lam x.halt x)

cps_prog :: Exp g Int -> Exp (CPS g) Void
cps_prog t = cps t (Lam (Halt (Var Z)))

cpsvar :: Var g t -> Var (CPS g) (CPS t)
cpsvar Z = Z
cpsvar (S x) = (S (cpsvar x)) 

cps :: Exp g t -> (Exp (CPS g) (CPS t -> Void)) -> Exp (CPS g) Void 
cps (Var x) k = App k (Var (cpsvar x))
cps (Lit i) k = App k (Lit i)
cps (Lam (t :: Exp (g,t1) t2)) k = App k (Lam (Lam v))

   where u  :: Exp ((g, t1), Cont t2) t2
         u = shift1 t 

         v  :: Exp ((CPS g, CPS t1), CPS (Cont t2)) Void
         v  = cps u (Var Z)

cps (App (t1 :: Exp g (t1 -> t2)) 
         (t2 :: Exp g t1)) k = cps t1 k0
  where
    k2 :: Exp ((CPS g, CPS (t1 -> t2)), CPS t1) (CPS (Cont t2)) 
    k2 =  shift1 (shift1 k)

    k1 :: Exp (CPS g, CPS (t1 -> t2)) (CPS (Cont t1))
    k1 =  Lam (App (App (Var (S Z)) (Var Z)) k2)

    t2' :: Exp (g, t1 -> t2) t1
    t2' = shift1 t2

    e0 :: Exp (CPS (g, t1 -> t2)) Void
    e0 =  cps t2' k1

    k0 :: Exp (CPS g) (CPS (Cont (t1 -> t2)))
    k0 = Lam (cps t2' k1)

cps (Halt t) k = cps_prog t

-- HOMEWORK: add these cases

-- [[ callcc t ]] k = [[ t ]] (Lam v. App (App v k) k)
   -- t is a function that accepts current continuation.
   -- after evaluating it, apply it to current continuation k (twice).
-- [[ throw t u]] k = [[ t ]] (Lam v1. [[ u ]] (Lam v2. (App v1 v2))
   -- t should evaluate to a continuation
   -- u us a value that is acceptable to it.
   -- after evaluating each, ignore current continuation k, and just
   -- jump to v1.
 
cps (Callcc t) k = undefined

cps (Throw (t :: Exp g (Cont t1))  (u :: Exp g t1)) k = undefined


