diff --git a/kani-compiler/src/codegen_cprover_gotoc/compiler_interface.rs b/kani-compiler/src/codegen_cprover_gotoc/compiler_interface.rs index ff98880ac4d72..986cd00e32a50 100644 --- a/kani-compiler/src/codegen_cprover_gotoc/compiler_interface.rs +++ b/kani-compiler/src/codegen_cprover_gotoc/compiler_interface.rs @@ -7,14 +7,14 @@ use crate::args::ReachabilityType; use crate::codegen_cprover_gotoc::GotocCtx; use crate::kani_middle::analysis; use crate::kani_middle::attributes::{is_test_harness_description, KaniAttributes}; +use crate::kani_middle::check_reachable_items; use crate::kani_middle::codegen_units::{CodegenUnit, CodegenUnits}; use crate::kani_middle::metadata::gen_test_metadata; use crate::kani_middle::provide; use crate::kani_middle::reachability::{ collect_reachable_items, filter_const_crate_items, filter_crate_items, }; -use crate::kani_middle::transform::BodyTransformation; -use crate::kani_middle::{check_reachable_items, dump_mir_items}; +use crate::kani_middle::transform::{BodyTransformation, GlobalPasses}; use crate::kani_queries::QueryDb; use cbmc::goto_program::Location; use cbmc::irep::goto_binary_serde::write_goto_binary_file; @@ -87,11 +87,33 @@ impl GotocCodegenBackend { check_contract: Option, mut transformer: BodyTransformation, ) -> (GotocCtx<'tcx>, Vec, Option) { - let items = with_timer( + let (items, call_graph) = with_timer( || collect_reachable_items(tcx, &mut transformer, starting_items), "codegen reachability analysis", ); - dump_mir_items(tcx, &mut transformer, &items, &symtab_goto.with_extension("kani.mir")); + + // Retrieve all instances from the currently codegened items. + let instances = items + .iter() + .filter_map(|item| match item { + MonoItem::Fn(instance) => Some(*instance), + MonoItem::Static(static_def) => { + let instance: Instance = (*static_def).into(); + instance.has_body().then_some(instance) + } + MonoItem::GlobalAsm(_) => None, + }) + .collect(); + + // Apply all transformation passes, including global passes. + let mut global_passes = GlobalPasses::new(&self.queries.lock().unwrap(), tcx); + global_passes.run_global_passes( + &mut transformer, + tcx, + starting_items, + instances, + call_graph, + ); // Follow rustc naming convention (cx is abbrev for context). // https://rustc-dev-guide.rust-lang.org/conventions.html#naming-conventions diff --git a/kani-compiler/src/kani_middle/mod.rs b/kani-compiler/src/kani_middle/mod.rs index 17b08b687e30c..3c300e9da52cf 100644 --- a/kani-compiler/src/kani_middle/mod.rs +++ b/kani-compiler/src/kani_middle/mod.rs @@ -4,9 +4,7 @@ //! and transformations. use std::collections::HashSet; -use std::path::Path; -use crate::kani_middle::transform::BodyTransformation; use crate::kani_queries::QueryDb; use rustc_hir::{def::DefKind, def_id::LOCAL_CRATE}; use rustc_middle::span_bug; @@ -15,18 +13,14 @@ use rustc_middle::ty::layout::{ LayoutOfHelpers, TyAndLayout, }; use rustc_middle::ty::{self, Instance as InstanceInternal, Ty as TyInternal, TyCtxt}; -use rustc_session::config::OutputType; use rustc_smir::rustc_internal; use rustc_span::source_map::respan; use rustc_span::Span; use rustc_target::abi::call::FnAbi; use rustc_target::abi::{HasDataLayout, TargetDataLayout}; -use stable_mir::mir::mono::{Instance, MonoItem}; +use stable_mir::mir::mono::MonoItem; use stable_mir::ty::{FnDef, RigidTy, Span as SpanStable, TyKind}; use stable_mir::CrateDef; -use std::fs::File; -use std::io::BufWriter; -use std::io::Write; use self::attributes::KaniAttributes; @@ -92,41 +86,6 @@ pub fn check_reachable_items(tcx: TyCtxt, queries: &QueryDb, items: &[MonoItem]) tcx.dcx().abort_if_errors(); } -/// Print MIR for the reachable items if the `--emit mir` option was provided to rustc. -pub fn dump_mir_items( - tcx: TyCtxt, - transformer: &mut BodyTransformation, - items: &[MonoItem], - output: &Path, -) { - /// Convert MonoItem into a DefId. - /// Skip stuff that we cannot generate the MIR items. - fn get_instance(item: &MonoItem) -> Option { - match item { - // Exclude FnShims and others that cannot be dumped. - MonoItem::Fn(instance) => Some(*instance), - MonoItem::Static(def) => { - let instance: Instance = (*def).into(); - instance.has_body().then_some(instance) - } - MonoItem::GlobalAsm(_) => None, - } - } - - if tcx.sess.opts.output_types.contains_key(&OutputType::Mir) { - // Create output buffer. - let out_file = File::create(output).unwrap(); - let mut writer = BufWriter::new(out_file); - - // For each def_id, dump their MIR - for instance in items.iter().filter_map(get_instance) { - writeln!(writer, "// Item: {} ({})", instance.name(), instance.mangled_name()).unwrap(); - let body = transformer.body(tcx, instance); - let _ = body.dump(&mut writer, &instance.name()); - } - } -} - /// Structure that represents the source location of a definition. /// TODO: Use `InternedString` once we move it out of the cprover_bindings. /// diff --git a/kani-compiler/src/kani_middle/reachability.rs b/kani-compiler/src/kani_middle/reachability.rs index 279dcf8cc1077..d2c9d50515c4f 100644 --- a/kani-compiler/src/kani_middle/reachability.rs +++ b/kani-compiler/src/kani_middle/reachability.rs @@ -22,6 +22,7 @@ use rustc_data_structures::fingerprint::Fingerprint; use rustc_data_structures::fx::FxHashSet; use rustc_data_structures::stable_hasher::{HashStable, StableHasher}; use rustc_middle::ty::{TyCtxt, VtblEntry}; +use rustc_session::config::OutputType; use rustc_smir::rustc_internal; use stable_mir::mir::alloc::{AllocId, GlobalAlloc}; use stable_mir::mir::mono::{Instance, InstanceKind, MonoItem, StaticDef}; @@ -32,6 +33,12 @@ use stable_mir::mir::{ use stable_mir::ty::{Allocation, ClosureKind, ConstantKind, RigidTy, Ty, TyKind}; use stable_mir::CrateItem; use stable_mir::{CrateDef, ItemKind}; +use std::fmt::{Display, Formatter}; +use std::{ + collections::{HashMap, HashSet}, + fs::File, + io::{BufWriter, Write}, +}; use crate::kani_middle::coercion; use crate::kani_middle::coercion::CoercionBase; @@ -42,7 +49,7 @@ pub fn collect_reachable_items( tcx: TyCtxt, transformer: &mut BodyTransformation, starting_points: &[MonoItem], -) -> Vec { +) -> (Vec, CallGraph) { // For each harness, collect items using the same collector. // I.e.: This will return any item that is reachable from one or more of the starting points. let mut collector = MonoItemsCollector::new(tcx, transformer); @@ -62,7 +69,7 @@ pub fn collect_reachable_items( // order of the errors and warnings is stable. let mut sorted_items: Vec<_> = collector.collected.into_iter().collect(); sorted_items.sort_by_cached_key(|item| to_fingerprint(tcx, item)); - sorted_items + (sorted_items, collector.call_graph) } /// Collect all (top-level) items in the crate that matches the given predicate. @@ -118,7 +125,24 @@ where } } } - roots + roots.into_iter().map(|root| root.item).collect() +} + +/// Reason for introducing an edge in the call graph. +#[derive(Clone, Copy, Debug, Eq, PartialEq, Hash)] +enum CollectionReason { + DirectCall, + IndirectCall, + VTableMethod, + Static, + StaticDrop, +} + +/// A destination of the edge in the call graph. +#[derive(Clone, Debug, Eq, PartialEq, Hash)] +struct CollectedItem { + item: MonoItem, + reason: CollectionReason, } struct MonoItemsCollector<'tcx, 'a> { @@ -130,8 +154,8 @@ struct MonoItemsCollector<'tcx, 'a> { collected: FxHashSet, /// Items enqueued for visiting. queue: Vec, - #[cfg(debug_assertions)] - call_graph: debug::CallGraph, + /// Call graph used for dataflow analysis. + call_graph: CallGraph, } impl<'tcx, 'a> MonoItemsCollector<'tcx, 'a> { @@ -140,8 +164,7 @@ impl<'tcx, 'a> MonoItemsCollector<'tcx, 'a> { tcx, collected: FxHashSet::default(), queue: vec![], - #[cfg(debug_assertions)] - call_graph: debug::CallGraph::default(), + call_graph: CallGraph::default(), transformer, } } @@ -167,17 +190,17 @@ impl<'tcx, 'a> MonoItemsCollector<'tcx, 'a> { vec![] } }; - #[cfg(debug_assertions)] self.call_graph.add_edges(to_visit, &next_items); - self.queue - .extend(next_items.into_iter().filter(|item| !self.collected.contains(item))); + self.queue.extend(next_items.into_iter().filter_map( + |CollectedItem { item, .. }| (!self.collected.contains(&item)).then_some(item), + )); } } } /// Visit a function and collect all mono-items reachable from its instructions. - fn visit_fn(&mut self, instance: Instance) -> Vec { + fn visit_fn(&mut self, instance: Instance) -> Vec { let _guard = debug_span!("visit_fn", function=?instance).entered(); let body = self.transformer.body(self.tcx, instance); let mut collector = @@ -187,19 +210,24 @@ impl<'tcx, 'a> MonoItemsCollector<'tcx, 'a> { } /// Visit a static object and collect drop / initialization functions. - fn visit_static(&mut self, def: StaticDef) -> Vec { + fn visit_static(&mut self, def: StaticDef) -> Vec { let _guard = debug_span!("visit_static", ?def).entered(); let mut next_items = vec![]; // Collect drop function. let static_ty = def.ty(); let instance = Instance::resolve_drop_in_place(static_ty); - next_items.push(instance.into()); + next_items + .push(CollectedItem { item: instance.into(), reason: CollectionReason::StaticDrop }); // Collect initialization. let alloc = def.eval_initializer().unwrap(); for (_, prov) in alloc.provenance.ptrs { - next_items.extend(collect_alloc_items(prov.0).into_iter()); + next_items.extend( + collect_alloc_items(prov.0) + .into_iter() + .map(|item| CollectedItem { item, reason: CollectionReason::Static }), + ); } next_items @@ -213,7 +241,7 @@ impl<'tcx, 'a> MonoItemsCollector<'tcx, 'a> { struct MonoItemsFnCollector<'a, 'tcx> { tcx: TyCtxt<'tcx>, - collected: FxHashSet, + collected: FxHashSet, body: &'a Body, } @@ -251,7 +279,9 @@ impl<'a, 'tcx> MonoItemsFnCollector<'a, 'tcx> { } }); trace!(methods=?methods.clone().collect::>(), "collect_vtable_methods"); - self.collected.extend(methods); + self.collected.extend( + methods.map(|item| CollectedItem { item, reason: CollectionReason::VTableMethod }), + ); } // Add the destructor for the concrete type. @@ -282,7 +312,12 @@ impl<'a, 'tcx> MonoItemsFnCollector<'a, 'tcx> { }; if should_collect && should_codegen_locally(&instance) { trace!(?instance, "collect_instance"); - self.collected.insert(instance.into()); + let reason = if is_direct_call { + CollectionReason::DirectCall + } else { + CollectionReason::IndirectCall + }; + self.collected.insert(CollectedItem { item: instance.into(), reason }); } } @@ -290,7 +325,11 @@ impl<'a, 'tcx> MonoItemsFnCollector<'a, 'tcx> { fn collect_allocation(&mut self, alloc: &Allocation) { debug!(?alloc, "collect_allocation"); for (_, id) in &alloc.provenance.ptrs { - self.collected.extend(collect_alloc_items(id.0).into_iter()) + self.collected.extend( + collect_alloc_items(id.0) + .into_iter() + .map(|item| CollectedItem { item, reason: CollectionReason::Static }), + ) } } } @@ -366,7 +405,8 @@ impl<'a, 'tcx> MirVisitor for MonoItemsFnCollector<'a, 'tcx> { } Rvalue::ThreadLocalRef(item) => { trace!(?item, "visit_rvalue thread_local"); - self.collected.insert(MonoItem::Static(StaticDef::try_from(item).unwrap())); + let item = MonoItem::Static(StaticDef::try_from(item).unwrap()); + self.collected.insert(CollectedItem { item, reason: CollectionReason::Static }); } _ => { /* not interesting */ } } @@ -485,128 +525,150 @@ fn collect_alloc_items(alloc_id: AllocId) -> Vec { items } -#[cfg(debug_assertions)] -#[allow(dead_code)] -mod debug { - - use std::fmt::{Display, Formatter}; - use std::{ - collections::{HashMap, HashSet}, - fs::File, - io::{BufWriter, Write}, - }; - - use rustc_session::config::OutputType; - - use super::*; +/// Call graph with edges annotated with the reason why they were added to the graph. +#[derive(Debug, Default)] +pub struct CallGraph { + /// Nodes of the graph. + nodes: HashSet, + /// Edges of the graph. + edges: HashMap>, + /// Since the graph is directed, we also store back edges. + back_edges: HashMap>, +} - #[derive(Debug, Default)] - pub struct CallGraph { - // Nodes of the graph. - nodes: HashSet, - edges: HashMap>, - back_edges: HashMap>, +/// Newtype around MonoItem. +#[derive(Clone, Debug, Eq, PartialEq, Hash)] +struct Node(pub MonoItem); + +/// Newtype around CollectedItem. +#[derive(Clone, Debug, Eq, PartialEq, Hash)] +struct CollectedNode(pub CollectedItem); + +impl CallGraph { + /// Add a new node into a graph. + fn add_node(&mut self, item: MonoItem) { + let node = Node(item); + self.nodes.insert(node.clone()); + self.edges.entry(node.clone()).or_default(); + self.back_edges.entry(node).or_default(); } - #[derive(Clone, Debug, Eq, PartialEq, Hash)] - struct Node(pub MonoItem); - - impl CallGraph { - pub fn add_node(&mut self, item: MonoItem) { - let node = Node(item); - self.nodes.insert(node.clone()); - self.edges.entry(node.clone()).or_default(); - self.back_edges.entry(node).or_default(); - } + /// Add a new edge "from" -> "to". + fn add_edge(&mut self, from: MonoItem, to: MonoItem, collection_reason: CollectionReason) { + let from_node = Node(from.clone()); + let to_node = Node(to.clone()); + self.add_node(from.clone()); + self.add_node(to.clone()); + self.edges + .get_mut(&from_node) + .unwrap() + .push(CollectedNode(CollectedItem { item: to, reason: collection_reason })); + self.back_edges + .get_mut(&to_node) + .unwrap() + .push(CollectedNode(CollectedItem { item: from, reason: collection_reason })); + } - /// Add a new edge "from" -> "to". - pub fn add_edge(&mut self, from: MonoItem, to: MonoItem) { - let from_node = Node(from.clone()); - let to_node = Node(to.clone()); - self.add_node(from); - self.add_node(to); - self.edges.get_mut(&from_node).unwrap().push(to_node.clone()); - self.back_edges.get_mut(&to_node).unwrap().push(from_node); + /// Add multiple new edges for the "from" node. + fn add_edges(&mut self, from: MonoItem, to: &[CollectedItem]) { + self.add_node(from.clone()); + for CollectedItem { item, reason } in to { + self.add_edge(from.clone(), item.clone(), *reason); } + } - /// Add multiple new edges for the "from" node. - pub fn add_edges(&mut self, from: MonoItem, to: &[MonoItem]) { - self.add_node(from.clone()); - for item in to { - self.add_edge(from.clone(), item.clone()); + /// Print the graph in DOT format to a file. + /// See for more information. + fn dump_dot(&self, tcx: TyCtxt) -> std::io::Result<()> { + if let Ok(target) = std::env::var("KANI_REACH_DEBUG") { + debug!(?target, "dump_dot"); + let outputs = tcx.output_filenames(()); + let base_path = outputs.path(OutputType::Metadata); + let path = base_path.as_path().with_extension("dot"); + let out_file = File::create(path)?; + let mut writer = BufWriter::new(out_file); + writeln!(writer, "digraph ReachabilityGraph {{")?; + if target.is_empty() { + self.dump_all(&mut writer)?; + } else { + // Only dump nodes that led the reachability analysis to the target node. + self.dump_reason(&mut writer, &target)?; } + writeln!(writer, "}}")?; } - /// Print the graph in DOT format to a file. - /// See for more information. - pub fn dump_dot(&self, tcx: TyCtxt) -> std::io::Result<()> { - if let Ok(target) = std::env::var("KANI_REACH_DEBUG") { - debug!(?target, "dump_dot"); - let outputs = tcx.output_filenames(()); - let base_path = outputs.path(OutputType::Metadata); - let path = base_path.as_path().with_extension("dot"); - let out_file = File::create(path)?; - let mut writer = BufWriter::new(out_file); - writeln!(writer, "digraph ReachabilityGraph {{")?; - if target.is_empty() { - self.dump_all(&mut writer)?; - } else { - // Only dump nodes that led the reachability analysis to the target node. - self.dump_reason(&mut writer, &target)?; - } - writeln!(writer, "}}")?; - } + Ok(()) + } - Ok(()) + /// Write all notes to the given writer. + fn dump_all(&self, writer: &mut W) -> std::io::Result<()> { + tracing::info!(nodes=?self.nodes.len(), edges=?self.edges.len(), "dump_all"); + for node in &self.nodes { + writeln!(writer, r#""{node}""#)?; + for succ in self.edges.get(node).unwrap() { + let reason = succ.0.reason; + writeln!(writer, r#""{node}" -> "{succ}" [label={reason:?}] "#)?; + } } + Ok(()) + } - /// Write all notes to the given writer. - fn dump_all(&self, writer: &mut W) -> std::io::Result<()> { - tracing::info!(nodes=?self.nodes.len(), edges=?self.edges.len(), "dump_all"); - for node in &self.nodes { - writeln!(writer, r#""{node}""#)?; - for succ in self.edges.get(node).unwrap() { - writeln!(writer, r#""{node}" -> "{succ}" "#)?; - } + /// Write all notes that may have led to the discovery of the given target. + fn dump_reason(&self, writer: &mut W, target: &str) -> std::io::Result<()> { + let mut queue: Vec = + self.nodes.iter().filter(|item| item.to_string().contains(target)).cloned().collect(); + let mut visited: HashSet = HashSet::default(); + tracing::info!(target=?queue, nodes=?self.nodes.len(), edges=?self.edges.len(), "dump_reason"); + while let Some(to_visit) = queue.pop() { + if !visited.contains(&to_visit) { + visited.insert(to_visit.clone()); + queue.extend( + self.back_edges + .get(&to_visit) + .unwrap() + .iter() + .map(|item| Node::from(item.clone())), + ); } - Ok(()) } - /// Write all notes that may have led to the discovery of the given target. - fn dump_reason(&self, writer: &mut W, target: &str) -> std::io::Result<()> { - let mut queue = self - .nodes - .iter() - .filter(|item| item.to_string().contains(target)) - .collect::>(); - let mut visited: HashSet<&Node> = HashSet::default(); - tracing::info!(target=?queue, nodes=?self.nodes.len(), edges=?self.edges.len(), "dump_reason"); - while let Some(to_visit) = queue.pop() { - if !visited.contains(to_visit) { - visited.insert(to_visit); - queue.extend(self.back_edges.get(to_visit).unwrap()); - } + for node in &visited { + writeln!(writer, r#""{node}""#)?; + let edges = self.edges.get(node).unwrap(); + for succ in edges.iter().filter(|item| { + let node = Node::from((*item).clone()); + visited.contains(&node) + }) { + let reason = succ.0.reason; + writeln!(writer, r#""{node}" -> "{succ}" [label={reason:?}] "#)?; } + } + Ok(()) + } +} - for node in &visited { - writeln!(writer, r#""{node}""#)?; - for succ in - self.edges.get(node).unwrap().iter().filter(|item| visited.contains(item)) - { - writeln!(writer, r#""{node}" -> "{succ}" "#)?; - } - } - Ok(()) +impl Display for Node { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + match &self.0 { + MonoItem::Fn(instance) => write!(f, "{}", instance.name()), + MonoItem::Static(def) => write!(f, "{}", def.name()), + MonoItem::GlobalAsm(asm) => write!(f, "{asm:?}"), } } +} - impl Display for Node { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - match &self.0 { - MonoItem::Fn(instance) => write!(f, "{}", instance.name()), - MonoItem::Static(def) => write!(f, "{}", def.name()), - MonoItem::GlobalAsm(asm) => write!(f, "{asm:?}"), - } +impl Display for CollectedNode { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + match &self.0.item { + MonoItem::Fn(instance) => write!(f, "{}", instance.name()), + MonoItem::Static(def) => write!(f, "{}", def.name()), + MonoItem::GlobalAsm(asm) => write!(f, "{asm:?}"), } } } + +impl From for Node { + fn from(value: CollectedNode) -> Self { + Node(value.0.item) + } +} diff --git a/kani-compiler/src/kani_middle/transform/dump_mir_pass.rs b/kani-compiler/src/kani_middle/transform/dump_mir_pass.rs new file mode 100644 index 0000000000000..9393ec0d88c9a --- /dev/null +++ b/kani-compiler/src/kani_middle/transform/dump_mir_pass.rs @@ -0,0 +1,69 @@ +// Copyright Kani Contributors +// SPDX-License-Identifier: Apache-2.0 OR MIT + +//! Global transformation pass, which does not modify bodies but dumps MIR whenever the appropriate debug flag is passed. + +use crate::kani_middle::reachability::CallGraph; +use crate::kani_middle::transform::GlobalPass; +use crate::kani_queries::QueryDb; +use kani_metadata::ArtifactType; +use rustc_middle::ty::TyCtxt; +use rustc_session::config::OutputType; +use stable_mir::mir::mono::{Instance, MonoItem}; +use std::fs::File; +use std::io::BufWriter; +use std::io::Write; + +use super::BodyTransformation; + +/// Dump all MIR bodies. +#[derive(Debug)] +pub struct DumpMirPass { + enabled: bool, +} + +impl DumpMirPass { + pub fn new(tcx: TyCtxt) -> Self { + Self { enabled: tcx.sess.opts.output_types.contains_key(&OutputType::Mir) } + } +} + +impl GlobalPass for DumpMirPass { + fn is_enabled(&self, _query_db: &QueryDb) -> bool { + self.enabled + } + + fn transform( + &mut self, + tcx: TyCtxt, + _call_graph: &CallGraph, + starting_items: &[MonoItem], + instances: Vec, + transformer: &mut BodyTransformation, + ) { + // Create output buffer. + let file_path = { + let base_path = tcx.output_filenames(()).path(OutputType::Object); + let base_name = base_path.as_path(); + let entry_point = (starting_items.len() == 1).then_some(starting_items[0].clone()); + // If there is a single entry point, use it as a file name. + if let Some(MonoItem::Fn(starting_instance)) = entry_point { + let mangled_name = starting_instance.mangled_name(); + let file_stem = + format!("{}_{mangled_name}", base_name.file_stem().unwrap().to_str().unwrap()); + base_name.with_file_name(file_stem).with_extension(ArtifactType::SymTabGoto) + } else { + // Otherwise, use the object output path from the compiler. + base_name.with_extension(ArtifactType::SymTabGoto) + } + }; + let out_file = File::create(file_path.with_extension("kani.mir")).unwrap(); + let mut writer = BufWriter::new(out_file); + + // For each def_id, dump their MIR. + for instance in instances.iter() { + writeln!(writer, "// Item: {} ({})", instance.name(), instance.mangled_name()).unwrap(); + let _ = transformer.body(tcx, *instance).dump(&mut writer, &instance.name()); + } + } +} diff --git a/kani-compiler/src/kani_middle/transform/mod.rs b/kani-compiler/src/kani_middle/transform/mod.rs index 26a95978fcaff..5b497b09619de 100644 --- a/kani-compiler/src/kani_middle/transform/mod.rs +++ b/kani-compiler/src/kani_middle/transform/mod.rs @@ -17,6 +17,7 @@ //! For all instrumentation passes, always use exhaustive matches to ensure soundness in case a new //! case is added. use crate::kani_middle::codegen_units::CodegenUnit; +use crate::kani_middle::reachability::CallGraph; use crate::kani_middle::transform::body::CheckType; use crate::kani_middle::transform::check_uninit::UninitPass; use crate::kani_middle::transform::check_values::ValidValuePass; @@ -24,8 +25,9 @@ use crate::kani_middle::transform::contracts::AnyModifiesPass; use crate::kani_middle::transform::kani_intrinsics::IntrinsicGeneratorPass; use crate::kani_middle::transform::stubs::{ExternFnStubPass, FnStubPass}; use crate::kani_queries::QueryDb; +use dump_mir_pass::DumpMirPass; use rustc_middle::ty::TyCtxt; -use stable_mir::mir::mono::Instance; +use stable_mir::mir::mono::{Instance, MonoItem}; use stable_mir::mir::Body; use std::collections::HashMap; use std::fmt::Debug; @@ -34,6 +36,7 @@ pub(crate) mod body; mod check_uninit; mod check_values; mod contracts; +mod dump_mir_pass; mod kani_intrinsics; mod stubs; @@ -90,7 +93,8 @@ impl BodyTransformation { transformer } - /// Retrieve the body of an instance. + /// Retrieve the body of an instance. This does not apply global passes, but will retrieve the + /// body after global passes running if they were previously applied. /// /// Note that this assumes that the instance does have a body since existing consumers already /// assume that. Use `instance.has_body()` to check if an instance has a body. @@ -152,6 +156,23 @@ pub(crate) trait TransformPass: Debug { fn transform(&mut self, tcx: TyCtxt, body: Body, instance: Instance) -> (bool, Body); } +/// A trait to represent transformation passes that operate on the whole codegen unit. +pub(crate) trait GlobalPass: Debug { + fn is_enabled(&self, query_db: &QueryDb) -> bool + where + Self: Sized; + + /// Run a transformation pass on the whole codegen unit. + fn transform( + &mut self, + tcx: TyCtxt, + call_graph: &CallGraph, + starting_items: &[MonoItem], + instances: Vec, + transformer: &mut BodyTransformation, + ); +} + /// The transformation result. /// We currently only cache the body of functions that were instrumented. #[derive(Clone, Debug)] @@ -159,3 +180,37 @@ enum TransformationResult { Modified(Body), NotModified, } + +pub struct GlobalPasses { + /// The passes that operate on the whole codegen unit, they run after all previous passes are + /// done. + global_passes: Vec>, +} + +impl GlobalPasses { + pub fn new(queries: &QueryDb, tcx: TyCtxt) -> Self { + let mut global_passes = GlobalPasses { global_passes: vec![] }; + global_passes.add_global_pass(queries, DumpMirPass::new(tcx)); + global_passes + } + + fn add_global_pass(&mut self, query_db: &QueryDb, pass: P) { + if pass.is_enabled(&query_db) { + self.global_passes.push(Box::new(pass)) + } + } + + /// Run all global passes and store the results in a cache that can later be queried by `body`. + pub fn run_global_passes( + &mut self, + transformer: &mut BodyTransformation, + tcx: TyCtxt, + starting_items: &[MonoItem], + instances: Vec, + call_graph: CallGraph, + ) { + for global_pass in self.global_passes.iter_mut() { + global_pass.transform(tcx, &call_graph, starting_items, instances.clone(), transformer); + } + } +}