@@ -54,7 +54,7 @@ let row ?kind_label ~title prf =
5454 match prf.prf_desc with
5555 | Rtag ({txt; loc} , _ , [] ) -> txt, mk ~loc ?kind_label ~title txt code
5656 | Rtag ({txt; loc} , _ , (h :: _ )) ->
57- let enc = Encoding. core ~wrap: false h in
57+ let enc = Encoding. core h in
5858 txt, mk ~loc ~enc ?kind_label ~title txt code
5959 | _ ->
6060 Location. raise_errorf ~loc " inherit not handled"
@@ -63,29 +63,156 @@ let expressions ?kind_label ~title t =
6363 let loc = t.ptype_loc in
6464 match t.ptype_kind, t.ptype_manifest with
6565 | Ptype_abstract , Some {ptyp_desc =Ptyp_variant (l , _ , _ ); _} ->
66- List. map (row ?kind_label ~title ) l
67- | _ -> Location. raise_errorf ~loc " error cases only from variants"
66+ `variant (List. map (row ?kind_label ~title ) l)
67+ | Ptype_open , None -> `type_ext t.ptype_name.txt
68+ | _ -> Location. raise_errorf ~loc " error cases only from variants and type extension"
6869
69- let str_gen ~loc ~path :_ (rec_flag , l ) debug title kind_label =
70+ let str_gen ~loc ~path :_ (_rec_flag , l ) debug title kind_label =
7071 let l = List. map (fun t ->
7172 let loc = t.ptype_loc in
72- let cases = expressions ?kind_label ~title t in
73- List. map (fun (name , expr ) ->
74- let pat = ppat_constraint ~loc (pvar ~loc (String. lowercase_ascii name ^ " _case" ))
75- [% type : [% t ptyp_constr ~loc (Utils. llid ~loc t.ptype_name.txt) []] EzAPI.Err. case] in
76- value_binding ~loc ~pat ~expr ) cases) l in
73+ let r = expressions ?kind_label ~title t in
74+ match r with
75+ | `variant cases ->
76+ List. map (fun (name , expr ) ->
77+ let pat = ppat_constraint ~loc (pvar ~loc (String. lowercase_ascii name ^ " _case" ))
78+ [% type : [% t ptyp_constr ~loc (Utils. llid ~loc t.ptype_name.txt) []] EzAPI.Err. case] in
79+ value_binding ~loc ~pat ~expr ) cases
80+ | `type_ext name ->
81+ let t = ptyp_constr ~loc (Utils. llid ~loc name) [] in
82+ let pat = [% pat? ([% p pvar ~loc (" _error_selects_" ^ name)] : (int * ([% t t] -> [% t t] option )) list ref )] in
83+ let selects = value_binding ~loc ~pat ~expr: [% expr ref []] in
84+ let pat = [% pat? ([% p pvar ~loc (" _error_cases_" ^ name)] : (int * [% t t] Json_encoding. case) list ref )] in
85+ let cases = value_binding ~loc ~pat ~expr: [% expr ref []] in
86+ [ selects; cases ]
87+ ) l in
7788 let l = List. flatten l in
78- let rec_flag = if List. length l < 2 then Nonrecursive else rec_flag in
79- let s = [ pstr_value ~loc rec_flag l ] in
80- if debug then Format. printf " %s@." (Pprintast. string_of_structure s);
89+ let s = [ pstr_value ~loc Nonrecursive l ] in
90+ if debug then Format. printf " %a@." Pprintast. structure s;
8191 s
8292
93+ let attribute_code ~code attrs =
94+ let c = List. find_map (fun a -> match a.attr_name.txt, a.attr_payload with
95+ | "code" , PStr [ { pstr_desc = Pstr_eval ({ pexp_desc = Pexp_constant Pconst_integer (s, _); _ }, _); _ } ] ->
96+ Some (int_of_string s)
97+ | _ -> None ) attrs in
98+ match c, code with Some c , _ | _ , Some c -> c | _ -> 500
99+
100+ let str_type_ext ~loc :_ ~path :_ t debug code =
101+ let loc = t.ptyext_loc in
102+ let name = Longident. name t.ptyext_path.txt in
103+ let l = List. filter_map (fun pext ->
104+ let loc = pext.pext_loc in
105+ match pext.pext_kind with
106+ | Pext_decl ([] , args , None) ->
107+ let code = attribute_code ~code pext.pext_attributes in
108+ let case = Encoding. resolve_case ~loc @@ Encoding. constructor_label ~wrap: true ~case: `snake
109+ ~loc ~name: pext.pext_name.txt ~attrs: pext.pext_attributes args in
110+ let select = pext.pext_name.txt, (match args with Pcstr_tuple [] -> false | _ -> true ) in
111+ Some (code, case, select)
112+ | _ -> None
113+ ) t.ptyext_constructors in
114+ let cases = elist ~loc @@ List. map (fun (code , case , _ ) ->
115+ [% expr [% e eint ~loc code], [% e case]]) l in
116+ let select_grouped = List. fold_left (fun acc (code , _ , select ) ->
117+ match List. assoc_opt code acc with
118+ | None -> acc @ [code, [ select ]]
119+ | Some l -> (List. remove_assoc code acc) @ [ code, l @ [ select ] ]
120+ ) [] l in
121+ let select_merged cons = pexp_function ~loc (
122+ (List. map (fun (name , has_arg ) ->
123+ case ~guard: None
124+ ~lhs: (ppat_alias ~loc (ppat_construct ~loc (Utils. llid ~loc name) (if has_arg then Some [% pat? _] else None )) {txt= " x" ; loc})
125+ ~rhs: [% expr Some x]) cons) @ [
126+ case ~guard: None ~lhs: [% pat? _] ~rhs: [% expr None ]
127+ ]) in
128+ let selects = elist ~loc @@ List. map (fun (code , cons ) ->
129+ [% expr [% e eint ~loc code], [% e select_merged cons] ]) select_grouped in
130+ let cases_name = " _error_cases_" ^ name in
131+ let selects_name = " _error_selects_" ^ name in
132+ let expr = [% expr
133+ [% e evar ~loc cases_name] := ! [% e evar ~loc cases_name] @ [% e cases];
134+ [% e evar ~loc selects_name] := ! [% e evar ~loc selects_name] @ [% e selects];
135+ ] in
136+ let s = [
137+ pstr_value ~loc Nonrecursive [ value_binding ~loc ~pat: [% pat? () ] ~expr ]
138+ ] in
139+ if debug then Format. printf " %a@." Pprintast. structure s;
140+ s
141+
142+ let remove_spaces s =
143+ let b = Bytes. create (String. length s) in
144+ let n = String. fold_left (fun i -> function ' ' -> i | c -> Bytes. set b i c; i+ 1 ) 0 s in
145+ Bytes. (to_string @@ sub b 0 n)
146+
147+ let type_ext_err_case ~loc ~typ ?(def =true ) code =
148+ match EzAPI.Error_codes. error code with
149+ | None -> Location. raise_errorf ~loc " code is not standard"
150+ | Some name ->
151+ let enc = [% expr
152+ Json_encoding. union @@ List. filter_map (fun (code , case ) ->
153+ if code = [% e eint ~loc code] then Some case else None ) ! [% e evar ~loc (" _error_cases_" ^ typ)]
154+ ] in
155+ let enc =
156+ if not def then enc
157+ else [% expr Json_encoding. def [% e estring ~loc (remove_spaces name)] [% e enc]] in
158+ [% expr
159+ let select = EzAPI.Err. merge_selects @@ List. filter_map (fun (code , case ) ->
160+ if code = [% e eint ~loc code] then Some case else None ) ! [% e evar ~loc (" _error_selects_" ^ typ)] in
161+ EzAPI.Err. make ~code: [% e eint ~loc code] ~name: [% e estring ~loc name]
162+ ~encoding: [% e enc] ~select ~deselect: Fun. id ]
163+
164+ let remove_poly c = match c.ptyp_desc with Ptyp_poly (_ , c ) -> c | _ -> c
165+
166+ let get_err_case_options ~loc l =
167+ let code, debug, def = List. fold_left (fun (code , debug , def ) (lid , e ) -> match Longident. name lid.txt, e.pexp_desc with
168+ | "code" , Pexp_constant Pconst_integer (s , _ ) -> Some (int_of_string s), debug, def
169+ | "debug" , _ -> code, true , def
170+ | "nodef" , _ -> code, debug, false
171+ | "def" , Pexp_construct ({txt =Lident "false" ; _} , None) -> code, debug, false
172+ | s , _ -> Format. eprintf " %s option not understood@." s; code, debug, def
173+ ) (None , false , true ) l in
174+ match code with
175+ | None -> Location. raise_errorf ~loc " code not found"
176+ | Some code -> code, debug, def
177+
178+ let transform =
179+ object
180+ inherit Ast_traverse. map
181+ method! structure_item it = match it.pstr_desc with
182+ | Pstr_extension (({txt ="err_case" ; _} , PStr [{pstr_desc =Pstr_value (_ , [ vb ]); pstr_loc =loc ; _} ]), _ ) ->
183+ let typ, e, pat = match vb.pvb_expr.pexp_desc, vb.pvb_pat.ppat_desc with
184+ | Pexp_constraint (e , typ ), (Ppat_constraint ({ppat_desc =p ; _} , _ ) | p ) ->
185+ remove_poly typ, e, { vb.pvb_pat with ppat_desc= p }
186+ | _ , Ppat_constraint (p , typ ) ->
187+ remove_poly typ, vb.pvb_expr, p
188+ | _ -> Location. raise_errorf ~loc " no error type given to derive the error case" in
189+ let code, debug, def = match e.pexp_desc with
190+ | Pexp_constant Pconst_integer (s , _ ) -> int_of_string s, false , true
191+ | Pexp_record (l , None) -> get_err_case_options ~loc: e.pexp_loc l
192+ | _ -> Location. raise_errorf ~loc: e.pexp_loc " code not found" in
193+ let typ = match typ.ptyp_desc with
194+ | Ptyp_constr ({txt; _}, [] )
195+ | Ptyp_constr ({txt= (Ldot (Ldot (Lident " EzAPI" , " Err" ), " case" ) | Ldot (Lident " Err" , " case" )) ; _}, [
196+ { ptyp_desc = Ptyp_constr ({txt; _}, [] ); _ }
197+ ]) -> Longident. name txt
198+ | _ -> Location. raise_errorf ~loc: typ.ptyp_loc " couldn't find type to derive error case" in
199+ let expr = type_ext_err_case ~loc ~typ ~def code in
200+ let it = pstr_value ~loc Nonrecursive [ value_binding ~loc ~pat ~expr ] in
201+ if debug then Format. printf " %a@." Pprintast. structure_item it;
202+ it
203+ | _ -> it
204+ end
205+
83206let () =
84- let args_str = Deriving.Args. (
207+ let open Deriving in
208+ let args_str = Args. (
85209 empty
86210 +> flag " debug"
87211 +> flag " title"
88212 +> arg " kind_label" (estring __)
89213 ) in
90- let str_type_decl = Deriving.Generator. make args_str str_gen in
91- Deriving. ignore @@ Deriving. add " err_case" ~str_type_decl
214+ let str_type_decl = Generator. make args_str str_gen in
215+ let args_type_ext = Args. (empty +> flag " debug" +> arg " code" (eint __)) in
216+ let str_type_ext = Generator. make args_type_ext str_type_ext in
217+ ignore @@ add " err_case" ~str_type_decl ~str_type_ext ;
218+ Driver. register_transformation " err_case" ~impl: transform#structure
0 commit comments