Skip to content

Commit

Permalink
Merge pull request #859 from o1-labs/feature/generalized-constraint
Browse files Browse the repository at this point in the history
Abstract over the specific `Constraint.t` type used by the backing constraint system
  • Loading branch information
dannywillems authored Jan 7, 2025
2 parents b2442f4 + e09e89d commit 36c1add
Show file tree
Hide file tree
Showing 16 changed files with 339 additions and 345 deletions.
48 changes: 19 additions & 29 deletions src/base/backend_extended.ml
Original file line number Diff line number Diff line change
Expand Up @@ -59,30 +59,31 @@ module type S = sig
val to_constant : t -> Field.t option
end

module R1CS_constraint_system : Constraint_system.S with module Field := Field

module Constraint : sig
type t = (Cvar.t, Field.t) Constraint.t [@@deriving sexp]

type 'k with_constraint_args = ?label:string -> 'k
type t [@@deriving sexp]

val boolean : (Cvar.t -> t) with_constraint_args
val boolean : Cvar.t -> t

val equal : (Cvar.t -> Cvar.t -> t) with_constraint_args
val equal : Cvar.t -> Cvar.t -> t

val r1cs : (Cvar.t -> Cvar.t -> Cvar.t -> t) with_constraint_args
val r1cs : Cvar.t -> Cvar.t -> Cvar.t -> t

val square : (Cvar.t -> Cvar.t -> t) with_constraint_args
val square : Cvar.t -> Cvar.t -> t

val annotation : t -> string
val eval : t -> (Cvar.t -> Field.t) -> bool

val eval :
(Cvar.t, Field.t) Constraint.basic_with_annotation
-> (Cvar.t -> Field.t)
-> bool
val log_constraint : t -> (Cvar.t -> Field.t) -> string
end

module Run_state : Run_state_intf.S
module R1CS_constraint_system :
Constraint_system.S
with module Field := Field
with type constraint_ = Constraint.t

module Run_state :
Run_state_intf.S
with type field := Field.t
and type constraint_ := Constraint.t
end

module Make (Backend : Backend_intf.S) :
Expand All @@ -91,7 +92,8 @@ module Make (Backend : Backend_intf.S) :
and type Field.Vector.t = Backend.Field.Vector.t
and type Bigint.t = Backend.Bigint.t
and type R1CS_constraint_system.t = Backend.R1CS_constraint_system.t
and type 'field Run_state.t = 'field Backend.Run_state.t = struct
and type Run_state.t = Backend.Run_state.t
and type Constraint.t = Backend.Constraint.t = struct
open Backend

module Bigint = struct
Expand Down Expand Up @@ -207,19 +209,7 @@ module Make (Backend : Backend_intf.S) :
None
end

module Constraint = struct
open Constraint
include Constraint.T

type 'k with_constraint_args = ?label:string -> 'k

type t = (Cvar.t, Field.t) Constraint.t [@@deriving sexp]

let m = (module Field : Snarky_intf.Field.S with type t = Field.t)

let eval { basic; _ } get_value = Constraint.Basic.eval m get_value basic
end

module Constraint = Constraint
module R1CS_constraint_system = R1CS_constraint_system
module Run_state = Run_state
end
26 changes: 24 additions & 2 deletions src/base/backend_intf.ml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,29 @@ module type S = sig

val field_size : Bigint.t

module R1CS_constraint_system : Constraint_system.S with module Field := Field
module Constraint : sig
type t [@@deriving sexp]

module Run_state : Run_state_intf.S
val boolean : Field.t Cvar.t -> t

val equal : Field.t Cvar.t -> Field.t Cvar.t -> t

val r1cs : Field.t Cvar.t -> Field.t Cvar.t -> Field.t Cvar.t -> t

val square : Field.t Cvar.t -> Field.t Cvar.t -> t

val eval : t -> (Field.t Cvar.t -> Field.t) -> bool

val log_constraint : t -> (Field.t Cvar.t -> Field.t) -> string
end

module R1CS_constraint_system :
Constraint_system.S
with module Field := Field
with type constraint_ = Constraint.t

module Run_state :
Run_state_intf.S
with type field := Field.t
and type constraint_ := Constraint.t
end
47 changes: 25 additions & 22 deletions src/base/checked.ml
Original file line number Diff line number Diff line change
@@ -1,19 +1,20 @@
open Core_kernel

