(* Title: HOL/Tools/primrec_package.ML Author: Stefan Berghofer, TU Muenchen; Norbert Voelker, FernUni Hagen; Florian Haftmann, TU Muenchen Package for defining functions on datatypes by primitive recursion. *) signature PRIMREC_PACKAGE = sig val add_primrec: (binding * typ option * mixfix) list -> (Attrib.binding * term) list -> local_theory -> thm list * local_theory val add_primrec_cmd: (binding * string option * mixfix) list -> (Attrib.binding * string) list -> local_theory -> thm list * local_theory val add_primrec_global: (binding * typ option * mixfix) list -> (Attrib.binding * term) list -> theory -> thm list * theory val add_primrec_overloaded: (string * (string * typ) * bool) list -> (binding * typ option * mixfix) list -> (Attrib.binding * term) list -> theory -> thm list * theory end; structure PrimrecPackage : PRIMREC_PACKAGE = struct open DatatypeAux; exception PrimrecError of string * term option; fun primrec_error msg = raise PrimrecError (msg, NONE); fun primrec_error_eqn msg eqn = raise PrimrecError (msg, SOME eqn); fun message s = if ! Toplevel.debug then tracing s else (); (* preprocessing of equations *) fun process_eqn is_fixed spec rec_fns = let val (vs, Ts) = split_list (strip_qnt_vars "all" spec); val body = strip_qnt_body "all" spec; val (vs', _) = Name.variants vs (Name.make_context (fold_aterms (fn Free (v, _) => insert (op =) v | _ => I) body [])); val eqn = curry subst_bounds (map2 (curry Free) vs' Ts |> rev) body; val (lhs, rhs) = HOLogic.dest_eq (HOLogic.dest_Trueprop eqn) handle TERM _ => primrec_error "not a proper equation"; val (recfun, args) = strip_comb lhs; val fname = case recfun of Free (v, _) => if is_fixed v then v else primrec_error "illegal head of function equation" | _ => primrec_error "illegal head of function equation"; val (ls', rest) = take_prefix is_Free args; val (middle, rs') = take_suffix is_Free rest; val rpos = length ls'; val (constr, cargs') = if null middle then primrec_error "constructor missing" else strip_comb (hd middle); val (cname, T) = dest_Const constr handle TERM _ => primrec_error "ill-formed constructor"; val (tname, _) = dest_Type (body_type T) handle TYPE _ => primrec_error "cannot determine datatype associated with function" val (ls, cargs, rs) = (map dest_Free ls', map dest_Free cargs', map dest_Free rs') handle TERM _ => primrec_error "illegal argument in pattern"; val lfrees = ls @ rs @ cargs; fun check_vars _ [] = () | check_vars s vars = primrec_error (s ^ commas_quote (map fst vars)) eqn; in if length middle > 1 then primrec_error "more than one non-variable in pattern" else (check_vars "repeated variable names in pattern: " (duplicates (op =) lfrees); check_vars "extra variables on rhs: " (map dest_Free (OldTerm.term_frees rhs) |> subtract (op =) lfrees |> filter_out (is_fixed o fst)); case AList.lookup (op =) rec_fns fname of NONE => (fname, (tname, rpos, [(cname, (ls, cargs, rs, rhs, eqn))]))::rec_fns | SOME (_, rpos', eqns) => if AList.defined (op =) eqns cname then primrec_error "constructor already occurred as pattern" else if rpos <> rpos' then primrec_error "position of recursive argument inconsistent" else AList.update (op =) (fname, (tname, rpos, (cname, (ls, cargs, rs, rhs, eqn))::eqns)) rec_fns) end handle PrimrecError (msg, NONE) => primrec_error_eqn msg spec; fun process_fun descr eqns (i, fname) (fnames, fnss) = let val (_, (tname, _, constrs)) = nth descr i; (* substitute "fname ls x rs" by "y ls rs" for (x, (_, y)) in subs *) fun subst [] t fs = (t, fs) | subst subs (Abs (a, T, t)) fs = fs |> subst subs t |-> (fn t' => pair (Abs (a, T, t'))) | subst subs (t as (_ $ _)) fs = let val (f, ts) = strip_comb t; in if is_Free f andalso member (fn ((v, _), (w, _)) => v = w) eqns (dest_Free f) then let val (fname', _) = dest_Free f; val (_, rpos, _) = the (AList.lookup (op =) eqns fname'); val (ls, rs) = chop rpos ts val (x', rs') = case rs of x' :: rs => (x', rs) | [] => primrec_error ("not enough arguments in recursive application\n" ^ "of function " ^ quote fname' ^ " on rhs"); val (x, xs) = strip_comb x'; in case AList.lookup (op =) subs x of NONE => fs |> fold_map (subst subs) ts |-> (fn ts' => pair (list_comb (f, ts'))) | SOME (i', y) => fs |> fold_map (subst subs) (xs @ ls @ rs') ||> process_fun descr eqns (i', fname') |-> (fn ts' => pair (list_comb (y, ts'))) end else fs |> fold_map (subst subs) (f :: ts) |-> (fn (f'::ts') => pair (list_comb (f', ts'))) end | subst _ t fs = (t, fs); (* translate rec equations into function arguments suitable for rec comb *) fun trans eqns (cname, cargs) (fnames', fnss', fns) = (case AList.lookup (op =) eqns cname of NONE => (warning ("No equation for constructor " ^ quote cname ^ "\nin definition of function " ^ quote fname); (fnames', fnss', (Const ("HOL.undefined", dummyT))::fns)) | SOME (ls, cargs', rs, rhs, eq) => let val recs = filter (is_rec_type o snd) (cargs' ~~ cargs); val rargs = map fst recs; val subs = map (rpair dummyT o fst) (rev (Term.rename_wrt_term rhs rargs)); val (rhs', (fnames'', fnss'')) = subst (map2 (fn (x, y) => fn z => (Free x, (body_index y, Free z))) recs subs) rhs (fnames', fnss') handle PrimrecError (s, NONE) => primrec_error_eqn s eq in (fnames'', fnss'', (list_abs_free (cargs' @ subs @ ls @ rs, rhs'))::fns) end) in (case AList.lookup (op =) fnames i of NONE => if exists (fn (_, v) => fname = v) fnames then primrec_error ("inconsistent functions for datatype " ^ quote tname) else let val (_, _, eqns) = the (AList.lookup (op =) eqns fname); val (fnames', fnss', fns) = fold_rev (trans eqns) constrs ((i, fname)::fnames, fnss, []) in (fnames', (i, (fname, #1 (snd (hd eqns)), fns))::fnss') end | SOME fname' => if fname = fname' then (fnames, fnss) else primrec_error ("inconsistent functions for datatype " ^ quote tname)) end; (* prepare functions needed for definitions *) fun get_fns fns ((i : int, (tname, _, constrs)), rec_name) (fs, defs) = case AList.lookup (op =) fns i of NONE => let val dummy_fns = map (fn (_, cargs) => Const ("HOL.undefined", replicate ((length cargs) + (length (List.filter is_rec_type cargs))) dummyT ---> HOLogic.unitT)) constrs; val _ = warning ("No function definition for datatype " ^ quote tname) in (dummy_fns @ fs, defs) end | SOME (fname, ls, fs') => (fs' @ fs, (fname, ls, rec_name, tname) :: defs); (* make definition *) fun make_def ctxt fixes fs (fname, ls, rec_name, tname) = let val SOME (var, varT) = get_first (fn ((b, T), mx) => if Binding.name_of b = fname then SOME ((b, mx), T) else NONE) fixes; val def_name = Thm.def_name (Long_Name.base_name fname); val raw_rhs = fold_rev (fn T => fn t => Abs ("", T, t)) (map snd ls @ [dummyT]) (list_comb (Const (rec_name, dummyT), fs @ map Bound (0 :: (length ls downto 1)))) val rhs = singleton (Syntax.check_terms ctxt) (TypeInfer.constrain varT raw_rhs); in (var, ((Binding.name def_name, []), rhs)) end; (* find datatypes which contain all datatypes in tnames' *) fun find_dts (dt_info : datatype_info Symtab.table) _ [] = [] | find_dts dt_info tnames' (tname::tnames) = (case Symtab.lookup dt_info tname of NONE => primrec_error (quote tname ^ " is not a datatype") | SOME dt => if tnames' subset (map (#1 o snd) (#descr dt)) then (tname, dt)::(find_dts dt_info tnames' tnames) else find_dts dt_info tnames' tnames); (* primrec definition *) local fun prove_spec ctxt names rec_rewrites defs eqs = let val rewrites = map mk_meta_eq rec_rewrites @ map (snd o snd) defs; fun tac _ = EVERY [rewrite_goals_tac rewrites, rtac refl 1]; val _ = message ("Proving equations for primrec function(s) " ^ commas_quote names); in map (fn (a, t) => (a, [Goal.prove ctxt [] [] t tac])) eqs end; fun gen_primrec set_group prep_spec raw_fixes raw_spec lthy = let val (fixes, spec) = fst (prep_spec raw_fixes raw_spec lthy); val eqns = fold_rev (process_eqn (fn v => Variable.is_fixed lthy v orelse exists (fn ((w, _), _) => v = Binding.name_of w) fixes) o snd) spec []; val tnames = distinct (op =) (map (#1 o snd) eqns); val dts = find_dts (DatatypePackage.get_datatypes (ProofContext.theory_of lthy)) tnames tnames; val main_fns = map (fn (tname, {index, ...}) => (index, (fst o the o find_first (fn (_, x) => #1 x = tname)) eqns)) dts; val {descr, rec_names, rec_rewrites, ...} = if null dts then primrec_error ("datatypes " ^ commas_quote tnames ^ "\nare not mutually recursive") else snd (hd dts); val (fnames, fnss) = fold_rev (process_fun descr eqns) main_fns ([], []); val (fs, defs) = fold_rev (get_fns fnss) (descr ~~ rec_names) ([], []); val names1 = map snd fnames; val names2 = map fst eqns; val _ = if gen_eq_set (op =) (names1, names2) then () else primrec_error ("functions " ^ commas_quote names2 ^ "\nare not mutually recursive"); val prefix = space_implode "_" (map (Long_Name.base_name o #1) defs); val qualify = Binding.qualify false prefix; val spec' = (map o apfst) (fn (b, attrs) => (qualify b, Code.add_default_eqn_attrib :: attrs)) spec; val simp_atts = map (Attrib.internal o K) [Simplifier.simp_add, Nitpick_Const_Simp_Thms.add]; in lthy |> set_group ? LocalTheory.set_group (serial_string ()) |> fold_map (LocalTheory.define Thm.definitionK o make_def lthy fixes fs) defs |-> (fn defs => `(fn ctxt => prove_spec ctxt names1 rec_rewrites defs spec')) |-> (fn simps => fold_map (LocalTheory.note Thm.theoremK) simps) |-> (fn simps' => LocalTheory.note Thm.theoremK ((qualify (Binding.qualified_name "simps"), simp_atts), maps snd simps')) |>> snd end handle PrimrecError (msg, some_eqn) => error ("Primrec definition error:\n" ^ msg ^ (case some_eqn of SOME eqn => "\nin\n" ^ quote (Syntax.string_of_term lthy eqn) | NONE => "")); in val add_primrec = gen_primrec false Specification.check_spec; val add_primrec_cmd = gen_primrec true Specification.read_spec; end; fun add_primrec_global fixes specs thy = let val lthy = TheoryTarget.init NONE thy; val (simps, lthy') = add_primrec fixes specs lthy; val simps' = ProofContext.export lthy' lthy simps; in (simps', LocalTheory.exit_global lthy') end; fun add_primrec_overloaded ops fixes specs thy = let val lthy = TheoryTarget.overloading ops thy; val (simps, lthy') = add_primrec fixes specs lthy; val simps' = ProofContext.export lthy' lthy simps; in (simps', LocalTheory.exit_global lthy') end; (* outer syntax *) local structure P = OuterParse and K = OuterKeyword in val opt_unchecked_name = Scan.optional (P.$$$ "(" |-- P.!!! (((P.$$$ "unchecked" >> K true) -- Scan.optional P.name "" || P.name >> pair false) --| P.$$$ ")")) (false, ""); val old_primrec_decl = opt_unchecked_name -- Scan.repeat1 ((SpecParse.opt_thm_name ":" >> apfst Binding.name_of) -- P.prop); val primrec_decl = P.opt_target -- P.fixes -- SpecParse.where_alt_specs; val _ = OuterSyntax.command "primrec" "define primitive recursive functions on datatypes" K.thy_decl ((primrec_decl >> (fn ((opt_target, fixes), specs) => Toplevel.local_theory opt_target (add_primrec_cmd fixes specs #> snd))) || (old_primrec_decl >> (fn ((unchecked, alt_name), eqns) => Toplevel.theory (snd o (if unchecked then OldPrimrecPackage.add_primrec_unchecked else OldPrimrecPackage.add_primrec) alt_name (map P.triple_swap eqns))))); end; end;