Skip to content

Commit

Permalink
Preliminary refactoring (#161)
Browse files Browse the repository at this point in the history
* Refactor map extraction.

* Struct for fn signature data.

* Smaller refactoring.

* Update stainless version.
  • Loading branch information
yannbolliger authored Jun 8, 2021
1 parent d6447be commit 8688427
Show file tree
Hide file tree
Showing 6 changed files with 122 additions and 123 deletions.
2 changes: 1 addition & 1 deletion .stainless-version
Original file line number Diff line number Diff line change
@@ -1 +1 @@
noxt-0.8.0-5-gf899dfd
noxt-0.8.0-8-gf15f5b5
67 changes: 29 additions & 38 deletions stainless_extraction/src/bindings.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use super::flags::Flags;
use super::*;

use rustc_hir::intravisit::{self, NestedVisitorMap, Visitor};
use rustc_hir::{self as hir, HirId, Node, Pat, PatKind};
use rustc_middle::ty;
use rustc_span::symbol::Symbol;
Expand Down Expand Up @@ -46,50 +47,40 @@ impl<'a, 'l, 'tcx> BodyExtractor<'a, 'l, 'tcx> {
let xtor = &mut self.base;

// Extract ident from corresponding HIR node, sanity-check binding mode
let (id, span, mutable) = {
let node = xtor.tcx.hir().find(hir_id).unwrap();

let (ident, mutable) = if let Node::Binding(Pat {
kind: PatKind::Binding(_, _, ident, _),
hir_id,
span,
..
}) = node
{
match self
let node = xtor.tcx.hir().find(hir_id).unwrap();
let (ident, binding_mode, span) = if let Node::Binding(Pat {
kind: PatKind::Binding(_, _, ident, _),
hir_id,
span,
..
}) = node
{
(
ident,
self
.tables
.extract_binding_mode(xtor.tcx.sess, *hir_id, *span)
{
// allowed binding modes
Some(ty::BindByValue(hir::Mutability::Not))
| Some(ty::BindByReference(hir::Mutability::Not)) => (ident, false),
Some(ty::BindByValue(hir::Mutability::Mut)) => (ident, true),

// For the forbidden binding modes, return the identifier anyway
// because failure will occur later.
_ => {
xtor.unsupported(*span, "Only immutable bindings are supported");
(ident, false)
}
}
} else {
xtor.unsupported(
node.ident().map(|ident| ident.span).unwrap_or_default(),
"Cannot extract complex pattern in binding (cannot recover from this)",
);
unreachable!()
};

(
xtor.register_hir(hir_id, ident.name.to_string()),
ident.span,
mutable,
.expect("Cannot extract binding without binding mode."),
span,
)
} else {
xtor.unsupported(
node.ident().map(|ident| ident.span).unwrap_or_default(),
"Cannot extract complex pattern in binding (cannot recover from this)",
);
unreachable!()
};

if let ty::BindByReference(Mutability::Mut) = binding_mode {
xtor.unsupported(*span, "Only immutable bindings are supported");
}

let id = xtor.register_hir(hir_id, ident.name.to_string());
let mutable = matches!(binding_mode, ty::BindByValue(Mutability::Mut));

// Build a Variable node
let f = xtor.factory();
let tpe = xtor.extract_ty(self.tables.node_type(hir_id), &self.txtcx, span);
let tpe = xtor.extract_ty(self.tables.node_type(hir_id), &self.txtcx, ident.span);
let flags = flags_opt
.map(|flags| flags.to_stainless(f))
.into_iter()
Expand Down Expand Up @@ -206,7 +197,7 @@ impl<'bxtor, 'a, 'l, 'tcx> BindingsCollector<'bxtor, 'a, 'l, 'tcx> {
}
}

impl<'bxtor, 'a, 'l, 'tcx> Visitor<'tcx> for BindingsCollector<'bxtor, 'a, 'l, 'tcx> {
impl<'tcx> Visitor<'tcx> for BindingsCollector<'_, '_, '_, 'tcx> {
type Map = rustc_middle::hir::map::Map<'tcx>;

fn nested_visit_map(&mut self) -> NestedVisitorMap<Self::Map> {
Expand Down
121 changes: 59 additions & 62 deletions stainless_extraction/src/expr/map.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,22 +5,28 @@ impl<'a, 'l, 'tcx> BodyExtractor<'a, 'l, 'tcx> {
&mut self,
item: CrateItem,
args: &'a [Expr<'a, 'tcx>],
substs_ref: SubstsRef<'tcx>,
substs: SubstsRef<'tcx>,
span: Span,
) -> st::Expr<'l> {
let (key_tpe, val_tpe) = match &self.base.extract_tys(substs.types(), &self.txtcx, span)[..] {
[key_tpe, val_tpe] => (*key_tpe, *val_tpe),
_ => {
return self.unsupported_expr(
span,
format!("Cannot extract {:?} with {} types.", item, substs.len()),
)
}
};

match (item, &self.extract_exprs(args)[..]) {
(MapNewFn, []) => self.extract_map_creation(substs_ref, span),
(MapNewFn, []) => self.extract_map_creation(key_tpe, val_tpe),

(MapIndexFn, [map, key]) => self.extract_map_apply(*map, *key, substs_ref, span),
(MapContainsKeyFn, [map, key]) => self.extract_map_contains(*map, *key, substs_ref, span),
(MapRemoveFn, [map, key]) => self.extract_map_removed(*map, *key, substs_ref, span),
(MapIndexFn, [map, key]) => self.extract_map_apply(*map, *key, val_tpe),
(MapContainsKeyFn, [map, key]) => self.extract_map_contains(*map, *key, val_tpe),
(MapRemoveFn, [map, key]) => self.extract_map_remove(*map, *key, val_tpe),
(MapGetFn, [map, key]) => self.factory().MapApply(*map, *key).into(),
(MapInsertFn, [map, key, val]) => {
self.extract_map_updated(*map, *key, *val, substs_ref, span)
}
(MapGetOrFn, [map, key, or_else]) => {
self.extract_map_get_or_else(*map, *key, *or_else, substs_ref, span)
}
(MapInsertFn, [map, key, val]) => self.extract_map_insert(*map, *key, val_tpe, *val),
(MapGetOrFn, [map, key, or_else]) => self.extract_map_get_or(*map, *key, val_tpe, *or_else),

(op, _) => self.unsupported_expr(
span,
Expand All @@ -29,32 +35,25 @@ impl<'a, 'l, 'tcx> BodyExtractor<'a, 'l, 'tcx> {
}
}

fn extract_map_creation(&mut self, substs: SubstsRef<'tcx>, span: Span) -> st::Expr<'l> {
let f = self.factory();
let tps = self.base.extract_tys(substs.types(), &self.txtcx, span);

match &tps[..] {
[key_tpe, val_tpe] => f
.FiniteMap(
vec![],
self.synth().std_option_none(*val_tpe),
*key_tpe,
self.synth().std_option_type(*val_tpe),
)
.into(),
_ => unreachable!(),
}
fn extract_map_creation(&mut self, key_tpe: st::Type<'l>, val_tpe: st::Type<'l>) -> st::Expr<'l> {
self
.factory()
.FiniteMap(
vec![],
self.synth().std_option_none(val_tpe),
key_tpe,
self.synth().std_option_type(val_tpe),
)
.into()
}

fn extract_map_apply(
&mut self,
map: st::Expr<'l>,
key: st::Expr<'l>,
substs: SubstsRef<'tcx>,
span: Span,
val_tpe: st::Type<'l>,
) -> st::Expr<'l> {
let f = self.factory();
let val_tpe = self.base.extract_ty(substs.type_at(1), &self.txtcx, span);
let some_tpe = self.synth().std_option_some_type(val_tpe);
f.Assert(
f.IsInstanceOf(f.MapApply(map, key).into(), some_tpe).into(),
Expand All @@ -70,11 +69,9 @@ impl<'a, 'l, 'tcx> BodyExtractor<'a, 'l, 'tcx> {
&mut self,
map: st::Expr<'l>,
key: st::Expr<'l>,
substs: SubstsRef<'tcx>,
span: Span,
val_tpe: st::Type<'l>,
) -> st::Expr<'l> {
let f = self.factory();
let val_tpe = self.base.extract_ty(substs.type_at(1), &self.txtcx, span);
f.Not(
f.Equals(
f.MapApply(map, key).into(),
Expand All @@ -85,52 +82,52 @@ impl<'a, 'l, 'tcx> BodyExtractor<'a, 'l, 'tcx> {
.into()
}

fn extract_map_removed(
fn extract_map_get_or(
&mut self,
map: st::Expr<'l>,
key: st::Expr<'l>,
substs: SubstsRef<'tcx>,
span: Span,
val_tpe: st::Type<'l>,
or_else: st::Expr<'l>,
) -> st::Expr<'l> {
let val_tpe = self.base.extract_ty(substs.type_at(1), &self.txtcx, span);
self
.factory()
.MapUpdated(map, key, self.synth().std_option_none(val_tpe))
.into()
let f = self.factory();
let some_tpe = self.synth().std_option_some_type(val_tpe);
f.IfExpr(
f.IsInstanceOf(f.MapApply(map, key).into(), some_tpe).into(),
self
.synth()
.std_option_some_value(f.AsInstanceOf(f.MapApply(map, key).into(), some_tpe).into()),
or_else,
)
.into()
}

fn extract_map_updated(
fn extract_map_insert(
&mut self,
map: st::Expr<'l>,
key: st::Expr<'l>,
val_tpe: st::Type<'l>,
val: st::Expr<'l>,
substs: SubstsRef<'tcx>,
span: Span,
) -> st::Expr<'l> {
let f = self.factory();
let val_tpe = self.base.extract_ty(substs.type_at(1), &self.txtcx, span);
f.MapUpdated(map, key, self.synth().std_option_some(val, val_tpe))
.into()
let update_value = self.synth().std_option_some(val, val_tpe);
self.extract_map_update(map, key, update_value)
}

fn extract_map_get_or_else(
fn extract_map_remove(
&mut self,
map: st::Expr<'l>,
key: st::Expr<'l>,
or_else: st::Expr<'l>,
substs: SubstsRef<'tcx>,
span: Span,
val_tpe: st::Type<'l>,
) -> st::Expr<'l> {
let f = self.factory();
let val_tpe = self.base.extract_ty(substs.type_at(1), &self.txtcx, span);
let some_tpe = self.synth().std_option_some_type(val_tpe);
f.IfExpr(
f.IsInstanceOf(f.MapApply(map, key).into(), some_tpe).into(),
self
.synth()
.std_option_some_value(f.AsInstanceOf(f.MapApply(map, key).into(), some_tpe).into()),
or_else,
)
.into()
let update_value = self.synth().std_option_none(val_tpe);
self.extract_map_update(map, key, update_value)
}

fn extract_map_update(
&mut self,
map: st::Expr<'l>,
key: st::Expr<'l>,
update_value: st::Expr<'l>,
) -> st::Expr<'l> {
self.factory().MapUpdated(map, key, update_value).into()
}
}
7 changes: 7 additions & 0 deletions stainless_extraction/src/fns.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,13 @@ impl<'a> FnItem<'a> {
}
}

pub struct FnSignature<'l> {
pub id: &'l st::SymbolIdentifier<'l>,
pub tparams: Vec<&'l st::TypeParameterDef<'l>>,
pub params: Params<'l>,
pub return_tpe: st::Type<'l>,
}

/// Identifies the specific implementation/instance of a type class that is
/// needed at a method call site.
#[derive(Debug, Eq, PartialEq)]
Expand Down
46 changes: 25 additions & 21 deletions stainless_extraction/src/krate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,8 @@ use rustc_middle::ty::{AssocKind, List, TraitRef};
use rustc_span::DUMMY_SP;

use stainless_data::ast as st;
use stainless_data::ast::{SymbolIdentifier, TypeParameterDef};

use crate::fns::FnItem;
use crate::fns::{FnItem, FnSignature};
use std::iter;

/// Top-level extraction
Expand Down Expand Up @@ -260,15 +259,20 @@ impl<'l, 'tcx> BaseExtractor<'l, 'tcx> {
"Expected non-local def id, got: {:?}",
def_id
);
let (id, tparams, params, rtp) = self.extract_fn_signature(def_id);
let FnSignature {
id,
tparams,
params,
return_tpe,
} = self.extract_fn_signature(def_id);

let f = self.factory();
let empty_body = f.NoTree(rtp).into();
let empty_body = f.NoTree(return_tpe).into();
f.FunDef(
id,
tparams,
params,
rtp,
return_tpe,
empty_body,
vec![f.Extern().into()],
)
Expand All @@ -282,17 +286,22 @@ impl<'l, 'tcx> BaseExtractor<'l, 'tcx> {
})
}

pub fn extract_abstract_fn(&mut self, def_id: DefId) -> &'l st::FunDef<'l> {
let (id, tparams, params, rtp) = self.extract_fn_signature(def_id);
fn extract_abstract_fn(&mut self, def_id: DefId) -> &'l st::FunDef<'l> {
let FnSignature {
id,
tparams,
params,
return_tpe,
} = self.extract_fn_signature(def_id);
let class_def = self.get_class_of_method(id);

let f = self.factory();
let empty_body = f.NoTree(rtp).into();
let empty_body = f.NoTree(return_tpe).into();
f.FunDef(
id,
self.filter_class_tparams(tparams, class_def),
params,
rtp,
return_tpe,
empty_body,
class_def
.iter()
Expand All @@ -302,15 +311,7 @@ impl<'l, 'tcx> BaseExtractor<'l, 'tcx> {
)
}

fn extract_fn_signature(
&mut self,
def_id: DefId,
) -> (
&'l SymbolIdentifier<'l>,
Vec<&'l TypeParameterDef<'l>>,
Params<'l>,
st::Type<'l>,
) {
fn extract_fn_signature(&mut self, def_id: DefId) -> FnSignature<'l> {
let f = self.factory();

// Extract the function signature
Expand All @@ -328,10 +329,13 @@ impl<'l, 'tcx> BaseExtractor<'l, 'tcx> {
&*f.ValDef(var)
})
.collect();
let return_tpe = self.extract_ty(fn_sig.output(), &txtcx, DUMMY_SP);

let fun_id = self.get_or_extract_fn_ref(def_id);
(fun_id, tparams, params, return_tpe)
FnSignature {
id: self.get_or_extract_fn_ref(def_id),
tparams,
params,
return_tpe: self.extract_ty(fn_sig.output(), &txtcx, DUMMY_SP),
}
}

pub fn get_or_extract_local_fn(&mut self, fn_item: &FnItem<'l>) -> &'l st::FunDef<'l> {
Expand Down
2 changes: 1 addition & 1 deletion stainless_extraction/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@ use std::collections::{HashMap, HashSet};
use std::rc::Rc;

use rustc_hir::def_id::{DefId, LocalDefId};
use rustc_hir::intravisit::{self, NestedVisitorMap, Visitor};
use rustc_hir::{self as hir, HirId};
use rustc_middle::mir::Mutability;
use rustc_middle::span_bug;
use rustc_middle::ty::{TyCtxt, TypeckResults, WithOptConstParam};
use rustc_mir_build::thir;
Expand Down

0 comments on commit 8688427

Please sign in to comment.