module Make (Field : sig
type t [@@deriving sexp]

val equal : t -> t -> bool
end)
(Types : Types.Types)
(Basic : Checked_intf.Basic with type field = Field.t with module Types := Types)
(As_prover : As_prover_intf.Basic
with type field := Basic.field
with module Types := Types) :
module Make
(Backend : Backend_extended.S)
(Types : Types.Types)
(Basic : Checked_intf.Basic
with type field = Backend.Field.t
and type constraint_ = Backend.Constraint.t
with module Types := Types)
(As_prover : As_prover_intf.Basic
with type field := Basic.field
with module Types := Types) :
Checked_intf.S
with module Types := Types
with type field = Field.t
and type run_state = Basic.run_state = struct
with type field = Backend.Field.t
and type run_state = Basic.run_state
and type constraint_ = Basic.constraint_ = struct
include Basic

let request_witness (typ : ('var, 'value) Types.Typ.t)
Expand Down Expand Up @@ -69,23 +70,25 @@ end)
in
handle t (fun request -> (Option.value_exn !handler) request)

let assert_ ?label c = add_constraint (Constraint.override_label c label)
let assert_ c = add_constraint c

let assert_r1cs ?label a b c = assert_ (Constraint.r1cs ?label a b c)
let assert_r1cs a b c = assert_ (Backend.Constraint.r1cs a b c)

let assert_square ?label a c = assert_ (Constraint.square ?label a c)
let assert_square a c = assert_ (Backend.Constraint.square a c)

let assert_all ?label cs =
let assert_all cs =
List.fold_right cs ~init:(return ()) ~f:(fun c (acc : _ t) ->
bind acc ~f:(fun () ->
add_constraint (Constraint.override_label c label) ) )
bind acc ~f:(fun () -> add_constraint c) )

let assert_equal ?label x y =
let assert_equal x y =
match (x, y) with
| Cvar.Constant x, Cvar.Constant y ->
if Field.equal x y then return ()
if Backend.Field.equal x y then return ()
else
failwithf !"assert_equal: %{sexp: Field.t} != %{sexp: Field.t}" x y ()
failwithf
!"assert_equal: %{sexp: Backend.Field.t} != %{sexp: \
Backend.Field.t}"
x y ()
| _ ->
assert_ (Constraint.equal ?label x y)
assert_ (Backend.Constraint.equal x y)
end
25 changes: 12 additions & 13 deletions src/base/checked_intf.ml
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,15 @@ module type Basic = sig

type field

type constraint_

type 'a t = 'a Types.Checked.t

type run_state

include Monad_let.S with type 'a t := 'a t

val add_constraint : (field Cvar.t, field) Constraint.t -> unit t
val add_constraint : constraint_ -> unit t

val as_prover : unit Types.As_prover.t -> unit t

