module FlowArrow
( FlowArrow
, Privilege
, AuthDB
, tag, 
, declassify
, cert
, authenticate
)
where

import Data.List as List
import Lattice
import Control.Arrow

data Constraint l = 
    LEQ l l | USERGEQ l deriving (Eq, Show)

data Flow l = Trans l l | Flat deriving (Eq,Show)

flow_seq::Flow l->Flow l->(Flow l, [Constraint l])
flow_seq (Trans l1 l2) (Trans l3 l4)=
    (Trans l1 l4, [LEQ l2 l3])
flow_seq Flat f2 = (f2,[])
flow_seq f1 Flat = (f1,[])

flow_par :: (Lattice l)=>Flow l->Flow l->Flow l
flow_par (Trans l1 l2) (Trans l3 l4) = 
    Trans (label_meet l1 l3) (label_join l2 l4)
flow_par Flat f2 = f2
flow_par f1 Flat = f1

data FlowArrow l a b c = FA
    { computation :: a b c
    , flow        :: Flow l
    , constraints :: [Constraint l]
    }
  
instance (Lattice l, Arrow a) =>
                     Arrow (FlowArrow l a) where
  pure f = FA { computation = pure f
              , flow = Flat
              , constraints = []
              }
  (FA c1 f1 t1) >>> (FA c2 f2 t2) =
     let (f,c) = flow_seq f1 f2 in
      FA { computation = c1 >>> c2
         , flow = f
         , constraints = t1 ++ t2 ++ c
         }
  first (FA c f t) = 
      FA { computation = first c
         , flow = f
         , constraints = t
         }
  second (FA c f t) = 
      FA { computation = second c
         , flow = f
         , constraints = t
         }
  (FA c1 f1 t1) *** (FA c2 f2 t2) =
      FA { computation = c1 *** c2
         , flow = flow_par f1 f2
         , constraints = t1++t2
         }
  (FA c1 f1 t1) &&& (FA c2 f2 t2) =
      FA { computation = c1 &&& c2
         , flow = flow_par f1 f2
         , constraints = t1++t2
         }

instance (Lattice l, ArrowChoice a) => 
    ArrowChoice (FlowArrow l a) where
  left (FA c f t) = 
      FA { computation = left c
         , flow = f
         , constraints = t
         }
  right (FA c f t) = 
      FA { computation = right c
         , flow = f
         , constraints = t
         }
  (FA c1 f1 t1) +++ (FA c2 f2 t2) =
      FA { computation = c1 +++ c2
         , flow = flow_par f1 f2
         , constraints = t1++t2
         }
  (FA c1 f1 t1) ||| (FA c2 f2 t2) =
      FA { computation = c1 ||| c2
         , flow = flow_par f1 f2
         , constraints = t1++t2
         }

instance (Lattice l, ArrowLoop a) => 
              ArrowLoop (FlowArrow l a) where
 loop (FA c f t) = 
   let t' = constraint_loop f in
     FA { computation = loop c
        , flow = f
        , constraints = t ++ t'
        }
   where  
   constraint_loop Flat = []
   constraint_loop (Trans l1 l2) = [LEQ l2 l1]


tag ::(Lattice l,Arrow a)=> l -> FlowArrow l a b b
tag lbl = 
    FA { computation = pure (\x->x)
       , flow = Trans lbl lbl
       , constraints = []
       }

declassify :: (Lattice l, Arrow a) => 
              l -> l -> FlowArrow l a b b
declassify l1 l2 = 
    FA { computation = pure (\x->x)
       , flow = Trans l1 l2
       , constraints = [USERGEQ l1]
       }

check_levels label_in label_out Flat = 
   label_leq label_in label_out
check_levels label_in label_out (Trans l1 l2) = 
  (label_leq label_in l1) &&
  (label_leq l2 label_out)

check_constraint p (LEQ l1 l2)=label_leq l1 l2
check_constraint p (USERGEQ l)=label_leq l p
check_constraints p t=all (check_constraint p) t

data Privilege l = PR l

certify :: (Lattice l) => l -> l -> Privilege l 
           -> FlowArrow l a b c -> a b c 
certify l_in l_out (PR label_user) (FA c f t) =
 if not $ check_levels l_in l_out f then
     error $ "security level mismatch" ++ (show f)
 else if not $ check_constraints label_user t then
     error $ "constraints cannot be met"++(show t)
 else 
     c

cert = certify label_bottom label_bottom

type AuthDB = [(String, String, TriLabel)]

authenticate :: (FlowArrow TriLabel (->) () AuthDB) 
  ->String->String -> (String, Privilege TriLabel)
authenticate adb username password = 
  let res = adb >>> proc db -> do
      declassify HIGH LOW -< 
       case find finder db of
        Nothing         -> ("nobody", PR LOW)
        Just (u,p,priv) -> (u, PR priv)
  in 
  cert (PR HIGH) res ()
  where finder (u,p,_)= username==u && password==p
