Skip to content

Commit

Permalink
Push side effects to rhs of let bindings.
Browse files Browse the repository at this point in the history
  • Loading branch information
maximebuyse committed Jan 9, 2025
1 parent 1258db5 commit 17ed29f
Show file tree
Hide file tree
Showing 19 changed files with 441 additions and 223 deletions.
44 changes: 43 additions & 1 deletion engine/lib/side_effect_utils.ml
Original file line number Diff line number Diff line change
Expand Up @@ -503,6 +503,45 @@ struct
collect_and_hoist_effects_object#visit_expr CollectContext.empty e
in
(lets_of_bindings lbs e, effects)

(* This visitor binds in `let_ = e in ()` all expressions
of type unit that are not already in a let binding.
This ensures that all side effects happen in the rhs of a let binding. *)
let bind_unit_return_position =
object (self)
inherit [_] Visitors.map as super

method! visit_expr in_let e =
match e.e with
| Let { monadic; rhs; lhs; body } ->
{
e with
e =
Let
{
monadic;
rhs = self#visit_expr true rhs;
lhs = self#visit_pat false lhs;
body = self#visit_expr false body;
};
}
| _ ->
let span = e.span in
if [%eq: expr'] e.e (U.unit_expr span).e then e
else if [%eq: ty] e.typ U.unit_typ && not in_let then
{
e with
e =
Let
{
monadic = None;
rhs = self#visit_expr true e;
lhs = U.M.pat_PWild ~span ~typ:e.typ;
body = U.unit_expr span;
};
}
else super#visit_expr false e
end
end
end

Expand Down Expand Up @@ -538,7 +577,10 @@ struct
open ID

let dexpr (expr : A.expr) : B.expr =
Hoist.collect_and_hoist_effects expr |> fst |> dexpr
Hoist.collect_and_hoist_effects expr
|> fst
|> Hoist.bind_unit_return_position#visit_expr false
|> dexpr

[%%inline_defs "Item.*"]

Expand Down
6 changes: 4 additions & 2 deletions test-harness/src/snapshots/toolchain__assert into-coq.snap
Original file line number Diff line number Diff line change
Expand Up @@ -48,11 +48,13 @@ Definition asserts (_ : unit) : unit :=
let _ := assert (t_PartialEq_f_eq (1) (1)) in
let _ := match (2,2) with
| (left_val,right_val) =>
assert (t_PartialEq_f_eq (left_val) (right_val))
let _ := assert (t_PartialEq_f_eq (left_val) (right_val)) in
tt
end in
let _ := match (1,2) with
| (left_val,right_val) =>
assert (negb (t_PartialEq_f_eq (left_val) (right_val)))
let _ := assert (negb (t_PartialEq_f_eq (left_val) (right_val))) in
tt
end in
tt.
'''
8 changes: 6 additions & 2 deletions test-harness/src/snapshots/toolchain__assert into-fstar.snap
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,15 @@ let asserts (_: Prims.unit) : Prims.unit =
let _:Prims.unit = Hax_lib.v_assert (1l =. 1l <: bool) in
let _:Prims.unit =
match 2l, 2l <: (i32 & i32) with
| left_val, right_val -> Hax_lib.v_assert (left_val =. right_val <: bool)
| left_val, right_val ->
let _:Prims.unit = Hax_lib.v_assert (left_val =. right_val <: bool) in
()
in
let _:Prims.unit =
match 1l, 2l <: (i32 & i32) with
| left_val, right_val -> Hax_lib.v_assert (~.(left_val =. right_val <: bool) <: bool)
| left_val, right_val ->
let _:Prims.unit = Hax_lib.v_assert (~.(left_val =. right_val <: bool) <: bool) in
()
in
()
'''
Original file line number Diff line number Diff line change
Expand Up @@ -61,11 +61,13 @@ Equations asserts {L1 : {fset Location}} {I1 : Interface} (_ : both L1 I1 'unit)
letb _ := assert ((ret_both (1 : int32)) =.? (ret_both (1 : int32))) in
letb _ := matchb prod_b (ret_both (2 : int32),ret_both (2 : int32)) with
| '(left_val,right_val) =>
solve_lift (assert (left_val =.? right_val))
letb _ := assert (left_val =.? right_val) in
solve_lift (ret_both (tt : 'unit))
end in
letb _ := matchb prod_b (ret_both (1 : int32),ret_both (2 : int32)) with
| '(left_val,right_val) =>
solve_lift (assert (not (left_val =.? right_val)))
letb _ := assert (not (left_val =.? right_val)) in
solve_lift (ret_both (tt : 'unit))
end in
solve_lift (ret_both (tt : 'unit)) : both L1 I1 'unit.
Fail Next Obligation.
Expand Down
19 changes: 15 additions & 4 deletions test-harness/src/snapshots/toolchain__attributes into-fstar.snap
Original file line number Diff line number Diff line change
Expand Up @@ -412,8 +412,16 @@ let impl: t_Foo Prims.unit =
f_h_pre = (fun (x: u8) (y: u8) -> true);
f_h_post
=
(fun (x: u8) (y: u8) (output_variable: Prims.unit) -> output_variable =. (() <: Prims.unit));
f_h = (fun (x: u8) (y: u8) -> () <: Prims.unit);
(fun (x: u8) (y: u8) (output_variable: Prims.unit) ->
(let _:Prims.unit = output_variable in
()) =.
(let _:Prims.unit = () <: Prims.unit in
()));
f_h
=
(fun (x: u8) (y: u8) ->
let _:Prims.unit = () <: Prims.unit in
());
f_i_pre = (fun (x: u8) (y: u8) -> true);
f_i_post = (fun (x: u8) (y: u8) (y_future: u8) -> y_future =. y);
f_i
Expand Down Expand Up @@ -452,11 +460,14 @@ let another_panicfree_function (_: Prims.unit)
let not_much:i32 = 0l in
let nothing:i32 = 0l in
let still_not_much:i32 = not_much +! nothing in
admit () (* Panic freedom *)
let _:Prims.unit = admit () (* Panic freedom *) in
()

#push-options "--admit_smt_queries true"

let a_function_which_only_laxes (_: Prims.unit) : Prims.unit = Hax_lib.v_assert false
let a_function_which_only_laxes (_: Prims.unit) : Prims.unit =
let _:Prims.unit = Hax_lib.v_assert false in
()

#pop-options
'''
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,13 +50,18 @@ open FStar.Mul

let f (_: Prims.unit) : Prims.unit = ()

let g (_: Prims.unit) : Prims.unit = f ()
let g (_: Prims.unit) : Prims.unit =
let _:Prims.unit = f () in
()

let h (_: Prims.unit) : Prims.unit =
let _:Prims.unit = g () in
Cyclic_modules.C.i ()
let _:Prims.unit = Cyclic_modules.C.i () in
()

let h2 (_: Prims.unit) : Prims.unit = Cyclic_modules.C.i ()
let h2 (_: Prims.unit) : Prims.unit =
let _:Prims.unit = Cyclic_modules.C.i () in
()
'''
"Cyclic_modules.D.Cyclic_bundle_81544935.fst" = '''
module Cyclic_modules.D.Cyclic_bundle_81544935
Expand All @@ -66,11 +71,17 @@ open FStar.Mul

let d1 (_: Prims.unit) : Prims.unit = ()

let e1 (_: Prims.unit) : Prims.unit = d1 ()
let e1 (_: Prims.unit) : Prims.unit =
let _:Prims.unit = d1 () in
()

let de1 (_: Prims.unit) : Prims.unit = e1 ()
let de1 (_: Prims.unit) : Prims.unit =
let _:Prims.unit = e1 () in
()

let d2 (_: Prims.unit) : Prims.unit = de1 ()
let d2 (_: Prims.unit) : Prims.unit =
let _:Prims.unit = de1 () in
()
'''
"Cyclic_modules.D.fst" = '''
module Cyclic_modules.D
Expand Down Expand Up @@ -100,9 +111,13 @@ let g (_: Prims.unit) : Prims.unit = ()

let h (_: Prims.unit) : Prims.unit = ()

let f (_: Prims.unit) : Prims.unit = h ()
let f (_: Prims.unit) : Prims.unit =
let _:Prims.unit = h () in
()

let i (_: Prims.unit) : Prims.unit = g ()
let i (_: Prims.unit) : Prims.unit =
let _:Prims.unit = g () in
()
'''
"Cyclic_modules.Disjoint_cycle_a.fst" = '''
module Cyclic_modules.Disjoint_cycle_a
Expand Down Expand Up @@ -202,10 +217,13 @@ module Cyclic_modules.Late_skip_a.Cyclic_bundle_658016071
open Core
open FStar.Mul

let rec ff_749016415 (_: Prims.unit) : Prims.unit = ff_377825240 ()
let rec ff_749016415 (_: Prims.unit) : Prims.unit =
let _:Prims.unit = ff_377825240 () in
()

and ff_377825240 (_: Prims.unit) : Prims.Pure Prims.unit (requires true) (fun _ -> Prims.l_True) =
ff_749016415 ()
let _:Prims.unit = ff_749016415 () in
()
'''
"Cyclic_modules.Late_skip_a.fst" = '''
module Cyclic_modules.Late_skip_a
Expand All @@ -231,13 +249,16 @@ open FStar.Mul

let c (_: Prims.unit) : Prims.unit = ()

let a (_: Prims.unit) : Prims.unit = c ()
let a (_: Prims.unit) : Prims.unit =
let _:Prims.unit = c () in
()

let d (_: Prims.unit) : Prims.unit = ()

let b (_: Prims.unit) : Prims.unit =
let _:Prims.unit = a () in
d ()
let _:Prims.unit = d () in
()
'''
"Cyclic_modules.M1.fst" = '''
module Cyclic_modules.M1
Expand Down
15 changes: 12 additions & 3 deletions test-harness/src/snapshots/toolchain__generics into-fstar.snap
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,11 @@ let impl__Test__set_alpn_protocols
(self: t_Test)
(v__protocols: impl_995885649_)
: Core.Result.t_Result Prims.unit Prims.unit =
Core.Result.Result_Ok (() <: Prims.unit) <: Core.Result.t_Result Prims.unit Prims.unit
Core.Result.Result_Ok
(let _:Prims.unit = () <: Prims.unit in
())
<:
Core.Result.t_Result Prims.unit Prims.unit

let impl__Test__set_ciphersuites
(#v_S #impl_995885649_: Type0)
Expand All @@ -65,7 +69,11 @@ let impl__Test__set_ciphersuites
(self: t_Test)
(ciphers: impl_995885649_)
: Core.Result.t_Result Prims.unit Prims.unit =
Core.Result.Result_Ok (() <: Prims.unit) <: Core.Result.t_Result Prims.unit Prims.unit
Core.Result.Result_Ok
(let _:Prims.unit = () <: Prims.unit in
())
<:
Core.Result.t_Result Prims.unit Prims.unit
'''
"Generics.fst" = '''
module Generics
Expand Down Expand Up @@ -153,7 +161,8 @@ let foo (v_LEN: usize) (arr: t_Array usize v_LEN) : usize =
(fun acc i ->
let acc:usize = acc in
let i:usize = i in
acc +! (arr.[ i ] <: usize) <: usize)
let acc:usize = acc +! (arr.[ i ] <: usize) in
acc)
in
acc

Expand Down
9 changes: 8 additions & 1 deletion test-harness/src/snapshots/toolchain__literals into-coq.snap
Original file line number Diff line number Diff line change
Expand Up @@ -114,20 +114,26 @@ Definition numeric (_ : unit) : unit :=
Definition patterns (_ : unit) : unit :=
let _ := match 1 with
| 2 =>
let _ := tt in
tt
| _ =>
let _ := tt in
tt
end in
let _ := match ("hello"%string,(123,["a"%string; "b"%string])) with
| ("hello"%string,(123,v__todo)) =>
let _ := tt in
tt
| _ =>
let _ := tt in
tt
end in
let _ := match Build_t_Foo (4) with
| Foo (3) =>
let _ := tt in
tt
| _ =>
let _ := tt in
tt
end in
tt.
Expand All @@ -139,7 +145,8 @@ Definition patterns (_ : unit) : unit :=


Definition panic_with_msg (_ : unit) : unit :=
never_to_any (panic_fmt (impl_2__new_const (["with msg"%string]))).
let _ := never_to_any (panic_fmt (impl_2__new_const (["with msg"%string]))) in
tt.

Definition empty_array (_ : unit) : unit :=
let _ : t_Slice t_u8 := unsize ([]) in
Expand Down
43 changes: 29 additions & 14 deletions test-harness/src/snapshots/toolchain__literals into-fstar.snap
Original file line number Diff line number Diff line change
Expand Up @@ -154,8 +154,12 @@ let numeric (_: Prims.unit) : Prims.unit =
let patterns (_: Prims.unit) : Prims.unit =
let _:Prims.unit =
match 1uy <: u8 with
| 2uy -> () <: Prims.unit
| _ -> () <: Prims.unit
| 2uy ->
let _:Prims.unit = () <: Prims.unit in
()
| _ ->
let _:Prims.unit = () <: Prims.unit in
()
in
let _:Prims.unit =
match
Expand All @@ -169,13 +173,21 @@ let patterns (_: Prims.unit) : Prims.unit =
<:
(string & (i32 & t_Array string (sz 2)))
with
| "hello", (123l, v__todo) -> () <: Prims.unit
| _ -> () <: Prims.unit
| "hello", (123l, v__todo) ->
let _:Prims.unit = () <: Prims.unit in
()
| _ ->
let _:Prims.unit = () <: Prims.unit in
()
in
let _:Prims.unit =
match { f_field = 4uy } <: t_Foo with
| { f_field = 3uy } -> () <: Prims.unit
| _ -> () <: Prims.unit
| { f_field = 3uy } ->
let _:Prims.unit = () <: Prims.unit in
()
| _ ->
let _:Prims.unit = () <: Prims.unit in
()
in
()

Expand All @@ -198,14 +210,17 @@ val impl_2': Core.Cmp.t_Eq t_Foo
let impl_2 = impl_2'
let panic_with_msg (_: Prims.unit) : Prims.unit =
Rust_primitives.Hax.never_to_any (Core.Panicking.panic_fmt (Core.Fmt.impl_2__new_const (sz 1)
(let list = ["with msg"] in
FStar.Pervasives.assert_norm (Prims.eq2 (List.Tot.length list) 1);
Rust_primitives.Hax.array_of_list 1 list)
<:
Core.Fmt.t_Arguments)
<:
Rust_primitives.Hax.t_Never)
let _:Prims.unit =
Rust_primitives.Hax.never_to_any (Core.Panicking.panic_fmt (Core.Fmt.impl_2__new_const (sz 1)
(let list = ["with msg"] in
FStar.Pervasives.assert_norm (Prims.eq2 (List.Tot.length list) 1);
Rust_primitives.Hax.array_of_list 1 list)
<:
Core.Fmt.t_Arguments)
<:
Rust_primitives.Hax.t_Never)
in
()

let empty_array (_: Prims.unit) : Prims.unit =
let _:t_Slice u8 =
Expand Down
Loading

0 comments on commit 17ed29f

Please sign in to comment.