diff --git a/source/vir/src/func_to_air.rs b/source/vir/src/func_to_air.rs index 35e09f9d79..300134b744 100644 --- a/source/vir/src/func_to_air.rs +++ b/source/vir/src/func_to_air.rs @@ -375,7 +375,7 @@ pub fn req_ens_to_air( /// if the function is a spec function. pub fn func_name_to_air( ctx: &Ctx, - diagnostics: &impl Diagnostics, + _diagnostics: &impl Diagnostics, function: &Function, ) -> Result { let mut commands: Vec = Vec::new(); @@ -400,15 +400,8 @@ pub fn func_name_to_air( commands.push(Arc::new(CommandX::Global(decl))); // Check whether we need to declare the recursive version too - if let Some(body) = &function.x.body { - let body_exp = crate::ast_to_sst::expr_to_exp_as_spec( - &ctx, - diagnostics, - &UpdateCell::new(HashMap::new()), - ¶ms_to_pars(&function.x.params, false), - &body, - )?; - if crate::recursion::is_recursive_exp(ctx, &function.x.name, &body_exp) { + if function.x.body.is_some() { + if crate::recursion::fun_is_recursive(ctx, &function.x.name) { let rec_f = suffix_global_id(&fun_to_air_ident(&prefix_recursive_fun(&function.x.name))); let mut rec_typs = diff --git a/source/vir/src/recursion.rs b/source/vir/src/recursion.rs index 32f8f5a0b8..a680192299 100644 --- a/source/vir/src/recursion.rs +++ b/source/vir/src/recursion.rs @@ -16,10 +16,7 @@ use crate::sst::{ UniqueIdent, }; use crate::sst_to_air::PostConditionKind; -use crate::sst_visitor::{ - exp_rename_vars, exp_visitor_check, exp_visitor_dfs, map_exp_visitor, map_stm_visitor, - stm_visitor_dfs, VisitorControlFlow, -}; +use crate::sst_visitor::{exp_rename_vars, exp_visitor_check, map_exp_visitor, map_stm_visitor}; use crate::util::vec_map_result; use air::ast::{Binder, Commands, Span}; use air::ast_util::{ident_binder, str_ident, str_typ}; @@ -53,15 +50,6 @@ fn get_callee(ctx: &Ctx, target: &Fun, resolved_method: &Option<(Fun, Typs)>) -> } } -fn is_self_call( - ctx: &Ctx, - target: &Fun, - resolved_method: &Option<(Fun, Typs)>, - name: &Fun, -) -> bool { - get_callee(ctx, target, resolved_method) == Some(name.clone()) -} - fn is_recursive_call(ctxt: &Ctxt, target: &Fun, resolved_method: &Option<(Fun, Typs)>) -> bool { if let Some(callee) = get_callee(ctxt.ctx, target, resolved_method) { callee == ctxt.recursive_function_name @@ -315,45 +303,8 @@ fn terminates( } } -pub(crate) fn is_recursive_exp(ctx: &Ctx, name: &Fun, body: &Exp) -> bool { - if ctx.global.func_call_graph.get_scc_size(&Node::Fun(name.clone())) > 1 { - // This function is part of a mutually recursive component - true - } else { - let mut scope_map = ScopeMap::new(); - // Check for self-recursion, which SCC computation does not account for - match exp_visitor_dfs(body, &mut scope_map, &mut |exp, _scope_map| match &exp.x { - ExpX::Call(CallFun::Fun(x, resolved_method), _, _) - if is_self_call(ctx, x, resolved_method, name) => - { - VisitorControlFlow::Stop(()) - } - _ => VisitorControlFlow::Recurse, - }) { - VisitorControlFlow::Stop(()) => true, - _ => false, - } - } -} - -pub(crate) fn is_recursive_stm(ctx: &Ctx, name: &Fun, body: &Stm) -> bool { - if ctx.global.func_call_graph.get_scc_size(&Node::Fun(name.clone())) > 1 { - // This function is part of a mutually recursive component - true - } else { - // Check for self-recursion, which SCC computation does not account for - match stm_visitor_dfs(body, &mut |stm| match &stm.x { - StmX::Call { fun, resolved_method, .. } - if is_self_call(ctx, fun, resolved_method, name) => - { - VisitorControlFlow::Stop(()) - } - _ => VisitorControlFlow::Recurse, - }) { - VisitorControlFlow::Stop(()) => true, - _ => false, - } - } +pub(crate) fn fun_is_recursive(ctx: &Ctx, name: &Fun) -> bool { + ctx.global.func_call_graph.node_is_in_cycle(&Node::Fun(name.clone())) } fn mk_decreases_at_entry(ctxt: &Ctxt, span: &Span, exps: &Vec) -> (Vec, Vec) { @@ -407,7 +358,7 @@ pub(crate) fn check_termination_exp( proof_body: Vec, uses_decreases_by: bool, ) -> Result<(bool, Commands, Exp), VirErr> { - if !is_recursive_exp(ctx, &function.x.name, body) { + if !fun_is_recursive(ctx, &function.x.name) { return Ok((false, Arc::new(vec![]), body.clone())); } let num_decreases = function.x.decrease.len(); @@ -484,7 +435,7 @@ pub(crate) fn check_termination_stm( function: &Function, body: &Stm, ) -> Result<(Vec, Stm), VirErr> { - if !is_recursive_stm(ctx, &function.x.name, body) { + if !fun_is_recursive(ctx, &function.x.name) { return Ok((vec![], body.clone())); } let num_decreases = function.x.decrease.len(); diff --git a/source/vir/src/scc.rs b/source/vir/src/scc.rs index 838f4161c0..5b2dadf959 100644 --- a/source/vir/src/scc.rs +++ b/source/vir/src/scc.rs @@ -180,6 +180,18 @@ impl Graph { sorted } + pub fn node_has_direct_edge_to_itself(&self, t: &T) -> bool { + assert!(self.has_run); + assert!(self.h.contains_key(&t)); + let v: NodeIndex = self.h[t]; + for edge in self.nodes[v].edges.iter() { + if *edge == v { + return true; + } + } + false + } + pub fn get_scc_size(&self, t: &T) -> usize { assert!(self.has_run); match self.mapping.get(&t) { @@ -188,6 +200,10 @@ impl Graph { } } + pub fn node_is_in_cycle(&self, t: &T) -> bool { + self.node_has_direct_edge_to_itself(t) || self.get_scc_size(t) > 1 + } + pub fn get_scc_rep(&self, t: &T) -> T { assert!(self.has_run); assert!(self.mapping.contains_key(&t));