diff --git a/source/rust_verify/src/rust_to_vir.rs b/source/rust_verify/src/rust_to_vir.rs index bbe57bcf8b..140d35ca91 100644 --- a/source/rust_verify/src/rust_to_vir.rs +++ b/source/rust_verify/src/rust_to_vir.rs @@ -280,11 +280,24 @@ fn check_item<'tcx>( false, )?); } + let types = Arc::new(types); let path = def_id_to_vir_path(ctxt.tcx, &ctxt.verus_items, path.res.def_id()); - let trait_impl = - vir::ast::TraitImplX { impl_path: impl_path.clone(), trait_path: path.clone() }; + let (typ_params, typ_bounds) = crate::rust_to_vir_base::check_generics_bounds_fun( + ctxt.tcx, + &ctxt.verus_items, + impll.generics, + impl_def_id, + Some(&mut *ctxt.diagnostics.borrow_mut()), + )?; + let trait_impl = vir::ast::TraitImplX { + impl_path: impl_path.clone(), + typ_params, + typ_bounds, + trait_path: path.clone(), + trait_typ_args: types.clone(), + }; vir.trait_impls.push(ctxt.spanned_new(item.span, trait_impl)); - Some((path, Arc::new(types))) + Some((path, types)) } else { None }; @@ -553,6 +566,7 @@ fn check_item<'tcx>( vir.functions.append(&mut methods); let traitx = vir::ast::TraitX { name: trait_path, + visibility: visibility(), methods: Arc::new(method_names), assoc_typs: Arc::new(assoc_typs), typ_params: generics_params, diff --git a/source/rust_verify/src/verifier.rs b/source/rust_verify/src/verifier.rs index 092a937967..8e421696e6 100644 --- a/source/rust_verify/src/verifier.rs +++ b/source/rust_verify/src/verifier.rs @@ -881,6 +881,15 @@ impl Verifier { &("Datatypes".to_string()), ); + let trait_commands = vir::traits::traits_to_air(ctx, &krate); + self.run_commands( + module, + reporter, + &mut air_context, + &trait_commands, + &("Traits".to_string()), + ); + let assoc_type_impl_commands = vir::assoc_types_to_air::assoc_type_impls_to_air(ctx, &krate.assoc_type_impls); self.run_commands( @@ -1311,7 +1320,7 @@ impl Verifier { reporter.report_now(¬e_bare(format!("verifying {module_msg}{functions_msg}"))); } - let (pruned_krate, mono_abstract_datatypes, lambda_types) = + let (pruned_krate, mono_abstract_datatypes, lambda_types, bound_traits) = vir::prune::prune_krate_for_module(&krate, &module, &self.vstd_crate_name); let mut ctx = vir::context::Ctx::new( &pruned_krate, @@ -1319,6 +1328,7 @@ impl Verifier { module.clone(), mono_abstract_datatypes, lambda_types, + bound_traits, self.args.debug, )?; let poly_krate = vir::poly::poly_krate_for_module(&mut ctx, &pruned_krate); diff --git a/source/rust_verify_test/tests/recursion.rs b/source/rust_verify_test/tests/recursion.rs index b262430647..1858b47fa2 100644 --- a/source/rust_verify_test/tests/recursion.rs +++ b/source/rust_verify_test/tests/recursion.rs @@ -797,7 +797,7 @@ test_verify_one_file! { //#[verifier(decreases_by)] proof fn check_arith_sum(i: int) { } - } => Err(err) => assert_vir_error_msg(err, "proof function must be marked #[verifier(decreases_by)] or #[verifier(recommends_by)] to be used as decreases_by/recommends_by") + } => Err(err) => assert_vir_error_msg(err, "proof function must be marked #[verifier::decreases_by] or #[verifier::recommends_by] to be used as decreases_by/recommends_by") } test_verify_one_file! { diff --git a/source/rust_verify_test/tests/traits.rs b/source/rust_verify_test/tests/traits.rs index 40a9543278..ab65bec43d 100644 --- a/source/rust_verify_test/tests/traits.rs +++ b/source/rust_verify_test/tests/traits.rs @@ -74,30 +74,6 @@ test_verify_one_file! { } => Ok(()) } -test_verify_one_file! { - #[test] test_not_yet_supported_10 verus_code! { - trait T { - spec fn f(&self) -> bool; - - proof fn p(&self) - ensures exists|x: &Self| self.f() != x.f(); - } - - #[verifier::external_body] /* vattr */ - #[verifier::broadcast_forall] /* vattr */ - proof fn f_not_g() - ensures exists|x: &A, y: &A| x.f() != y.f() - { - } - - struct S {} - - fn test() { - assert(false); - } - } => Err(err) => assert_vir_error_msg(err, ": bounds on broadcast_forall function type parameters") -} - test_verify_one_file! { #[test] test_not_yet_supported_11 verus_code! { trait T { @@ -1194,6 +1170,138 @@ test_verify_one_file! { } => Ok(()) } +test_verify_one_file! { + #[test] test_broadcast_forall1 verus_code! { + trait T { + spec fn f(&self) -> bool; + + proof fn p(&self) + ensures exists|x: &Self| self.f() != x.f(); + } + + spec fn g() -> bool { + exists|x: &A, y: &A| x.f() != y.f() + } + spec fn t() -> bool { true } + + #[verifier::external_body] /* vattr */ + #[verifier::broadcast_forall] /* vattr */ + proof fn f_not_g() + ensures + #[trigger] t::(), + g::(), + { + } + + struct S1 {} + impl T for S1 { + spec fn f(&self) -> bool { + true + } + + proof fn p(&self) { + assert(exists|x: &Self| self.f() != x.f()); // FAILS + } + } + + struct S2 {} + + struct S3(bool); + impl T for S3 { + spec fn f(&self) -> bool { + self.0 + } + + proof fn p(&self) { + assert(self.f() != S3(!self.0).f()) + } + } + + fn test1() { + assert(t::()); + assert(false); + } + + fn test2() { + assert(t::()); + assert(false); // FAILS + } + + fn test3() { + assert(t::()); + assert(false); // FAILS + } + } => Err(err) => assert_fails(err, 3) +} + +test_verify_one_file! { + #[test] test_broadcast_forall2 verus_code! { + trait T1 {} + trait T2 {} + + struct S(E); + + impl T1 for S {} + impl> T2 for S<(Y, u8)> {} + + spec fn f(i: int) -> bool; + + #[verifier::external_body] + #[verifier::broadcast_forall] + proof fn p, u16>>(i: int) + ensures f::(i) + { + } + + proof fn test1() { + assert(f::>(3)); + } + + proof fn test2() { + assert(f::>(3)); // FAILS + } + + proof fn test3() { + assert(f::>(3)); // FAILS + } + } => Err(err) => assert_fails(err, 2) +} + +test_verify_one_file! { + #[test] test_decreases_trait_bound verus_code! { + trait T { + proof fn impossible() + ensures false; + } + + spec fn f(i: int) -> bool + decreases 0int when true via f_decreases:: + { + !f::(i - 0) + } + + #[verifier::decreases_by] + proof fn f_decreases(i: int) { + A::impossible(); + } + + proof fn test1(i: int) { + assert(f::(i) == !f::(i - 0)); + assert(false); + } + + proof fn test2() { + // We'd like to test that f's definition axiom only applies to A that implement T. + // Ideally, we'd test this by applying f to an A that doesn't implement T + // and seeing that the definition axiom doesn't apply. + // Unfortunately, it's hard to test this because Rust's type checker already (correctly) + // stops us from saying f::(x) for ty that doesn't implement T. + // So we have to manually check the AIR code for the axiom off line. + assert(false); // FAILS + } + } => Err(err) => assert_fails(err, 1) +} + test_verify_one_file! { #[test] test_synthetic_type_params verus_code!{ spec fn global_type_id() -> int; diff --git a/source/rust_verify_test/tests/traits_modules.rs b/source/rust_verify_test/tests/traits_modules.rs index 94cab8b009..ec4fb80dcb 100644 --- a/source/rust_verify_test/tests/traits_modules.rs +++ b/source/rust_verify_test/tests/traits_modules.rs @@ -68,38 +68,6 @@ test_verify_one_file! { } => Ok(()) } -test_verify_one_file! { - #[test] test_not_yet_supported_10 verus_code! { - mod M1 { - pub trait T { - spec fn f(&self) -> bool; - - proof fn p(&self) - ensures exists|x: &Self| self.f() != x.f(); - } - } - - mod M2 { - #[verifier::external_body] /* vattr */ - #[verifier::broadcast_forall] /* vattr */ - proof fn f_not_g() - ensures exists|x: &A, y: &A| x.f() != y.f() - { - } - } - - mod M3 { - struct S {} - } - - mod M4 { - fn test() { - assert(false); - } - } - } => Err(err) => assert_vir_error_msg(err, ": bounds on broadcast_forall function type parameters") -} - test_verify_one_file! { #[test] test_ill_formed_7 code! { mod M1 { diff --git a/source/rust_verify_test/tests/traits_modules_pub_crate.rs b/source/rust_verify_test/tests/traits_modules_pub_crate.rs index 7b83f02ad9..cbd872bedb 100644 --- a/source/rust_verify_test/tests/traits_modules_pub_crate.rs +++ b/source/rust_verify_test/tests/traits_modules_pub_crate.rs @@ -69,38 +69,6 @@ test_verify_one_file! { } => Ok(()) } -test_verify_one_file! { - #[test] test_not_yet_supported_10 verus_code! { - mod M1 { - pub(crate) trait T { - spec fn f(&self) -> bool; - - proof fn p(&self) - ensures exists|x: &Self| self.f() != x.f(); - } - } - - mod M2 { - #[verifier::external_body] /* vattr */ - #[verifier::broadcast_forall] /* vattr */ - proof fn f_not_g() - ensures exists|x: &A, y: &A| x.f() != y.f() - { - } - } - - mod M3 { - struct S {} - } - - mod M4 { - fn test() { - assert(false); - } - } - } => Err(err) => assert_vir_error_msg(err, ": bounds on broadcast_forall function type parameters") -} - test_verify_one_file! { #[test] test_ill_formed_7 code! { mod M1 { diff --git a/source/vir/src/ast.rs b/source/vir/src/ast.rs index 387fce2d9f..28665658a9 100644 --- a/source/vir/src/ast.rs +++ b/source/vir/src/ast.rs @@ -224,6 +224,8 @@ pub enum ModeCoercion { pub enum NullaryOpr { /// convert a const generic into an expression, as in fn f() -> usize { N } ConstGeneric(Typ), + /// predicate representing a satisfied trait bound T(t1, ..., tn) for trait T + TraitBound(Path, Typs), } /// Primitive unary operations @@ -925,6 +927,7 @@ pub type Trait = Arc>; #[derive(Clone, Debug, Serialize, Deserialize, ToDebugSNode)] pub struct TraitX { pub name: Path, + pub visibility: Visibility, // REVIEW: typ_params does not yet explicitly include Self (right now, Self is implicit) pub typ_params: TypPositives, pub typ_bounds: GenericBounds, @@ -942,7 +945,7 @@ pub struct AssocTypeImplX { pub typ_params: Idents, pub typ_bounds: GenericBounds, pub trait_path: Path, - pub trait_typ_args: Arc>, + pub trait_typ_args: Typs, pub typ: Typ, } @@ -951,7 +954,11 @@ pub type TraitImpl = Arc>; pub struct TraitImplX { /// Path of the impl (e.g. "impl2") pub impl_path: Path, + // typ_params of impl (unrelated to typ_params of trait) + pub typ_params: Idents, + pub typ_bounds: GenericBounds, pub trait_path: Path, + pub trait_typ_args: Typs, } #[derive(Clone, Debug, Hash, Serialize, Deserialize, ToDebugSNode, PartialEq, Eq)] diff --git a/source/vir/src/ast_simplify.rs b/source/vir/src/ast_simplify.rs index 1f507adfb7..b9f766cba3 100644 --- a/source/vir/src/ast_simplify.rs +++ b/source/vir/src/ast_simplify.rs @@ -6,7 +6,7 @@ use crate::ast::{ AssocTypeImpl, AutospecUsage, BinaryOp, Binder, BuiltinSpecFun, CallTarget, ChainedOp, Constant, Datatype, DatatypeTransparency, DatatypeX, Expr, ExprX, Exprs, Field, FieldOpr, Function, FunctionKind, Ident, IntRange, Krate, KrateX, Mode, MultiOp, Path, Pattern, PatternX, - SpannedTyped, Stmt, StmtX, Typ, TypX, UnaryOp, UnaryOpr, VirErr, Visibility, + SpannedTyped, Stmt, StmtX, TraitImpl, Typ, TypX, UnaryOp, UnaryOpr, VirErr, Visibility, }; use crate::ast_util::int_range_from_type; use crate::ast_util::is_integer_type; @@ -790,6 +790,16 @@ fn simplify_datatype(state: &mut State, datatype: &Datatype) -> Result Result { + let mut local = LocalCtxt { span: imp.span.clone(), typ_params: Vec::new() }; + for x in imp.x.typ_params.iter() { + local.typ_params.push(x.clone()); + } + crate::ast_visitor::map_trait_impl_visitor_env(imp, state, &|state, typ| { + simplify_one_typ(&local, state, typ) + }) +} + fn simplify_assoc_type_impl( state: &mut State, assoc: &AssocTypeImpl, @@ -855,6 +865,7 @@ pub fn simplify_krate(ctx: &mut GlobalCtx, krate: &Krate) -> Result Result { + let ts = map_typs_visitor_env(ts, env, ft)?; + ExprX::NullaryOpr(crate::ast::NullaryOpr::TraitBound(p.clone(), ts)) + } ExprX::Unary(op, e1) => { let expr1 = map_expr_visitor_env(e1, map, env, fe, fs, ft)?; ExprX::Unary(*op, expr1) @@ -1129,6 +1134,25 @@ where Ok(Spanned::new(datatype.span.clone(), DatatypeX { variants, typ_bounds, ..datatypex })) } +pub(crate) fn map_trait_impl_visitor_env( + imp: &TraitImpl, + env: &mut E, + ft: &FT, +) -> Result +where + FT: Fn(&mut E, &Typ) -> Result, +{ + let TraitImplX { impl_path, typ_params, typ_bounds, trait_path, trait_typ_args } = &imp.x; + let impx = TraitImplX { + impl_path: impl_path.clone(), + typ_params: typ_params.clone(), + typ_bounds: map_generic_bounds_visitor(typ_bounds, env, ft)?, + trait_path: trait_path.clone(), + trait_typ_args: map_typs_visitor_env(trait_typ_args, env, ft)?, + }; + Ok(Spanned::new(imp.span.clone(), impx)) +} + pub(crate) fn map_assoc_type_impl_visitor_env( assoc: &AssocTypeImpl, env: &mut E, diff --git a/source/vir/src/context.rs b/source/vir/src/context.rs index a0853394bd..5b62cff1e5 100644 --- a/source/vir/src/context.rs +++ b/source/vir/src/context.rs @@ -69,6 +69,7 @@ pub struct Ctx { pub(crate) datatypes_with_invariant: HashSet, pub(crate) mono_types: Vec, pub(crate) lambda_types: Vec, + pub(crate) bound_traits: HashSet, pub functions: Vec, pub func_map: HashMap, // Ensure a unique identifier for each quantifier in a given function @@ -283,6 +284,7 @@ impl Ctx { module: Path, mono_types: Vec, lambda_types: Vec, + bound_traits: HashSet, debug: bool, ) -> Result { let mut datatype_is_transparent: HashMap = HashMap::new(); @@ -315,6 +317,7 @@ impl Ctx { datatypes_with_invariant, mono_types, lambda_types, + bound_traits, functions, func_map, quantifier_count, diff --git a/source/vir/src/def.rs b/source/vir/src/def.rs index 9e7a0b1046..1e42d4fdd4 100644 --- a/source/vir/src/def.rs +++ b/source/vir/src/def.rs @@ -58,6 +58,7 @@ const PREFIX_LAMBDA_TYPE: &str = "fun%"; const PREFIX_IMPL_IDENT: &str = "impl&%"; const PREFIX_PROJECT: &str = "proj%"; const PREFIX_PROJECT_DECORATION: &str = "proj%%"; +const PREFIX_TRAIT_BOUND: &str = "tr_bound%"; const SLICE_TYPE: &str = "slice%"; const ARRAY_TYPE: &str = "array%"; const PREFIX_SNAPSHOT: &str = "snap%"; @@ -175,6 +176,7 @@ pub const QID_HEIGHT_APPLY: &str = "height_apply"; pub const QID_ACCESSOR: &str = "accessor"; pub const QID_INVARIANT: &str = "invariant"; pub const QID_HAS_TYPE_ALWAYS: &str = "has_type_always"; +pub const QID_TRAIT_IMPL: &str = "trait_impl"; pub const QID_ASSOC_TYPE_IMPL: &str = "assoc_type_impl"; pub const VERUS_SPEC: &str = "VERUS_SPEC__"; @@ -366,6 +368,10 @@ pub fn projection(decoration: bool, trait_path: &Path, name: &Ident) -> Ident { )) } +pub fn trait_bound(trait_path: &Path) -> Ident { + Arc::new(format!("{}{}", PREFIX_TRAIT_BOUND, path_to_string(trait_path))) +} + pub fn prefix_type_id_fun(i: usize) -> Ident { prefix_type_id(&prefix_lambda_type(i)) } diff --git a/source/vir/src/func_to_air.rs b/source/vir/src/func_to_air.rs index c8d4662468..35e09f9d79 100644 --- a/source/vir/src/func_to_air.rs +++ b/source/vir/src/func_to_air.rs @@ -158,11 +158,15 @@ fn func_body_to_air( state.fun_ssts.borrow_mut().insert(function.x.name.clone(), info); let mut decrease_by_stms: Vec = Vec::new(); - let decrease_by_reqs = if let Some(req) = &function.x.decrease_when { + let def_reqs = if let Some(req) = &function.x.decrease_when { + // "when" means the function is only defined if the requirements hold, + // including trait bound requirements + let mut def_reqs = crate::traits::trait_bounds_to_air(ctx, &function.x.typ_bounds); let exp = crate::ast_to_sst::expr_to_exp(ctx, diagnostics, &state.fun_ssts, &pars, req)?; let expr = exp_to_expr(ctx, &exp, &ExprCtxt::new_mode(ExprMode::Spec))?; decrease_by_stms.push(Spanned::new(req.span.clone(), StmX::Assume(exp))); - vec![expr] + def_reqs.push(expr); + def_reqs } else { vec![] }; @@ -283,7 +287,7 @@ fn func_body_to_air( let name_body = format!("{}_fuel_to_body", &fun_to_air_ident(&name)); let bind_zero = func_bind(ctx, name_zero, &function.x.typ_params, &pars, &rec_f_fuel, true); let bind_body = func_bind(ctx, name_body, &function.x.typ_params, &pars, &rec_f_succ, true); - let implies_body = mk_implies(&mk_and(&decrease_by_reqs), &eq_body); + let implies_body = mk_implies(&mk_and(&def_reqs), &eq_body); let forall_zero = mk_bind_expr(&bind_zero, &eq_zero); let forall_body = mk_bind_expr(&bind_body, &implies_body); let fuel_nat_decl = Arc::new(DeclX::Const(fuel_nat_f, str_typ(FUEL_TYPE))); @@ -301,7 +305,7 @@ fn func_body_to_air( &function.x.typ_params, &typ_args, &pars, - &decrease_by_reqs, + &def_reqs, def_body, )?; let fuel_bool = str_apply(FUEL_BOOL, &vec![ident_var(&id_fuel)]); @@ -657,9 +661,6 @@ pub fn func_axioms_to_air( use crate::triggers::{typ_boxing, TriggerBoxing}; let mut vars: Vec<(Ident, TriggerBoxing)> = Vec::new(); let mut binders: Vec> = Vec::new(); - if function.x.typ_bounds.len() != 0 { - todo!() - } for name in function.x.typ_params.iter() { vars.push((suffix_typ_param_id(&name), TriggerBoxing::TypeId)); let typ = Arc::new(TypX::TypeId); @@ -772,6 +773,11 @@ pub fn func_def_to_air( let mut req_stms: Vec = Vec::new(); let mut reqs: Vec = Vec::new(); + reqs.extend(crate::traits::trait_bounds_to_sst( + ctx, + &function.span, + &function.x.typ_bounds, + )); for e in req_ens_function.x.require.iter() { let e_with_req_ens_params = map_expr_rename_vars(e, &req_ens_e_rename)?; if ctx.checking_recommends() { diff --git a/source/vir/src/modes.rs b/source/vir/src/modes.rs index 9cf4548e67..5d547827fa 100644 --- a/source/vir/src/modes.rs +++ b/source/vir/src/modes.rs @@ -636,6 +636,7 @@ fn check_expr_handle_mut_arg( Ok(mode) } ExprX::NullaryOpr(crate::ast::NullaryOpr::ConstGeneric(_)) => Ok(Mode::Exec), + ExprX::NullaryOpr(crate::ast::NullaryOpr::TraitBound(..)) => Ok(Mode::Spec), ExprX::Unary(UnaryOp::CoerceMode { op_mode, from_mode, to_mode, kind }, e1) => { // same as a call to an op_mode function with parameter from_mode and return to_mode if typing.check_ghost_blocks { diff --git a/source/vir/src/poly.rs b/source/vir/src/poly.rs index 7c05dc581b..cbbd3fe76c 100644 --- a/source/vir/src/poly.rs +++ b/source/vir/src/poly.rs @@ -386,6 +386,7 @@ fn poly_expr(ctx: &Ctx, state: &mut State, expr: &Expr) -> Expr { mk_expr(ExprX::Ctor(path.clone(), variant.clone(), Arc::new(bs), None)) } ExprX::NullaryOpr(crate::ast::NullaryOpr::ConstGeneric(_)) => expr.clone(), + ExprX::NullaryOpr(crate::ast::NullaryOpr::TraitBound(..)) => expr.clone(), ExprX::Unary(op, e1) => { let e1 = poly_expr(ctx, state, e1); match op { @@ -870,7 +871,14 @@ fn poly_function(ctx: &Ctx, function: &Function) -> Function { let broadcast_params = Arc::new(new_params); let span = &function.span; - let req = crate::ast_util::conjoin(span, &*function.x.require); + let mut reqs: Vec = Vec::new(); + reqs.extend(crate::traits::trait_bounds_to_ast( + ctx, + &function.span, + &function.x.typ_bounds, + )); + reqs.extend((*function.x.require).clone()); + let req = crate::ast_util::conjoin(span, &reqs); let ens = crate::ast_util::conjoin(span, &*function.x.ensure); let req_ens = crate::ast_util::mk_implies(span, &req, &ens); let req_ens = coerce_expr_to_native(ctx, &poly_expr(ctx, &mut state, &req_ens)); diff --git a/source/vir/src/prune.rs b/source/vir/src/prune.rs index 5bb8192db1..207382f8aa 100644 --- a/source/vir/src/prune.rs +++ b/source/vir/src/prune.rs @@ -35,10 +35,29 @@ enum ReachedType { // Group all AssocTypeImpls with the same (ReachedType(self_typ), (trait_path, name)): type AssocTypeGroup = (ReachedType, (Path, Ident)); +type TraitName = Path; +type ImplName = Path; + +#[derive(Debug)] +struct TraitImpl { + // For an impl "...T'(...t'...)... ==> trait T(...t...)", + // list all traits T' and types t' in the bounds: + bound_traits: Vec, + bound_types: Vec, + // list T and all t: + trait_name: TraitName, + trait_typ_args: Vec, +} + struct Ctxt { module: Path, function_map: HashMap, datatype_map: HashMap, + // For an impl "bounds ==> trait T(...t...)", point T to impl: + trait_to_trait_impls: HashMap>, + // For an impl "bounds ==> trait T(...t...)", point t to impl: + typ_to_trait_impls: HashMap>, + trait_impl_map: HashMap, assoc_type_impl_map: HashMap>, // Map (D, T.f) -> D.f if D implements T.f: method_map: HashMap<(ReachedType, Fun), Vec>, @@ -50,11 +69,15 @@ struct Ctxt { struct State { reached_functions: HashSet, reached_types: HashSet, + reached_bound_traits: HashSet, + reached_trait_impls: HashSet, reached_assoc_type_decls: HashSet<(Path, Ident)>, reached_assoc_type_impls: HashSet, reached_modules: HashSet, worklist_functions: Vec, worklist_types: Vec, + worklist_bound_traits: Vec, + worklist_trait_impls: Vec, worklist_assoc_type_decls: Vec<(Path, Ident)>, worklist_assoc_type_impls: Vec, worklist_modules: Vec, @@ -112,6 +135,25 @@ fn reach_function(ctxt: &Ctxt, state: &mut State, name: &Fun) { } } +fn reach_bound_trait(_ctxt: &Ctxt, state: &mut State, name: &TraitName) { + reach(&mut state.reached_bound_traits, &mut state.worklist_bound_traits, name); +} + +fn reach_trait_impl(ctxt: &Ctxt, state: &mut State, imp: &ImplName) { + if let Some(trait_impl) = ctxt.trait_impl_map.get(imp) { + // We only reach the impl "bounds ==> trait T(...t...)" when all of T and t have been reached. + // Otherwise, we consider the impl irrelevant. + for t in &trait_impl.trait_typ_args { + if *t != ReachedType::None && !state.reached_types.contains(t) { + return; + } + } + if state.reached_bound_traits.contains(&trait_impl.trait_name) { + reach(&mut state.reached_trait_impls, &mut state.worklist_trait_impls, imp); + } + } +} + fn reach_assoc_type_decl(_ctxt: &Ctxt, state: &mut State, name: &(Path, Ident)) { reach(&mut state.reached_assoc_type_decls, &mut state.worklist_assoc_type_decls, name); } @@ -207,6 +249,12 @@ fn traverse_reachable(ctxt: &Ctxt, state: &mut State) { if let FunctionKind::TraitMethodImpl { method, .. } = &function.x.kind { reach_function(ctxt, state, method); } + if function.x.mode == crate::ast::Mode::Spec || function.x.attrs.broadcast_forall { + for bound in function.x.typ_bounds.iter() { + let crate::ast::GenericBoundX::Trait(path, _) = &**bound; + reach_bound_trait(ctxt, state, path); + } + } let fe = |state: &mut State, _: &mut ScopeMap, e: &Expr| { // note: the visitor automatically reaches e.typ match &e.x { @@ -263,6 +311,11 @@ fn traverse_reachable(ctxt: &Ctxt, state: &mut State) { } _ => {} } + if let Some(imps) = ctxt.typ_to_trait_impls.get(&t) { + for imp in imps { + reach_trait_impl(ctxt, state, imp); + } + } let methods = reached_methods(ctxt, state.reached_functions.iter().map(|f| (&t, f))); reach_methods(ctxt, state, methods); let assoc_decls: Vec<(Path, Ident)> = @@ -272,6 +325,23 @@ fn traverse_reachable(ctxt: &Ctxt, state: &mut State) { } continue; } + if let Some(b) = state.worklist_bound_traits.pop() { + if let Some(impls) = ctxt.trait_to_trait_impls.get(&b) { + for imp in impls { + reach_trait_impl(ctxt, state, imp); + } + } + } + if let Some(i) = state.worklist_trait_impls.pop() { + if let Some(trait_impl) = ctxt.trait_impl_map.get(&i) { + for bound_trait in &trait_impl.bound_traits { + reach_bound_trait(ctxt, state, bound_trait); + } + for bound_type in &trait_impl.bound_types { + reach_type(ctxt, state, bound_type); + } + } + } if let Some(a) = state.worklist_assoc_type_decls.pop() { let typs: Vec = state.reached_types.iter().cloned().collect(); for t in typs { @@ -321,7 +391,7 @@ pub fn prune_krate_for_module( krate: &Krate, module: &Path, vstd_crate_name: &Option, -) -> (Krate, Vec, Vec) { +) -> (Krate, Vec, Vec, HashSet) { let mut state: State = Default::default(); state.reached_modules.insert(module.clone()); state.worklist_modules.push(module.clone()); @@ -419,6 +489,9 @@ pub fn prune_krate_for_module( let mut function_map: HashMap = HashMap::new(); let mut datatype_map: HashMap = HashMap::new(); let mut assoc_type_impl_map: HashMap> = HashMap::new(); + let mut trait_to_trait_impls: HashMap> = HashMap::new(); + let mut typ_to_trait_impls: HashMap> = HashMap::new(); + let mut trait_impl_map: HashMap = HashMap::new(); let mut method_map: HashMap<(ReachedType, Fun), Vec> = HashMap::new(); let mut all_functions_in_each_module: HashMap> = HashMap::new(); for f in &functions { @@ -441,6 +514,36 @@ pub fn prune_krate_for_module( datatype_map.insert(d.x.path.clone(), d.clone()); } + for imp in krate.trait_impls.iter() { + let mut bound_traits: Vec = Vec::new(); + let mut bound_types: Vec = Vec::new(); + for bound in imp.x.typ_bounds.iter() { + let crate::ast::GenericBoundX::Trait(path, typ_args) = &**bound; + bound_traits.push(path.clone()); + for t in typ_args.iter() { + bound_types.push(typ_to_reached_type(t)); + } + } + let trait_impl = TraitImpl { + bound_traits, + bound_types, + trait_name: imp.x.trait_path.clone(), + trait_typ_args: imp.x.trait_typ_args.iter().map(typ_to_reached_type).collect(), + }; + if !trait_to_trait_impls.contains_key(&imp.x.trait_path) { + trait_to_trait_impls.insert(imp.x.trait_path.clone(), Vec::new()); + } + trait_to_trait_impls.get_mut(&imp.x.trait_path).unwrap().push(imp.x.impl_path.clone()); + for t in &trait_impl.trait_typ_args { + if !typ_to_trait_impls.contains_key(t) { + typ_to_trait_impls.insert(t.clone(), Vec::new()); + } + typ_to_trait_impls.get_mut(&t).unwrap().push(imp.x.impl_path.clone()); + } + assert!(!trait_impl_map.contains_key(&imp.x.impl_path)); + trait_impl_map.insert(imp.x.impl_path.clone(), trait_impl); + } + for a in &krate.assoc_type_impls { let key = a.x.prune_name(); if !assoc_type_impl_map.contains_key(&key) { @@ -452,6 +555,9 @@ pub fn prune_krate_for_module( module: module.clone(), function_map, datatype_map, + trait_to_trait_impls, + typ_to_trait_impls, + trait_impl_map, assoc_type_impl_map, method_map, all_functions_in_each_module, @@ -487,7 +593,12 @@ pub fn prune_krate_for_module( .cloned() .collect(), traits, - trait_impls: krate.trait_impls.clone(), + trait_impls: krate + .trait_impls + .iter() + .filter(|i| state.reached_trait_impls.contains(&i.x.impl_path)) + .cloned() + .collect(), module_ids: krate.module_ids.clone(), external_fns: krate.external_fns.clone(), external_types: krate.external_types.clone(), @@ -498,5 +609,6 @@ pub fn prune_krate_for_module( let mut mono_abstract_datatypes: Vec = state.mono_abstract_datatypes.into_iter().collect(); mono_abstract_datatypes.sort(); - (Arc::new(kratex), mono_abstract_datatypes, lambda_types) + let State { reached_bound_traits, .. } = state; + (Arc::new(kratex), mono_abstract_datatypes, lambda_types, reached_bound_traits) } diff --git a/source/vir/src/recursive_types.rs b/source/vir/src/recursive_types.rs index ad002ae0a8..9c1e3c87c9 100644 --- a/source/vir/src/recursive_types.rs +++ b/source/vir/src/recursive_types.rs @@ -295,13 +295,6 @@ pub(crate) fn check_recursive_types(krate: &Krate) -> Result<(), VirErr> { if let FunctionKind::TraitMethodDecl { .. } = function.x.kind { assert!(&function.x.typ_params[0] == &crate::def::trait_self_type_param()); } - if function.x.typ_bounds.len() != 0 && function.x.attrs.broadcast_forall { - // See the todo!() in func_to_air.rs - return error( - &function.span, - "not yet supported: bounds on broadcast_forall function type parameters", - ); - } } for tr in &krate.traits { diff --git a/source/vir/src/sst_to_air.rs b/source/vir/src/sst_to_air.rs index 3996a5ef96..47b611ce8d 100644 --- a/source/vir/src/sst_to_air.rs +++ b/source/vir/src/sst_to_air.rs @@ -841,6 +841,13 @@ pub(crate) fn exp_to_expr(ctx: &Ctx, exp: &Exp, expr_ctxt: &ExprCtxt) -> Result< (ExpX::NullaryOpr(crate::ast::NullaryOpr::ConstGeneric(c)), false) => { str_apply(crate::def::CONST_INT, &vec![typ_to_id(c)]) } + (ExpX::NullaryOpr(crate::ast::NullaryOpr::TraitBound(p, ts)), false) => { + if let Some(e) = crate::traits::trait_bound_to_air(ctx, p, ts) { + e + } else { + air::ast_util::mk_true() + } + } (ExpX::Unary(op, arg), true) => { if !allowed_bitvector_type(&arg.typ) { return error( diff --git a/source/vir/src/sst_util.rs b/source/vir/src/sst_util.rs index 3ce8340f09..21592645b4 100644 --- a/source/vir/src/sst_util.rs +++ b/source/vir/src/sst_util.rs @@ -289,6 +289,7 @@ impl ExpX { NullaryOpr(crate::ast::NullaryOpr::ConstGeneric(_)) => { ("const_generic".to_string(), 99) } + NullaryOpr(crate::ast::NullaryOpr::TraitBound(..)) => ("trait_bound".to_string(), 99), Unary(op, exp) => match op { UnaryOp::Not | UnaryOp::BitNot => (format!("!{}", exp.x.to_string_prec(99)), 90), UnaryOp::Clip { .. } => (format!("clip({})", exp), 99), diff --git a/source/vir/src/sst_visitor.rs b/source/vir/src/sst_visitor.rs index 464f361c5b..6a26ca5621 100644 --- a/source/vir/src/sst_visitor.rs +++ b/source/vir/src/sst_visitor.rs @@ -506,6 +506,10 @@ where let t = ft(env, t)?; ok_exp(ExpX::NullaryOpr(crate::ast::NullaryOpr::ConstGeneric(t))) } + ExpX::NullaryOpr(crate::ast::NullaryOpr::TraitBound(p, ts)) => { + let ts: Result, VirErr> = ts.iter().map(|t| ft(env, t)).collect(); + ok_exp(ExpX::NullaryOpr(crate::ast::NullaryOpr::TraitBound(p.clone(), Arc::new(ts?)))) + } ExpX::Unary(op, e1) => ok_exp(ExpX::Unary(*op, fe(env, e1)?)), ExpX::UnaryOpr(op, e1) => { let op = match op { diff --git a/source/vir/src/traits.rs b/source/vir/src/traits.rs index cdc8b967a6..18680cd156 100644 --- a/source/vir/src/traits.rs +++ b/source/vir/src/traits.rs @@ -1,10 +1,13 @@ use crate::ast::{ - CallTarget, CallTargetKind, Expr, ExprX, Fun, Function, FunctionKind, Ident, Krate, Mode, Path, - Typ, VirErr, WellKnownItem, + CallTarget, CallTargetKind, Expr, ExprX, Fun, Function, FunctionKind, GenericBounds, Ident, + Krate, Mode, Path, SpannedTyped, Typ, TypX, Typs, VirErr, WellKnownItem, }; -use crate::ast_util::error; +use crate::ast_util::{error, path_as_friendly_rust_name}; +use crate::context::Ctx; use crate::def::Spanned; -use air::ast::Span; +use crate::sst_to_air::typ_to_ids; +use air::ast::{Command, CommandX, Commands, DeclX, Span}; +use air::ast_util::{ident_apply, mk_bind_expr, mk_implies, str_typ}; use air::scope_map::ScopeMap; use std::collections::{HashMap, HashSet}; use std::sync::Arc; @@ -150,3 +153,125 @@ fn demote_one_expr(traits: &HashSet, expr: &Expr) -> Result _ => Ok(expr.clone()), } } + +pub(crate) fn trait_bounds_to_ast(ctx: &Ctx, span: &Span, typ_bounds: &GenericBounds) -> Vec { + let mut bound_exprs: Vec = Vec::new(); + for bound in typ_bounds.iter() { + let crate::ast::GenericBoundX::Trait(path, typ_args) = &**bound; + if !ctx.trait_map.contains_key(path) || !ctx.bound_traits.contains(path) { + continue; + } + let op = crate::ast::NullaryOpr::TraitBound(path.clone(), typ_args.clone()); + let exprx = ExprX::NullaryOpr(op); + bound_exprs.push(SpannedTyped::new(span, &Arc::new(TypX::Bool), exprx)); + } + bound_exprs +} + +pub(crate) fn trait_bounds_to_sst( + ctx: &Ctx, + span: &Span, + typ_bounds: &GenericBounds, +) -> Vec { + let mut bound_exps: Vec = Vec::new(); + for bound in typ_bounds.iter() { + let crate::ast::GenericBoundX::Trait(path, typ_args) = &**bound; + if !ctx.trait_map.contains_key(path) || !ctx.bound_traits.contains(path) { + continue; + } + let op = crate::ast::NullaryOpr::TraitBound(path.clone(), typ_args.clone()); + let expx = crate::sst::ExpX::NullaryOpr(op); + bound_exps.push(SpannedTyped::new(span, &Arc::new(TypX::Bool), expx)); + } + bound_exps +} + +pub(crate) fn trait_bound_to_air( + ctx: &Ctx, + path: &Path, + typ_args: &Typs, +) -> Option { + if !ctx.trait_map.contains_key(path) || !ctx.bound_traits.contains(path) { + return None; + } + let mut typ_exprs: Vec = Vec::new(); + for t in typ_args.iter() { + typ_exprs.extend(typ_to_ids(t)); + } + Some(ident_apply(&crate::def::trait_bound(path), &typ_exprs)) +} + +pub(crate) fn trait_bounds_to_air(ctx: &Ctx, typ_bounds: &GenericBounds) -> Vec { + let mut bound_exprs: Vec = Vec::new(); + for bound in typ_bounds.iter() { + let crate::ast::GenericBoundX::Trait(path, typ_args) = &**bound; + if let Some(bound) = trait_bound_to_air(ctx, path, typ_args) { + bound_exprs.push(bound); + } + } + bound_exprs +} + +pub fn traits_to_air(ctx: &Ctx, krate: &Krate) -> Commands { + // Axioms about broadcast_forall and spec functions need justification + // for any trait bounds. + let mut commands: Vec = Vec::new(); + + // Declare predicates for bounds + // (declare-fun tr_bound%T (... Dcr Type ...) Bool) + for tr in krate.traits.iter() { + if ctx.bound_traits.contains(&tr.x.name) { + let mut tparams: Vec = Vec::new(); + tparams.extend(crate::def::types().iter().map(|s| str_typ(s))); // Self + for _ in tr.x.typ_params.iter() { + tparams.extend(crate::def::types().iter().map(|s| str_typ(s))); + } + let decl_trait_bound = Arc::new(DeclX::Fun( + crate::def::trait_bound(&tr.x.name), + Arc::new(tparams), + air::ast_util::bool_typ(), + )); + commands.push(Arc::new(CommandX::Global(decl_trait_bound))); + } + } + + // Axioms for bounds predicates (based on trait impls) + for imp in krate.trait_impls.iter() { + assert!(ctx.bound_traits.contains(&imp.x.trait_path)); + // forall typ_params. typ_bounds ==> tr_bound%T(...typ_args...) + // Example: + // trait T1 {} + // trait T2 {} + // impl T2> for S> + // --> + // forall A. tr_bound%T1(A) ==> tr_bound%T2(S>, Set) + let tr_bound = if let Some(tr_bound) = + trait_bound_to_air(ctx, &imp.x.trait_path, &imp.x.trait_typ_args) + { + tr_bound + } else { + continue; + }; + let name = format!( + "{}_{}", + path_as_friendly_rust_name(&imp.x.impl_path), + crate::def::QID_TRAIT_IMPL + ); + let trigs = vec![tr_bound.clone()]; + let bind = crate::func_to_air::func_bind_trig( + ctx, + name, + &imp.x.typ_params, + &Arc::new(vec![]), + &trigs, + false, + ); + let req_bounds = trait_bounds_to_air(ctx, &imp.x.typ_bounds); + let imply = mk_implies(&air::ast_util::mk_and(&req_bounds), &tr_bound); + let forall = mk_bind_expr(&bind, &imply); + let axiom = Arc::new(DeclX::Axiom(forall)); + commands.push(Arc::new(CommandX::Global(axiom))); + } + + Arc::new(commands) +} diff --git a/source/vir/src/triggers.rs b/source/vir/src/triggers.rs index cc2f91cd7e..3e4609867b 100644 --- a/source/vir/src/triggers.rs +++ b/source/vir/src/triggers.rs @@ -241,6 +241,9 @@ fn check_trigger_expr( ExpX::VarAt(_, VarAt::Pre) => Ok(()), ExpX::Old(_, _) => panic!("internal error: Old"), ExpX::NullaryOpr(crate::ast::NullaryOpr::ConstGeneric(_)) => Ok(()), + ExpX::NullaryOpr(crate::ast::NullaryOpr::TraitBound(..)) => { + error(&exp.span, "triggers cannot contain trait bounds") + } ExpX::Unary(op, arg) => match op { UnaryOp::StrLen | UnaryOp::StrIsAscii => check_trigger_expr_arg(state, true, arg), UnaryOp::Clip { .. } | UnaryOp::BitNot | UnaryOp::CharToInt => { diff --git a/source/vir/src/well_formed.rs b/source/vir/src/well_formed.rs index a3b195625a..a2dc5ca496 100644 --- a/source/vir/src/well_formed.rs +++ b/source/vir/src/well_formed.rs @@ -950,7 +950,7 @@ pub fn check_crate(krate: &Krate, diags: &mut Vec) -> Result<(), VirEr }; if !proof_function.x.attrs.is_decrease_by { return Err(air::messages::error( - "proof function must be marked #[verifier(decreases_by)] or #[verifier(recommends_by)] to be used as decreases_by/recommends_by", + "proof function must be marked #[verifier::decreases_by] or #[verifier::recommends_by] to be used as decreases_by/recommends_by", &proof_function.span, ) .secondary_span(&function.span)); @@ -1024,7 +1024,7 @@ pub fn check_crate(krate: &Krate, diags: &mut Vec) -> Result<(), VirEr { return error( &function.span, - "function cannot be marked #[verifier(decreases_by)] or #[verifier(recommends_by)] unless it is used in some decreases_by/recommends_by", + "function cannot be marked #[verifier::decreases_by] or #[verifier::recommends_by] unless it is used in some decreases_by/recommends_by", ); } }