{-# OPTIONS -fglasgow-exts -fallow-overlapping-instances -fallow-undecidable-instances -fallow-incoherent-instances #-}

module Shrink where
  import Data.Typeable

  --------- Library code -----------
  data Proxy (a :: * -> *)
    -- allows explicit type argument, for type inference

  class Sat a where { dict :: a }
    -- abstracted type class, passed in as an explicit dictionary

  class (Typeable a, Sat (ctx a))
     => Data ctx a where
      gmapQ :: Proxy ctx
            -> (forall b. Data ctx b => b -> r)
            -> a -> [r]

      gmapM :: Monad m => Proxy ctx
            -> (forall b. Data ctx b => b -> m b)
            -> a -> m a

  instance Sat (cxt Char) => Data cxt Char where
    gmapQ _ f n = []
    gmapM _ f c = return c

  instance Sat (ctx Int) => Data ctx Int where
    gmapQ _ f i = []
    gmapM _ f i = return i

  instance (Sat (cxt [a]), Data cxt a)
        => Data cxt [a] where
    gmapQ _ f []     = []
    gmapQ _ f (x:xs) = [f x, f xs]

    gmapM _ f []     = return []
    gmapM _ f (x:xs) = do { x' <- f x; xs' <- f xs; return (x':xs') }

  instance (Sat (cxt (a,b)), Data cxt a, Data cxt b)
        => Data cxt (a,b) where
    gmapQ _ f (x,y) = [f x, f y]
    gmapM _ f (x,y) = do { x' <- f x; y' <- f y; return (x',y') }

  --------- shrink-specific code -----------
  data ShrinkD a = ShrinkD { shrinkD   :: a -> [a],
 			     childrenD :: a -> [a] }

  class Shrink a where
     shrink   :: a -> [a]
     children :: a -> [a]

  instance Shrink a => Sat (ShrinkD a) where
    dict = ShrinkD { shrinkD   = shrink
		  , childrenD = children }

  instance Data ShrinkD a => Shrink a where
    shrink t   = (childrenD dict) t ++ 
                 shrinkStep t
    children t = [y | Just y <- gmapQ shrinkProxy cast t]

  shrinkProxy :: Proxy ShrinkD 
  shrinkProxy = error "urk"

  shrinkStep :: Data ShrinkD t => t -> [t]
  shrinkStep t = let 
		   M _ ts = gmapM shrinkProxy (\x -> M x (shrinkD dict x)) t
		 in ts

  --------- The M monad -----------
  data M a = M a [a]	

  instance Monad M where
    return x = M x []
    (M x xs) >>= k = M r (rs1 ++ rs2)
	where
	  M r rs1 = k x
	  rs2 = [r | x <- xs, let M r _ = k x]

  --------- Special Cases -------------------
  -- a special case for pairs!!

  instance Shrink a => Shrink (a,b) where
     shrink (x,y) = [ (x', y) | x' <- shrink x ]
     children _   = []

  instance Shrink Int where
     shrink i   = [i - 1] 
     children i = []

  ------------- List with Length -----------------

  data ListWithLength a = LWL [a] Int deriving (Typeable, Show)
  -- Invariant: the Int is the length of the list

  -- The data instance of LWL. This should be automatically derivable
  instance ( Sat (ctx (ListWithLength a))
	 , Data ctx Int
         , Data ctx [a]
         , Typeable a
         )
      => Data ctx (ListWithLength a)
   where
     gmapQ _ f (LWL l i) = [f l, f i]
     gmapM _ f (LWL l i) = do { l' <- f l; i' <- f i; return (LWL l' i') }

  -- a special case of Shrink for LWL
  instance (Data ShrinkD a) => Shrink (ListWithLength a) where
     shrink t@(LWL l i) = children t ++ (map (\x -> (LWL x i)) (shrinkStep l))
     children (LWL (hd:tl) i) = [LWL tl (i-1)]
     children (LWL [] _) = []

  main = do print (shrink ("ab", "cd"))
	    print (shrink ('a','b'))
            print (shrink "abc")
            print (shrink [1::Int,2,3])
            print (shrink (LWL [1::Int,2,3] 3))