Expand All @@ -29,7 +31,7 @@ module type Basic = sig
val direct : (run_state -> run_state * 'a) -> 'a t

val constraint_count :
?weight:((field Cvar.t, field) Constraint.t -> int)
?weight:(constraint_ -> int)
-> ?log:(?start:bool -> string -> int -> unit)
-> (unit -> 'a t)
-> int
Expand All @@ -40,6 +42,8 @@ module type S = sig

type field

type constraint_

type run_state

type 'a t = 'a Types.Checked.t
Expand Down Expand Up @@ -89,25 +93,20 @@ module type S = sig

val with_label : string -> (unit -> 'a t) -> 'a t

val assert_ :
?label:Base.string -> (field Cvar.t, field) Constraint.t -> unit t
val assert_ : constraint_ -> unit t

val assert_r1cs :
?label:Base.string -> field Cvar.t -> field Cvar.t -> field Cvar.t -> unit t
val assert_r1cs : field Cvar.t -> field Cvar.t -> field Cvar.t -> unit t

val assert_square :
?label:Base.string -> field Cvar.t -> field Cvar.t -> unit t
val assert_square : field Cvar.t -> field Cvar.t -> unit t

val assert_all :
?label:Base.string -> (field Cvar.t, field) Constraint.t list -> unit t
val assert_all : constraint_ list -> unit t

val assert_equal :
?label:Base.string -> field Cvar.t -> field Cvar.t -> unit t
val assert_equal : field Cvar.t -> field Cvar.t -> unit t

val direct : (run_state -> run_state * 'a) -> 'a t

val constraint_count :
?weight:((field Cvar.t, field) Constraint.t -> int)
?weight:(constraint_ -> int)
-> ?log:(?start:bool -> string -> int -> unit)
-> (unit -> 'a t)
-> int
Expand Down
64 changes: 16 additions & 48 deletions src/base/checked_runner.ml
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
open Core_kernel
module Constraint0 = Constraint

let stack_to_string = String.concat ~sep:"\n"

Expand All @@ -10,9 +9,7 @@ let eval_constraints_ref = eval_constraints
module T (Backend : Backend_extended.S) = struct
type 'a t =
| Pure of 'a
| Function of
( Backend.Field.t Backend.Run_state.t
-> Backend.Field.t Backend.Run_state.t * 'a )
| Function of (Backend.Run_state.t -> Backend.Run_state.t * 'a)
end

module Simple_types (Backend : Backend_extended.S) = Types.Make_types (struct
Expand Down Expand Up @@ -40,15 +37,15 @@ module Make_checked
with type field := Backend.Field.t
with module Types := Types) =
struct
type run_state = Backend.Field.t Backend.Run_state.t
type run_state = Backend.Run_state.t

type constraint_ = Backend.Constraint.t

type field = Backend.Field.t

type 'a t = 'a T(Backend).t =
| Pure of 'a
| Function of
( Backend.Field.t Backend.Run_state.t
-> Backend.Field.t Backend.Run_state.t * 'a )
| Function of (Backend.Run_state.t -> Backend.Run_state.t * 'a)

let eval (t : 'a t) : run_state -> run_state * 'a =
match t with Pure a -> fun s -> (s, a) | Function g -> g
Expand Down Expand Up @@ -83,7 +80,7 @@ struct

open Backend

let get_value (t : Field.t Run_state.t) : Cvar.t -> Field.t =
let get_value (t : Run_state.t) : Cvar.t -> Field.t =
let get_one i = Run_state.get_variable_value t i in
Cvar.eval (`Return_values_will_be_mutated get_one)

Expand Down Expand Up @@ -143,36 +140,10 @@ struct
f ~at_label_boundary:(`End, lab) None ) ;
(Run_state.set_stack s' stack, y) )

let log_constraint ({ basic; _ } : Constraint.t) s =
let open Constraint0 in
match basic with
| Boolean var ->
Format.(asprintf "Boolean %s" (Field.to_string (get_value s var)))
| Equal (var1, var2) ->
Format.(
asprintf "Equal %s %s"
(Field.to_string (get_value s var1))
(Field.to_string (get_value s var2)))
| Square (var1, var2) ->
Format.(
asprintf "Square %s %s"
(Field.to_string (get_value s var1))
(Field.to_string (get_value s var2)))
| R1CS (var1, var2, var3) ->
Format.(
asprintf "R1CS %s %s %s"
(Field.to_string (get_value s var1))
(Field.to_string (get_value s var2))
(Field.to_string (get_value s var3)))
| _ ->
Format.asprintf
!"%{sexp:(Field.t, Field.t) Constraint0.basic}"
(Constraint0.Basic.map basic ~f:(get_value s))

let add_constraint ~stack ({ basic; annotation } : Constraint.t)
(Constraint_system.T ((module C), system) : Field.t Constraint_system.t) =
let label = Option.value annotation ~default:"<unknown>" in
C.add_constraint system basic ~label:(stack_to_string (label :: stack))
let add_constraint (basic : Constraint.t)
(Constraint_system.T ((module C), system) :
(Field.t, Constraint.t) Constraint_system.t ) =
C.add_constraint system basic

let add_constraint c : _ t =
Function
Expand All @@ -189,19 +160,18 @@ struct
then
failwithf
"Constraint unsatisfied (unreduced):\n\
%s\n\
%s\n\n\
Constraint:\n\
%s\n\
Data:\n\
%s"
(Constraint.annotation c)
(stack_to_string (Run_state.stack s))
(Sexp.to_string (Constraint.sexp_of_t c))
(log_constraint c s) () ;
(Backend.Constraint.log_constraint c (get_value s))
() ;
if not (Run_state.as_prover s) then
Option.iter (Run_state.system s) ~f:(fun system ->
add_constraint ~stack:(Run_state.stack s) c system ) ;
add_constraint c system ) ;
(s, ()) ) )

let with_handler h t : _ t =
Expand Down Expand Up @@ -422,17 +392,15 @@ module type S = sig
module State : sig
val make :
num_inputs:int
-> input:field Run_state.Vector.t
-> input:field Run_state_intf.Vector.t
-> next_auxiliary:int ref
-> aux:field Run_state.Vector.t
-> aux:field Run_state_intf.Vector.t
-> ?system:r1cs
-> ?eval_constraints:bool
-> ?handler:Request.Handler.t
-> with_witness:bool
-> ?log_constraint:
( ?at_label_boundary:[ `End | `Start ] * string
-> (field Cvar.t, field) Constraint.t option
-> unit )
(?at_label_boundary:[ `End | `Start ] * string -> constr -> unit)
-> unit
-> run_state
end
Expand Down
Loading

0 comments on commit 36c1add

Please sign in to comment.