open Camlp4.PreCast;
open Format;
open Fan_easy;
open List;
open Fan_expr;


value mk_variant_eq _cons  = fun 
  [ [] -> <:expr< True >>
  | [{ty_expr;_}] -> ty_expr 
  | [{ty_expr; _} ::ys] ->
      fold_left
        (fun acc -> fun [ {ty_expr=e; _} ->
          <:expr< $acc$ && $e$ >> ] ) ty_expr ys ];
  
value mk_tuple_eq exprs = mk_variant_eq "" exprs ;
  
value mk_record_eq cols =
    cols
    |> map (fun [ {record_info;} -> record_info])
    |> mk_variant_eq "" ;

value gen_eq =
  gen_str_item ~id:(`Pre "eq_")  ~names:[]
    ~arity:2   ~mk_tuple:mk_tuple_eq ~mk_record:mk_record_eq mk_variant_eq
    ~trail:(fun
      [ (_,1) -> <:match_case< >>
      | (_,0) -> <:match_case< (_,_) -> True >>
      | (TyAbstr,_) -> <:match_case< (_,_) -> False >> 
      | (_,_) -> <:match_case< (_,_) -> False >> ] ) ;      

value extract info = info
    |> map (fun [{ty_name_expr;ty_id;_} -> [ty_name_expr;ty_id] ])
    |> concat ;

value mkfmt pre sep post fields =
    <:expr<
      Format.fprintf fmt
      $str: pre^ String.concat sep fields ^ post $ >> ;
  
value mk_variant_print cons params =
    let len = List.length params in
    let pre =
        if len >= 1 then
          mkfmt ("@[<1>("^cons^"@ ")
            "@ " ")@]" (init len (fun _ -> "%a"))
        else
          mkfmt cons "" "" [] in
    params |> extract |> apply pre ;
    
value mk_tuple_print params =
    let len = List.length params in
    let pre = mkfmt "@[<1>(" ",@," ")@]" (init len (fun _ -> "%a")) in 
    params |> extract |> apply pre  ;
    
value mk_record_print cols = 
    let pre = cols
       |> map (fun [ {record_label;_} -> record_label^":%a" ])
       |>  mkfmt "@[<hv 1>{" ";@," "}@]" in 
    cols |> map(fun [ {record_info;_} -> record_info ])
         |> extract |> apply pre  ;
  
value gen_print =
  gen_str_item  ~id:(`Pre "pp_print_")  ~names:["fmt"
    ~mk_tuple:mk_tuple_print  ~mk_record:mk_record_print   mk_variant_print;    

value gen_print_obj =
  gen_object ~kind:Iter ~mk_tuple:mk_tuple_print
    ~base:"printbase" ~class_name:"print"
    ~names:["fmt"]  ~mk_record:mk_record_print mk_variant_print;
  
value mk_variant_meta_expr cons params =
    let len = List.length params in 
    if ends_with cons "Ant" then
      match len with
      [ n when n > 1 -> of_ident_number <:ident< Ast.ExAnt >> len
      | 1 ->  <:expr< Ast.ExAnt _loc $id:xid 0$ >>
      | _ -> do{
        eprintf "%s can not be handled" cons;
        exit 1} ]
    else
      params
      |> map (fun [ {ty_expr;_} -> ty_expr ])
      |> fold_left mee_app (mee_of_str cons)  ;
        
value mk_record_meta_expr cols = cols |> map
      (fun [ {record_label; record_info={ty_expr;_};_} -> (record_label, ty_expr)])
      |> mk_record_ee ;

value mk_tuple_meta_expr params =
    params |> map (fun [{ty_expr;_} -> ty_expr]) |> mk_tuple_ee ;

value gen_meta_expr = 
  gen_str_item  ~id:(`Pre "meta_")  ~names:["_loc"]
    ~mk_tuple:mk_tuple_meta_expr
    ~mk_record:mk_record_meta_expr mk_variant_meta_expr
;    

value mk_variant_meta_patt cons params =
    let len = List.length params in 
    if ends_with cons "Ant" then
      match len with
      [ n when n > 1 -> of_ident_number <:ident< Ast.PaAnt >> len
      | 1 -> <:expr< Ast.PaAnt _loc $id:xid 0$ >>
      | _ -> do{
        eprintf "%s can not be handled" cons;
        exit 1 } ]
    else
      params
      |> map (fun [ {ty_expr;_} -> ty_expr ])
      |> fold_left mep_app (mep_of_str cons);
        
value mk_record_meta_patt cols = cols |> map
      (fun [ {record_label; record_info={ty_expr;_};_}
             -> (record_label, ty_expr)])
         |> mk_record_ep ;

value mk_tuple_meta_patt params = params |> map
      (fun [{ty_expr;_} -> ty_expr]) |> mk_tuple_ep;

value gen_meta_patt =  gen_str_item  ~id:(`Pre "meta_") ~names:["_loc"]
    ~mk_tuple:mk_tuple_meta_patt ~mk_record:mk_record_meta_patt mk_variant_meta_patt
;    

  
value (gen_map,gen_map2) =
  let mk_variant cons params =
    params |> map (fun [ {ty_expr;_} -> ty_expr]) |> apply (of_str cons) in
  let mk_tuple params =
    params |> map (fun [{ty_expr; _ } -> ty_expr]) |> tuple_of_list in 
  let mk_record cols =
    cols |> map (fun [ {record_label; record_info={ty_expr;_ } ; _ }  ->
          (record_label,ty_expr) ] )  |> mk_record 
  in
  (gen_object ~kind:Map ~mk_tuple ~mk_record
     ~base:"mapbase" ~class_name:"map"
     mk_variant ~names:[],
   gen_object ~kind:Map ~mk_tuple ~mk_record
     ~base:"mapbase2" ~class_name:"map2" mk_variant ~names:[]
     ~arity:2 ~trail:(
     fun
     [(_,1) -> <:match_case< >>
     |(_,_) ->  <:match_case< (_,_) -> invalid_arg "map2 failure" >> ] ))
;

value (gen_fold,gen_fold2) = 
  let mk_variant cons params =
    let rec aux exprs = match exprs with
      [ [] -> <:expr< self>>
      | [{ty_expr;_}] -> ty_expr
      | [{ty_expr;_ }::xs] ->
          <:expr< let self = $ty_expr$ in $aux xs$ >> ] in
    aux params  in 
  let mk_tuple  = mk_variant ""  in 
  let mk_record cols =
    cols |> map (fun [ {record_label; record_info ; _ } -> record_info ] )
         |> mk_variant "" in 
  (gen_object ~kind:Fold ~mk_tuple ~mk_record
     ~base:"foldbase" ~class_name:"fold" mk_variant ~names:[],
   gen_object ~kind:Fold ~mk_tuple ~mk_record
     ~base:"foldbase2" ~class_name:"fold2"
     mk_variant ~names:[]
     ~arity:2 ~trail:(fun
     [(_,1) -> <:match_case< >>
     |(_,_) ->  <:match_case< (_,_) -> invalid_arg "fold2 failure" >> ] ))
;


open Fan_asthook;  
begin
  [
   ("Print", gen_print) ;
   ("Eq",gen_eq) ;
   ("MetaExpr",gen_meta_expr) ;
   ("MetaPatt",gen_meta_patt) ;
   ("Map",gen_map);
   ("Map2",gen_map2);
   ("Fold",gen_fold);
   ("Fold2",gen_fold2);
   ("OPrint", gen_print_obj);  
  ] |> iter register;
end;