Skip to content

Commit

Permalink
Simplify by using IsVar flag to mark mutability.
Browse files Browse the repository at this point in the history
  • Loading branch information
yannbolliger committed Mar 4, 2021
1 parent 8c46cbe commit dcdefba
Show file tree
Hide file tree
Showing 5 changed files with 45 additions and 36 deletions.
15 changes: 15 additions & 0 deletions stainless_data/src/ast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,21 @@ impl<'l, 'a> From<&'l ValDef<'a>> for &'l Variable<'a> {
}
}

impl Variable<'_> {
pub fn is_mutable(&self) -> bool {
self.flags.iter().any(|f| match f {
Flag::IsVar(_) => true,
_ => false,
})
}
}

impl ValDef<'_> {
pub fn is_mutable(&self) -> bool {
self.v.is_mutable()
}
}

// Additional helpers that mirror those in Inox

pub fn Int32Literal(value: Int) -> BVLiteral {
Expand Down
38 changes: 17 additions & 21 deletions stainless_extraction/src/bindings.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ impl<'a, 'l, 'tcx> BodyExtractor<'a, 'l, 'tcx> {
.pat
.simple_ident()
.and_then(|ident| flags_by_symbol.remove(&ident.name));
let (var, _) = self.extract_binding(param.pat.hir_id, flags);
let var = self.extract_binding(param.pat.hir_id, flags);
self.dcx.add_param(self.factory().ValDef(var));
}

Expand All @@ -35,11 +35,7 @@ impl<'a, 'l, 'tcx> BodyExtractor<'a, 'l, 'tcx> {

/// Extract a binding based on the binding node's HIR id.
/// Updates `dcx` if the binding hadn't been extracted before.
fn extract_binding(
&mut self,
hir_id: HirId,
flags: Option<Flags>,
) -> (&'l st::Variable<'l>, bool) {
fn extract_binding(&mut self, hir_id: HirId, flags_opt: Option<Flags>) -> &'l st::Variable<'l> {
self.dcx.get_var(hir_id).unwrap_or_else(|| {
let xtor = &mut self.base;

Expand Down Expand Up @@ -85,21 +81,26 @@ impl<'a, 'l, 'tcx> BodyExtractor<'a, 'l, 'tcx> {
};

// Build a Variable node
let f = xtor.factory();
let tpe = xtor.extract_ty(self.tables.node_type(hir_id), &self.txtcx, span);
let flags = flags
.map(|flags| flags.to_stainless(xtor.factory()))
.unwrap_or_default();
let var = xtor.factory().Variable(id, tpe, flags);
self.dcx.add_var(hir_id, var, mutable);
(var, mutable)
let flags = flags_opt
.map(|flags| flags.to_stainless(f))
.into_iter()
.flatten()
// Add @var flag if the param is mutable
.chain(mutable.then(|| f.IsVar().into()))
.collect();
let var = f.Variable(id, tpe, flags);
self.dcx.add_var(hir_id, var);
var
})
}
}

/// DefContext tracks available bindings
#[derive(Clone, Debug)]
pub(super) struct DefContext<'l> {
vars: HashMap<HirId, (&'l st::Variable<'l>, bool)>,
vars: HashMap<HirId, &'l st::Variable<'l>>,
params: Vec<&'l st::ValDef<'l>>,
}

Expand All @@ -115,14 +116,9 @@ impl<'l> DefContext<'l> {
&self.params[..]
}

pub(super) fn add_var(
&mut self,
hir_id: HirId,
var: &'l st::Variable<'l>,
mutable: bool,
) -> &mut Self {
pub(super) fn add_var(&mut self, hir_id: HirId, var: &'l st::Variable<'l>) -> &mut Self {
assert!(!self.vars.contains_key(&hir_id));
self.vars.insert(hir_id, (var, mutable));
self.vars.insert(hir_id, var);
self
}

Expand All @@ -133,7 +129,7 @@ impl<'l> DefContext<'l> {
}

#[inline]
pub(super) fn get_var(&self, hir_id: HirId) -> Option<(&'l st::Variable<'l>, bool)> {
pub(super) fn get_var(&self, hir_id: HirId) -> Option<&'l st::Variable<'l>> {
self.vars.get(&hir_id).copied()
}
}
Expand Down
21 changes: 9 additions & 12 deletions stainless_extraction/src/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ impl<'a, 'l, 'tcx> BodyExtractor<'a, 'l, 'tcx> {
ExprKind::LogicalOp { .. } => self.extract_logical_op(expr),
ExprKind::Tuple { .. } => self.extract_tuple(expr),
ExprKind::Field { lhs, name } => self.extract_field(lhs, name),
ExprKind::VarRef { id } => self.fetch_var(id).0.into(),
ExprKind::VarRef { id } => self.fetch_var(id).into(),
ExprKind::Call { ty, ref args, .. } => self.extract_call_like(ty, args, expr.span),
ExprKind::Adt { .. } => self.extract_adt_construction(expr),
ExprKind::Block { body: ast_block } => {
Expand Down Expand Up @@ -619,7 +619,7 @@ impl<'a, 'l, 'tcx> BodyExtractor<'a, 'l, 'tcx> {
box kind @ PatKind::Binding { .. } => {
assert!(binder.is_none());
match self.try_pattern_to_var(&kind, true) {
Ok((binder, _)) => match kind {
Ok(binder) => match kind {
PatKind::Binding {
subpattern: Some(subpattern),
..
Expand Down Expand Up @@ -716,10 +716,7 @@ impl<'a, 'l, 'tcx> BodyExtractor<'a, 'l, 'tcx> {
let lhs = self.mirror(lhs);
let lhs = self.strip_scope(lhs);
match lhs.kind {
ExprKind::VarRef { id } => self
.factory()
.Assignment(self.fetch_var(id).0, value)
.into(),
ExprKind::VarRef { id } => self.factory().Assignment(self.fetch_var(id), value).into(),

ExprKind::Field { lhs, name } => {
let lhs = self.mirror(lhs);
Expand Down Expand Up @@ -827,14 +824,14 @@ impl<'a, 'l, 'tcx> BodyExtractor<'a, 'l, 'tcx> {
&format!("Cannot extract complex pattern in let: {}", reason),
pattern.span,
),
Ok((vd, mutable)) => {
Ok(vd) => {
// recurse the extract all the following statements
let exprs = acc_exprs.clone();
acc_exprs.clear();
let body_expr = self.extract_block_(stmts, acc_exprs, acc_specs, final_expr);
// wrap that body expression into the Let
let init = self.extract_expr_ref(init);
let last_expr = if mutable {
let last_expr = if vd.is_mutable() {
f.LetVar(vd, init, body_expr).into()
} else {
f.Let(vd, init, body_expr).into()
Expand Down Expand Up @@ -902,7 +899,7 @@ impl<'a, 'l, 'tcx> BodyExtractor<'a, 'l, 'tcx> {
&self,
pat_kind: &PatKind<'tcx>,
allow_subpattern: bool,
) -> Result<(&'l st::ValDef<'l>, bool)> {
) -> Result<&'l st::ValDef<'l>> {
match pat_kind {
PatKind::Binding {
subpattern: Some(_),
Expand All @@ -921,9 +918,9 @@ impl<'a, 'l, 'tcx> BodyExtractor<'a, 'l, 'tcx> {
var: hir_id,
..
} => {
let (var, mutable) = self.fetch_var(*hir_id);
if *mutability == Mutability::Not || mutable {
Ok((self.factory().ValDef(var), mutable))
let var = self.fetch_var(*hir_id);
if *mutability == Mutability::Not || var.is_mutable() {
Ok(self.factory().ValDef(var))
} else {
Err("Binding mode not allowed")
}
Expand Down
3 changes: 2 additions & 1 deletion stainless_extraction/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#![feature(rustc_private)]
#![feature(box_patterns)]
#![feature(bool_to_option)]

#[macro_use]
extern crate lazy_static;
Expand Down Expand Up @@ -372,7 +373,7 @@ impl<'a, 'l, 'tcx> BodyExtractor<'a, 'l, 'tcx> {
self.base.factory()
}

fn fetch_var(&self, hir_id: HirId) -> (&'l st::Variable<'l>, bool) {
fn fetch_var(&self, hir_id: HirId) -> &'l st::Variable<'l> {
let span: Span = self.tcx().hir().span(hir_id);
self
.dcx
Expand Down
4 changes: 2 additions & 2 deletions stainless_extraction/src/spec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -106,13 +106,13 @@ impl<'a, 'l, 'tcx> BodyExtractor<'a, 'l, 'tcx> {
.expect("No return parameter on post spec function");

// Preregister the ret binding with the return_id
bxtor.dcx.add_var(return_param_id, return_var, false);
bxtor.dcx.add_var(return_param_id, return_var);
}

// Register all other bindings
assert_eq!(outer_fn_params.len(), spec_param_ids.len());
for (vd, sid) in outer_fn_params.iter().zip(spec_param_ids) {
bxtor.dcx.add_var(sid, vd.v, false);
bxtor.dcx.add_var(sid, vd.v);
}
// Pick up any additional local bindings
// (A spec neither has flags on the params, nor additional evidence params)
Expand Down

0 comments on commit dcdefba

Please sign in to comment.