{-# OPTIONS -fglasgow-exts -fallow-overlapping-instances -fallow-undecidable-instances -fallow-incoherent-instances #-}

module Shrink where
  import Data.Typeable

  --------- Library code -----------
  class (Typeable a, (ctx a))
     => Data ctx a where
      gmapQ :: 
            -> (forall b. Data ctx b => b -> r)
            -> a -> [r]

      gmapM :: (Monad m) =>
            -> (forall b. Data ctx b => b -> m b)
            -> a -> m a

  instance (cxt Char) => Data cxt Char where
    gmapQ f n = []
    gmapM f c = return c

  instance (ctx Int) => Data ctx Int where
    gmapQ f i = []
    gmapM f i = return i

  instance (cxt [a], Data cxt a)
        => Data cxt [a] where
    gmapQ f []     = []
    gmapQ f (x:xs) = [f x, f xs]

    gmapM f []     = return []
    gmapM f (x:xs) = do { x' <- f x; xs' <- f xs; return (x':xs') }

  instance (cxt (a,b), Data cxt a, Data cxt b)
        => Data cxt (a,b) where
    gmapQ f (x,y) = [f x, f y]
    gmapM f (x,y) = do { x' <- f x; y' <- f y; return (x',y') }

  --------- shrink-specific code -----------
  class Shrink a where
     shrink   :: a -> [a]
     children :: a -> [a]

  instance Data Shrink a => Shrink a where
    shrink t   = children t ++ shrinkStep t
    children t = [y | Just y <- gmapQ {|Shrink|} cast t]

  shrinkStep :: Shrink t => t -> [t]
  shrinkStep t = let 
		   M _ ts = gmapM {|shrink|} (\x -> M x (shrink x)) t
		 in ts

  --------- The M monad -----------
  data M a = M a [a]	

  instance Monad M where
    return x = M x []
    (M x xs) >>= k = M r (rs1 ++ rs2)
	where
	  M r rs1 = k x
	  rs2 = [r | x <- xs, let M r _ = k x]

  --------- Special Cases -------------------

  -- a special case for pairs, shrinking only the first component
  instance Shrink a => Shrink (a,b) where
     shrink (x,y) = [ (x', y) | x' <- shrink x ]
     children _   = []

  -- shrinking integers by making them smaller
  instance Shrink Int where
     shrink i   = [i - 1] 
     children i = []

  ------------- List with Length -----------------

  data ListWithLength a = LWL [a] Int deriving (Typeable, Show)
  -- Invariant: the Int is the length of the list

  -- The data instance of LWL. This should be automatically derivable
  instance (ctx (ListWithLength a)
	 , Data ctx Int
         , Data ctx [a]
         , Typeable a
         )
      => Data ctx (ListWithLength a)
   where
     gmapQ _ f (LWL l i) = [f l, f i]
     gmapM _ f (LWL l i) = do { l' <- f l; i' <- f i; return (LWL l' i') }

  -- a special case of Shrink for LWL
  instance (Shrink a) => Shrink (ListWithLength a) where
     shrink t@(LWL l i) = children t ++ (map (\x -> (LWL x i)) (shrinkStep l))
     children (LWL (hd:tl) i) = [LWL tl (i-1)]
     children (LWL [] _) = []

  main = do print (shrink ("ab", "cd"))
	    print (shrink ('a','b'))
            print (shrink "abc")
            print (shrink [1::Int,2,3])
            print (shrink (LWL [1::Int,2,3] 3))


