diff --git a/stainless_data/src/ast.rs b/stainless_data/src/ast.rs index 6648ce93..12dc986e 100644 --- a/stainless_data/src/ast.rs +++ b/stainless_data/src/ast.rs @@ -149,6 +149,21 @@ impl<'l, 'a> From<&'l ValDef<'a>> for &'l Variable<'a> { } } +impl Variable<'_> { + pub fn is_mutable(&self) -> bool { + self.flags.iter().any(|f| match f { + Flag::IsVar(_) => true, + _ => false, + }) + } +} + +impl ValDef<'_> { + pub fn is_mutable(&self) -> bool { + self.v.is_mutable() + } +} + // Additional helpers that mirror those in Inox pub fn Int32Literal(value: Int) -> BVLiteral { diff --git a/stainless_extraction/src/bindings.rs b/stainless_extraction/src/bindings.rs index 9c8ca01b..0771d984 100644 --- a/stainless_extraction/src/bindings.rs +++ b/stainless_extraction/src/bindings.rs @@ -20,7 +20,7 @@ impl<'a, 'l, 'tcx> BodyExtractor<'a, 'l, 'tcx> { .pat .simple_ident() .and_then(|ident| flags_by_symbol.remove(&ident.name)); - let (var, _) = self.extract_binding(param.pat.hir_id, flags); + let var = self.extract_binding(param.pat.hir_id, flags); self.dcx.add_param(self.factory().ValDef(var)); } @@ -35,11 +35,7 @@ impl<'a, 'l, 'tcx> BodyExtractor<'a, 'l, 'tcx> { /// Extract a binding based on the binding node's HIR id. /// Updates `dcx` if the binding hadn't been extracted before. - fn extract_binding( - &mut self, - hir_id: HirId, - flags: Option, - ) -> (&'l st::Variable<'l>, bool) { + fn extract_binding(&mut self, hir_id: HirId, flags_opt: Option) -> &'l st::Variable<'l> { self.dcx.get_var(hir_id).unwrap_or_else(|| { let xtor = &mut self.base; @@ -85,13 +81,18 @@ impl<'a, 'l, 'tcx> BodyExtractor<'a, 'l, 'tcx> { }; // Build a Variable node + let f = xtor.factory(); let tpe = xtor.extract_ty(self.tables.node_type(hir_id), &self.txtcx, span); - let flags = flags - .map(|flags| flags.to_stainless(xtor.factory())) - .unwrap_or_default(); - let var = xtor.factory().Variable(id, tpe, flags); - self.dcx.add_var(hir_id, var, mutable); - (var, mutable) + let flags = flags_opt + .map(|flags| flags.to_stainless(f)) + .into_iter() + .flatten() + // Add @var flag if the param is mutable + .chain(mutable.then(|| f.IsVar().into())) + .collect(); + let var = f.Variable(id, tpe, flags); + self.dcx.add_var(hir_id, var); + var }) } } @@ -99,7 +100,7 @@ impl<'a, 'l, 'tcx> BodyExtractor<'a, 'l, 'tcx> { /// DefContext tracks available bindings #[derive(Clone, Debug)] pub(super) struct DefContext<'l> { - vars: HashMap, bool)>, + vars: HashMap>, params: Vec<&'l st::ValDef<'l>>, } @@ -115,14 +116,9 @@ impl<'l> DefContext<'l> { &self.params[..] } - pub(super) fn add_var( - &mut self, - hir_id: HirId, - var: &'l st::Variable<'l>, - mutable: bool, - ) -> &mut Self { + pub(super) fn add_var(&mut self, hir_id: HirId, var: &'l st::Variable<'l>) -> &mut Self { assert!(!self.vars.contains_key(&hir_id)); - self.vars.insert(hir_id, (var, mutable)); + self.vars.insert(hir_id, var); self } @@ -133,7 +129,7 @@ impl<'l> DefContext<'l> { } #[inline] - pub(super) fn get_var(&self, hir_id: HirId) -> Option<(&'l st::Variable<'l>, bool)> { + pub(super) fn get_var(&self, hir_id: HirId) -> Option<&'l st::Variable<'l>> { self.vars.get(&hir_id).copied() } } diff --git a/stainless_extraction/src/expr.rs b/stainless_extraction/src/expr.rs index 9dd80de4..18f43f36 100644 --- a/stainless_extraction/src/expr.rs +++ b/stainless_extraction/src/expr.rs @@ -31,7 +31,7 @@ impl<'a, 'l, 'tcx> BodyExtractor<'a, 'l, 'tcx> { ExprKind::LogicalOp { .. } => self.extract_logical_op(expr), ExprKind::Tuple { .. } => self.extract_tuple(expr), ExprKind::Field { lhs, name } => self.extract_field(lhs, name), - ExprKind::VarRef { id } => self.fetch_var(id).0.into(), + ExprKind::VarRef { id } => self.fetch_var(id).into(), ExprKind::Call { ty, ref args, .. } => self.extract_call_like(ty, args, expr.span), ExprKind::Adt { .. } => self.extract_adt_construction(expr), ExprKind::Block { body: ast_block } => { @@ -619,7 +619,7 @@ impl<'a, 'l, 'tcx> BodyExtractor<'a, 'l, 'tcx> { box kind @ PatKind::Binding { .. } => { assert!(binder.is_none()); match self.try_pattern_to_var(&kind, true) { - Ok((binder, _)) => match kind { + Ok(binder) => match kind { PatKind::Binding { subpattern: Some(subpattern), .. @@ -716,10 +716,7 @@ impl<'a, 'l, 'tcx> BodyExtractor<'a, 'l, 'tcx> { let lhs = self.mirror(lhs); let lhs = self.strip_scope(lhs); match lhs.kind { - ExprKind::VarRef { id } => self - .factory() - .Assignment(self.fetch_var(id).0, value) - .into(), + ExprKind::VarRef { id } => self.factory().Assignment(self.fetch_var(id), value).into(), ExprKind::Field { lhs, name } => { let lhs = self.mirror(lhs); @@ -827,14 +824,14 @@ impl<'a, 'l, 'tcx> BodyExtractor<'a, 'l, 'tcx> { &format!("Cannot extract complex pattern in let: {}", reason), pattern.span, ), - Ok((vd, mutable)) => { + Ok(vd) => { // recurse the extract all the following statements let exprs = acc_exprs.clone(); acc_exprs.clear(); let body_expr = self.extract_block_(stmts, acc_exprs, acc_specs, final_expr); // wrap that body expression into the Let let init = self.extract_expr_ref(init); - let last_expr = if mutable { + let last_expr = if vd.is_mutable() { f.LetVar(vd, init, body_expr).into() } else { f.Let(vd, init, body_expr).into() @@ -902,7 +899,7 @@ impl<'a, 'l, 'tcx> BodyExtractor<'a, 'l, 'tcx> { &self, pat_kind: &PatKind<'tcx>, allow_subpattern: bool, - ) -> Result<(&'l st::ValDef<'l>, bool)> { + ) -> Result<&'l st::ValDef<'l>> { match pat_kind { PatKind::Binding { subpattern: Some(_), @@ -921,9 +918,9 @@ impl<'a, 'l, 'tcx> BodyExtractor<'a, 'l, 'tcx> { var: hir_id, .. } => { - let (var, mutable) = self.fetch_var(*hir_id); - if *mutability == Mutability::Not || mutable { - Ok((self.factory().ValDef(var), mutable)) + let var = self.fetch_var(*hir_id); + if *mutability == Mutability::Not || var.is_mutable() { + Ok(self.factory().ValDef(var)) } else { Err("Binding mode not allowed") } diff --git a/stainless_extraction/src/lib.rs b/stainless_extraction/src/lib.rs index 74be9849..bec8f0ac 100644 --- a/stainless_extraction/src/lib.rs +++ b/stainless_extraction/src/lib.rs @@ -1,5 +1,6 @@ #![feature(rustc_private)] #![feature(box_patterns)] +#![feature(bool_to_option)] #[macro_use] extern crate lazy_static; @@ -372,7 +373,7 @@ impl<'a, 'l, 'tcx> BodyExtractor<'a, 'l, 'tcx> { self.base.factory() } - fn fetch_var(&self, hir_id: HirId) -> (&'l st::Variable<'l>, bool) { + fn fetch_var(&self, hir_id: HirId) -> &'l st::Variable<'l> { let span: Span = self.tcx().hir().span(hir_id); self .dcx diff --git a/stainless_extraction/src/spec.rs b/stainless_extraction/src/spec.rs index 52937504..3f135683 100644 --- a/stainless_extraction/src/spec.rs +++ b/stainless_extraction/src/spec.rs @@ -106,13 +106,13 @@ impl<'a, 'l, 'tcx> BodyExtractor<'a, 'l, 'tcx> { .expect("No return parameter on post spec function"); // Preregister the ret binding with the return_id - bxtor.dcx.add_var(return_param_id, return_var, false); + bxtor.dcx.add_var(return_param_id, return_var); } // Register all other bindings assert_eq!(outer_fn_params.len(), spec_param_ids.len()); for (vd, sid) in outer_fn_params.iter().zip(spec_param_ids) { - bxtor.dcx.add_var(sid, vd.v, false); + bxtor.dcx.add_var(sid, vd.v); } // Pick up any additional local bindings // (A spec neither has flags on the params, nor additional evidence params)