Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Preliminary refactoring #161

Merged
merged 4 commits into from
Jun 8, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .stainless-version
Original file line number Diff line number Diff line change
@@ -1 +1 @@
noxt-0.8.0-5-gf899dfd
noxt-0.8.0-8-gf15f5b5
67 changes: 29 additions & 38 deletions stainless_extraction/src/bindings.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use super::flags::Flags;
use super::*;

use rustc_hir::intravisit::{self, NestedVisitorMap, Visitor};
use rustc_hir::{self as hir, HirId, Node, Pat, PatKind};
use rustc_middle::ty;
use rustc_span::symbol::Symbol;
Expand Down Expand Up @@ -46,50 +47,40 @@ impl<'a, 'l, 'tcx> BodyExtractor<'a, 'l, 'tcx> {
let xtor = &mut self.base;

// Extract ident from corresponding HIR node, sanity-check binding mode
let (id, span, mutable) = {
let node = xtor.tcx.hir().find(hir_id).unwrap();

let (ident, mutable) = if let Node::Binding(Pat {
kind: PatKind::Binding(_, _, ident, _),
hir_id,
span,
..
}) = node
{
match self
let node = xtor.tcx.hir().find(hir_id).unwrap();
let (ident, binding_mode, span) = if let Node::Binding(Pat {
kind: PatKind::Binding(_, _, ident, _),
hir_id,
span,
..
}) = node
{
(
ident,
self
.tables
.extract_binding_mode(xtor.tcx.sess, *hir_id, *span)
{
// allowed binding modes
Some(ty::BindByValue(hir::Mutability::Not))
| Some(ty::BindByReference(hir::Mutability::Not)) => (ident, false),
Some(ty::BindByValue(hir::Mutability::Mut)) => (ident, true),

// For the forbidden binding modes, return the identifier anyway
// because failure will occur later.
_ => {
xtor.unsupported(*span, "Only immutable bindings are supported");
(ident, false)
}
}
} else {
xtor.unsupported(
node.ident().map(|ident| ident.span).unwrap_or_default(),
"Cannot extract complex pattern in binding (cannot recover from this)",
);
unreachable!()
};

(
xtor.register_hir(hir_id, ident.name.to_string()),
ident.span,
mutable,
.expect("Cannot extract binding without binding mode."),
span,
)
} else {
xtor.unsupported(
node.ident().map(|ident| ident.span).unwrap_or_default(),
"Cannot extract complex pattern in binding (cannot recover from this)",
);
unreachable!()
};

if let ty::BindByReference(Mutability::Mut) = binding_mode {
xtor.unsupported(*span, "Only immutable bindings are supported");
}

let id = xtor.register_hir(hir_id, ident.name.to_string());
let mutable = matches!(binding_mode, ty::BindByValue(Mutability::Mut));

// Build a Variable node
let f = xtor.factory();
let tpe = xtor.extract_ty(self.tables.node_type(hir_id), &self.txtcx, span);
let tpe = xtor.extract_ty(self.tables.node_type(hir_id), &self.txtcx, ident.span);
let flags = flags_opt
.map(|flags| flags.to_stainless(f))
.into_iter()
Expand Down Expand Up @@ -206,7 +197,7 @@ impl<'bxtor, 'a, 'l, 'tcx> BindingsCollector<'bxtor, 'a, 'l, 'tcx> {
}
}

impl<'bxtor, 'a, 'l, 'tcx> Visitor<'tcx> for BindingsCollector<'bxtor, 'a, 'l, 'tcx> {
impl<'tcx> Visitor<'tcx> for BindingsCollector<'_, '_, '_, 'tcx> {
type Map = rustc_middle::hir::map::Map<'tcx>;

fn nested_visit_map(&mut self) -> NestedVisitorMap<Self::Map> {
Expand Down
121 changes: 59 additions & 62 deletions stainless_extraction/src/expr/map.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,22 +5,28 @@ impl<'a, 'l, 'tcx> BodyExtractor<'a, 'l, 'tcx> {
&mut self,
item: CrateItem,
args: &'a [Expr<'a, 'tcx>],
substs_ref: SubstsRef<'tcx>,
substs: SubstsRef<'tcx>,
span: Span,
) -> st::Expr<'l> {
let (key_tpe, val_tpe) = match &self.base.extract_tys(substs.types(), &self.txtcx, span)[..] {
[key_tpe, val_tpe] => (*key_tpe, *val_tpe),
_ => {
return self.unsupported_expr(
span,
format!("Cannot extract {:?} with {} types.", item, substs.len()),
)
}
};

match (item, &self.extract_exprs(args)[..]) {
(MapNewFn, []) => self.extract_map_creation(substs_ref, span),
(MapNewFn, []) => self.extract_map_creation(key_tpe, val_tpe),

(MapIndexFn, [map, key]) => self.extract_map_apply(*map, *key, substs_ref, span),
(MapContainsKeyFn, [map, key]) => self.extract_map_contains(*map, *key, substs_ref, span),
(MapRemoveFn, [map, key]) => self.extract_map_removed(*map, *key, substs_ref, span),
(MapIndexFn, [map, key]) => self.extract_map_apply(*map, *key, val_tpe),
(MapContainsKeyFn, [map, key]) => self.extract_map_contains(*map, *key, val_tpe),
(MapRemoveFn, [map, key]) => self.extract_map_remove(*map, *key, val_tpe),
(MapGetFn, [map, key]) => self.factory().MapApply(*map, *key).into(),
(MapInsertFn, [map, key, val]) => {
self.extract_map_updated(*map, *key, *val, substs_ref, span)
}
(MapGetOrFn, [map, key, or_else]) => {
self.extract_map_get_or_else(*map, *key, *or_else, substs_ref, span)
}
(MapInsertFn, [map, key, val]) => self.extract_map_insert(*map, *key, val_tpe, *val),
(MapGetOrFn, [map, key, or_else]) => self.extract_map_get_or(*map, *key, val_tpe, *or_else),

(op, _) => self.unsupported_expr(
span,
Expand All @@ -29,32 +35,25 @@ impl<'a, 'l, 'tcx> BodyExtractor<'a, 'l, 'tcx> {
}
}

fn extract_map_creation(&mut self, substs: SubstsRef<'tcx>, span: Span) -> st::Expr<'l> {
let f = self.factory();
let tps = self.base.extract_tys(substs.types(), &self.txtcx, span);

match &tps[..] {
[key_tpe, val_tpe] => f
.FiniteMap(
vec![],
self.synth().std_option_none(*val_tpe),
*key_tpe,
self.synth().std_option_type(*val_tpe),
)
.into(),
_ => unreachable!(),
}
fn extract_map_creation(&mut self, key_tpe: st::Type<'l>, val_tpe: st::Type<'l>) -> st::Expr<'l> {
self
.factory()
.FiniteMap(
vec![],
self.synth().std_option_none(val_tpe),
key_tpe,
self.synth().std_option_type(val_tpe),
)
.into()
}

fn extract_map_apply(
&mut self,
map: st::Expr<'l>,
key: st::Expr<'l>,
substs: SubstsRef<'tcx>,
span: Span,
val_tpe: st::Type<'l>,
) -> st::Expr<'l> {
let f = self.factory();
let val_tpe = self.base.extract_ty(substs.type_at(1), &self.txtcx, span);
let some_tpe = self.synth().std_option_some_type(val_tpe);
f.Assert(
f.IsInstanceOf(f.MapApply(map, key).into(), some_tpe).into(),
Expand All @@ -70,11 +69,9 @@ impl<'a, 'l, 'tcx> BodyExtractor<'a, 'l, 'tcx> {
&mut self,
map: st::Expr<'l>,
key: st::Expr<'l>,
substs: SubstsRef<'tcx>,
span: Span,
val_tpe: st::Type<'l>,
) -> st::Expr<'l> {
let f = self.factory();
let val_tpe = self.base.extract_ty(substs.type_at(1), &self.txtcx, span);
f.Not(
f.Equals(
f.MapApply(map, key).into(),
Expand All @@ -85,52 +82,52 @@ impl<'a, 'l, 'tcx> BodyExtractor<'a, 'l, 'tcx> {
.into()
}

fn extract_map_removed(
fn extract_map_get_or(
&mut self,
map: st::Expr<'l>,
key: st::Expr<'l>,
substs: SubstsRef<'tcx>,
span: Span,
val_tpe: st::Type<'l>,
or_else: st::Expr<'l>,
) -> st::Expr<'l> {
let val_tpe = self.base.extract_ty(substs.type_at(1), &self.txtcx, span);
self
.factory()
.MapUpdated(map, key, self.synth().std_option_none(val_tpe))
.into()
let f = self.factory();
let some_tpe = self.synth().std_option_some_type(val_tpe);
f.IfExpr(
f.IsInstanceOf(f.MapApply(map, key).into(), some_tpe).into(),
self
.synth()
.std_option_some_value(f.AsInstanceOf(f.MapApply(map, key).into(), some_tpe).into()),
or_else,
)
.into()
}

fn extract_map_updated(
fn extract_map_insert(
&mut self,
map: st::Expr<'l>,
key: st::Expr<'l>,
val_tpe: st::Type<'l>,
val: st::Expr<'l>,
substs: SubstsRef<'tcx>,
span: Span,
) -> st::Expr<'l> {
let f = self.factory();
let val_tpe = self.base.extract_ty(substs.type_at(1), &self.txtcx, span);
f.MapUpdated(map, key, self.synth().std_option_some(val, val_tpe))
.into()
let update_value = self.synth().std_option_some(val, val_tpe);
self.extract_map_update(map, key, update_value)
}

fn extract_map_get_or_else(
fn extract_map_remove(
&mut self,
map: st::Expr<'l>,
key: st::Expr<'l>,
or_else: st::Expr<'l>,
substs: SubstsRef<'tcx>,
span: Span,
val_tpe: st::Type<'l>,
) -> st::Expr<'l> {
let f = self.factory();
let val_tpe = self.base.extract_ty(substs.type_at(1), &self.txtcx, span);
let some_tpe = self.synth().std_option_some_type(val_tpe);
f.IfExpr(
f.IsInstanceOf(f.MapApply(map, key).into(), some_tpe).into(),
self
.synth()
.std_option_some_value(f.AsInstanceOf(f.MapApply(map, key).into(), some_tpe).into()),
or_else,
)
.into()
let update_value = self.synth().std_option_none(val_tpe);
self.extract_map_update(map, key, update_value)
}

fn extract_map_update(
&mut self,
map: st::Expr<'l>,
key: st::Expr<'l>,
update_value: st::Expr<'l>,
) -> st::Expr<'l> {
self.factory().MapUpdated(map, key, update_value).into()
}
}
7 changes: 7 additions & 0 deletions stainless_extraction/src/fns.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,13 @@ impl<'a> FnItem<'a> {
}
}

pub struct FnSignature<'l> {
pub id: &'l st::SymbolIdentifier<'l>,
pub tparams: Vec<&'l st::TypeParameterDef<'l>>,
pub params: Params<'l>,
pub return_tpe: st::Type<'l>,
}

/// Identifies the specific implementation/instance of a type class that is
/// needed at a method call site.
#[derive(Debug, Eq, PartialEq)]
Expand Down
46 changes: 25 additions & 21 deletions stainless_extraction/src/krate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,8 @@ use rustc_middle::ty::{AssocKind, List, TraitRef};
use rustc_span::DUMMY_SP;

use stainless_data::ast as st;
use stainless_data::ast::{SymbolIdentifier, TypeParameterDef};

use crate::fns::FnItem;
use crate::fns::{FnItem, FnSignature};
use std::iter;

/// Top-level extraction
Expand Down Expand Up @@ -260,15 +259,20 @@ impl<'l, 'tcx> BaseExtractor<'l, 'tcx> {
"Expected non-local def id, got: {:?}",
def_id
);
let (id, tparams, params, rtp) = self.extract_fn_signature(def_id);
let FnSignature {
id,
tparams,
params,
return_tpe,
} = self.extract_fn_signature(def_id);

let f = self.factory();
let empty_body = f.NoTree(rtp).into();
let empty_body = f.NoTree(return_tpe).into();
f.FunDef(
id,
tparams,
params,
rtp,
return_tpe,
empty_body,
vec![f.Extern().into()],
)
Expand All @@ -282,17 +286,22 @@ impl<'l, 'tcx> BaseExtractor<'l, 'tcx> {
})
}

pub fn extract_abstract_fn(&mut self, def_id: DefId) -> &'l st::FunDef<'l> {
let (id, tparams, params, rtp) = self.extract_fn_signature(def_id);
fn extract_abstract_fn(&mut self, def_id: DefId) -> &'l st::FunDef<'l> {
let FnSignature {
id,
tparams,
params,
return_tpe,
} = self.extract_fn_signature(def_id);
let class_def = self.get_class_of_method(id);

let f = self.factory();
let empty_body = f.NoTree(rtp).into();
let empty_body = f.NoTree(return_tpe).into();
f.FunDef(
id,
self.filter_class_tparams(tparams, class_def),
params,
rtp,
return_tpe,
empty_body,
class_def
.iter()
Expand All @@ -302,15 +311,7 @@ impl<'l, 'tcx> BaseExtractor<'l, 'tcx> {
)
}

fn extract_fn_signature(
&mut self,
def_id: DefId,
) -> (
&'l SymbolIdentifier<'l>,
Vec<&'l TypeParameterDef<'l>>,
Params<'l>,
st::Type<'l>,
) {
fn extract_fn_signature(&mut self, def_id: DefId) -> FnSignature<'l> {
let f = self.factory();

// Extract the function signature
Expand All @@ -328,10 +329,13 @@ impl<'l, 'tcx> BaseExtractor<'l, 'tcx> {
&*f.ValDef(var)
})
.collect();
let return_tpe = self.extract_ty(fn_sig.output(), &txtcx, DUMMY_SP);

let fun_id = self.get_or_extract_fn_ref(def_id);
(fun_id, tparams, params, return_tpe)
FnSignature {
id: self.get_or_extract_fn_ref(def_id),
tparams,
params,
return_tpe: self.extract_ty(fn_sig.output(), &txtcx, DUMMY_SP),
}
}

pub fn get_or_extract_local_fn(&mut self, fn_item: &FnItem<'l>) -> &'l st::FunDef<'l> {
Expand Down
2 changes: 1 addition & 1 deletion stainless_extraction/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@ use std::collections::{HashMap, HashSet};
use std::rc::Rc;

use rustc_hir::def_id::{DefId, LocalDefId};
use rustc_hir::intravisit::{self, NestedVisitorMap, Visitor};
use rustc_hir::{self as hir, HirId};
use rustc_middle::mir::Mutability;
use rustc_middle::span_bug;
use rustc_middle::ty::{TyCtxt, TypeckResults, WithOptConstParam};
use rustc_mir_build::thir;
Expand Down