diff --git a/src/base/snark0.ml b/src/base/snark0.ml index 473b7117a..11ffde363 100644 --- a/src/base/snark0.ml +++ b/src/base/snark0.ml @@ -1428,6 +1428,27 @@ module Run = struct in Staged.stage finish_computation + let request_manual (req : unit -> 'a Request.t) () : 'a = + Request.Handler.run (Run_state.handler !state) (req ()) + |> Option.value_exn ~message:"Unhandled request" + + module Async_generic (Promise : Base.Monad.S) = struct + let run_prover ~(else_ : unit -> 'a) (f : unit -> 'a Promise.t) : + 'a Promise.t = + if Run_state.has_witness !state then ( + let old = Run_state.as_prover !state in + Run_state.set_as_prover !state true ; + let%map.Promise result = f () in + Run_state.set_as_prover !state old ; + result ) + else Promise.return (else_ ()) + + let as_prover (f : unit -> unit Promise.t) : unit Promise.t = + run_prover ~else_:(fun () -> ()) f + + let unit_request req = as_prover (request_manual req) + end + let run_unchecked x = finalize_is_running (fun () -> Perform.run_unchecked ~run:as_stateful (fun () -> mark_active ~f:x) ) diff --git a/src/base/snark_intf.ml b/src/base/snark_intf.ml index a6416ca31..465daf208 100644 --- a/src/base/snark_intf.ml +++ b/src/base/snark_intf.ml @@ -1405,6 +1405,12 @@ module type Run_basic = sig (* Callback, low-level version of [as_prover] and [exists]. *) val as_prover_manual : int -> (field array option -> Field.t array) Staged.t + module Async_generic (Promise : Base.Monad.S) : sig + val as_prover : (unit -> unit Promise.t) -> unit Promise.t + + val unit_request : (unit -> unit Promise.t Request.t) -> unit Promise.t + end + (** Generate the public input vector for a given statement. *) val generate_public_input : ('input_var, 'input_value) Typ.t -> 'input_value -> Field.Constant.Vector.t