Skip to content

Commit

Permalink
Auto merge of rust-lang#136428 - EnzymeAD:enable-autodiff, r=oli-obk
Browse files Browse the repository at this point in the history
test building enzyme in CI

1) This PR fixes a significant compile-time regression, by only running the expensive autodiff pipeline, if the users pass the newly introduced Enable value to the `-Zautodiff=` flag. It updates the test(s) accordingly. It gives a nice error if users forget that.
2) It fixes macos support by explicitly linking against the Enzyme build folder. This doesn't cover CI macos yet.
3) It fixes the issue that setting ENZYME_RUNPASS was ignored by enzyme and in fact did not schedule enzyme's opt pass.
4) It also re-enables support for various other values for the autodiff flag, which were ignored since the refactor.
5) I merged some improvements to Enzyme core, which means we do not longer depend on LLVM being build with the Plugin Interface enabled.
6) Unrelated to other fixes, this changes `rustc_autodiff` to `EncodeCrossCrate::Yes`. It is not enough on it's own to enable usage of Enzyme in libraries, but it is for sure a piece of the fixes needed to get this to work.

try-job: x86_64-gnu

r? `@oli-obk`

Tracking:

- rust-lang#124509
  • Loading branch information
bors committed Feb 22, 2025
2 parents b6d3be4 + 49e9630 commit 8dac72b
Show file tree
Hide file tree
Showing 20 changed files with 237 additions and 104 deletions.
1 change: 0 additions & 1 deletion compiler/rustc_ast/src/expand/autodiff_attrs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ use crate::{Ty, TyKind};
/// functions. The proper solution is to recognize and resolve this DAG of autodiff invocations,
/// as it's already done in the C++ and Julia frontend of Enzyme.
///
/// (FIXME) remove *First variants.
/// Documentation for using [reverse](https://enzyme.mit.edu/rust/rev.html) and
/// [forward](https://enzyme.mit.edu/rust/fwd.html) mode is available online.
#[derive(Clone, Copy, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)]
Expand Down
1 change: 1 addition & 0 deletions compiler/rustc_builtin_macros/src/autodiff.rs
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,7 @@ mod llvm_enzyme {
defaultness: ast::Defaultness::Final,
sig: d_sig,
generics: Generics::default(),
contract: None,
body: Some(d_body),
});
let mut rustc_ad_attr =
Expand Down
1 change: 1 addition & 0 deletions compiler/rustc_codegen_llvm/messages.ftl
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
codegen_llvm_autodiff_without_enable = using the autodiff feature requires -Z autodiff=Enable
codegen_llvm_autodiff_without_lto = using the autodiff feature requires using fat-lto
codegen_llvm_copy_bitcode = failed to copy bitcode to object file: {$err}
Expand Down
85 changes: 62 additions & 23 deletions compiler/rustc_codegen_llvm/src/back/lto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -586,6 +586,42 @@ fn thin_lto(
}
}

fn enable_autodiff_settings(ad: &[config::AutoDiff], module: &mut ModuleCodegen<ModuleLlvm>) {
for &val in ad {
match val {
config::AutoDiff::PrintModBefore => {
unsafe { llvm::LLVMDumpModule(module.module_llvm.llmod()) };
}
config::AutoDiff::PrintPerf => {
llvm::set_print_perf(true);
}
config::AutoDiff::PrintAA => {
llvm::set_print_activity(true);
}
config::AutoDiff::PrintTA => {
llvm::set_print_type(true);
}
config::AutoDiff::Inline => {
llvm::set_inline(true);
}
config::AutoDiff::LooseTypes => {
llvm::set_loose_types(false);
}
config::AutoDiff::PrintSteps => {
llvm::set_print(true);
}
// We handle this below
config::AutoDiff::PrintModAfter => {}
// This is required and already checked
config::AutoDiff::Enable => {}
}
}
// This helps with handling enums for now.
llvm::set_strict_aliasing(false);
// FIXME(ZuseZ4): Test this, since it was added a long time ago.
llvm::set_rust_rules(true);
}

pub(crate) fn run_pass_manager(
cgcx: &CodegenContext<LlvmCodegenBackend>,
dcx: DiagCtxtHandle<'_>,
Expand All @@ -604,34 +640,37 @@ pub(crate) fn run_pass_manager(
let opt_stage = if thin { llvm::OptStage::ThinLTO } else { llvm::OptStage::FatLTO };
let opt_level = config.opt_level.unwrap_or(config::OptLevel::No);

// If this rustc version was build with enzyme/autodiff enabled, and if users applied the
// `#[autodiff]` macro at least once, then we will later call llvm_optimize a second time.
debug!("running llvm pm opt pipeline");
// The PostAD behavior is the same that we would have if no autodiff was used.
// It will run the default optimization pipeline. If AD is enabled we select
// the DuringAD stage, which will disable vectorization and loop unrolling, and
// schedule two autodiff optimization + differentiation passes.
// We then run the llvm_optimize function a second time, to optimize the code which we generated
// in the enzyme differentiation pass.
let enable_ad = config.autodiff.contains(&config::AutoDiff::Enable);
let stage =
if enable_ad { write::AutodiffStage::DuringAD } else { write::AutodiffStage::PostAD };

if enable_ad {
enable_autodiff_settings(&config.autodiff, module);
}

unsafe {
write::llvm_optimize(
cgcx,
dcx,
module,
config,
opt_level,
opt_stage,
write::AutodiffStage::DuringAD,
)?;
write::llvm_optimize(cgcx, dcx, module, config, opt_level, opt_stage, stage)?;
}
// FIXME(ZuseZ4): Make this more granular
if cfg!(llvm_enzyme) && !thin {

if cfg!(llvm_enzyme) && enable_ad {
let opt_stage = llvm::OptStage::FatLTO;
let stage = write::AutodiffStage::PostAD;
unsafe {
write::llvm_optimize(
cgcx,
dcx,
module,
config,
opt_level,
llvm::OptStage::FatLTO,
write::AutodiffStage::PostAD,
)?;
write::llvm_optimize(cgcx, dcx, module, config, opt_level, opt_stage, stage)?;
}

// This is the final IR, so people should be able to inspect the optimized autodiff output.
if config.autodiff.contains(&config::AutoDiff::PrintModAfter) {
unsafe { llvm::LLVMDumpModule(module.module_llvm.llmod()) };
}
}

debug!("lto done");
Ok(())
}
Expand Down
15 changes: 5 additions & 10 deletions compiler/rustc_codegen_llvm/src/back/write.rs
Original file line number Diff line number Diff line change
Expand Up @@ -564,19 +564,16 @@ pub(crate) unsafe fn llvm_optimize(
// FIXME(ZuseZ4): In a future update we could figure out how to only optimize individual functions getting
// differentiated.

let consider_ad = cfg!(llvm_enzyme) && config.autodiff.contains(&config::AutoDiff::Enable);
let run_enzyme = autodiff_stage == AutodiffStage::DuringAD;
let unroll_loops;
let vectorize_slp;
let vectorize_loop;
let run_enzyme = cfg!(llvm_enzyme) && autodiff_stage == AutodiffStage::DuringAD;

// When we build rustc with enzyme/autodiff support, we want to postpone size-increasing
// optimizations until after differentiation. Our pipeline is thus: (opt + enzyme), (full opt).
// We therefore have two calls to llvm_optimize, if autodiff is used.
//
// FIXME(ZuseZ4): Before shipping on nightly,
// we should make this more granular, or at least check that the user has at least one autodiff
// call in their code, to justify altering the compilation pipeline.
if cfg!(llvm_enzyme) && autodiff_stage != AutodiffStage::PostAD {
if consider_ad && autodiff_stage != AutodiffStage::PostAD {
unroll_loops = false;
vectorize_slp = false;
vectorize_loop = false;
Expand Down Expand Up @@ -706,10 +703,8 @@ pub(crate) unsafe fn optimize(

// If we know that we will later run AD, then we disable vectorization and loop unrolling.
// Otherwise we pretend AD is already done and run the normal opt pipeline (=PostAD).
// FIXME(ZuseZ4): Make this more granular, only set PreAD if we actually have autodiff
// usages, not just if we build rustc with autodiff support.
let autodiff_stage =
if cfg!(llvm_enzyme) { AutodiffStage::PreAD } else { AutodiffStage::PostAD };
let consider_ad = cfg!(llvm_enzyme) && config.autodiff.contains(&config::AutoDiff::Enable);
let autodiff_stage = if consider_ad { AutodiffStage::PreAD } else { AutodiffStage::PostAD };
return unsafe {
llvm_optimize(cgcx, dcx, module, config, opt_level, opt_stage, autodiff_stage)
};
Expand Down
13 changes: 9 additions & 4 deletions compiler/rustc_codegen_llvm/src/builder/autodiff.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use crate::back::write::llvm_err;
use crate::builder::SBuilder;
use crate::context::SimpleCx;
use crate::declare::declare_simple_fn;
use crate::errors::LlvmError;
use crate::errors::{AutoDiffWithoutEnable, LlvmError};
use crate::llvm::AttributePlace::Function;
use crate::llvm::{Metadata, True};
use crate::value::Value;
Expand Down Expand Up @@ -46,9 +46,6 @@ fn generate_enzyme_call<'ll>(
let output = attrs.ret_activity;

// We have to pick the name depending on whether we want forward or reverse mode autodiff.
// FIXME(ZuseZ4): The new pass based approach should not need the {Forward/Reverse}First method anymore, since
// it will handle higher-order derivatives correctly automatically (in theory). Currently
// higher-order derivatives fail, so we should debug that before adjusting this code.
let mut ad_name: String = match attrs.mode {
DiffMode::Forward => "__enzyme_fwddiff",
DiffMode::Reverse => "__enzyme_autodiff",
Expand Down Expand Up @@ -291,6 +288,14 @@ pub(crate) fn differentiate<'ll>(
let diag_handler = cgcx.create_dcx();
let cx = SimpleCx { llmod: module.module_llvm.llmod(), llcx: module.module_llvm.llcx };

// First of all, did the user try to use autodiff without using the -Zautodiff=Enable flag?
if !diff_items.is_empty()
&& !cgcx.opts.unstable_opts.autodiff.contains(&rustc_session::config::AutoDiff::Enable)
{
let dcx = cgcx.create_dcx();
return Err(dcx.handle().emit_almost_fatal(AutoDiffWithoutEnable));
}

// Before dumping the module, we want all the TypeTrees to become part of the module.
for item in diff_items.iter() {
let name = item.source.clone();
Expand Down
5 changes: 4 additions & 1 deletion compiler/rustc_codegen_llvm/src/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -92,9 +92,12 @@ impl<G: EmissionGuarantee> Diagnostic<'_, G> for ParseTargetMachineConfig<'_> {

#[derive(Diagnostic)]
#[diag(codegen_llvm_autodiff_without_lto)]
#[note]
pub(crate) struct AutoDiffWithoutLTO;

#[derive(Diagnostic)]
#[diag(codegen_llvm_autodiff_without_enable)]
pub(crate) struct AutoDiffWithoutEnable;

#[derive(Diagnostic)]
#[diag(codegen_llvm_lto_disallowed)]
pub(crate) struct LtoDisallowed;
Expand Down
94 changes: 94 additions & 0 deletions compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,3 +35,97 @@ pub enum LLVMRustVerifierFailureAction {
LLVMPrintMessageAction = 1,
LLVMReturnStatusAction = 2,
}

#[cfg(llvm_enzyme)]
pub use self::Enzyme_AD::*;

#[cfg(llvm_enzyme)]
pub mod Enzyme_AD {
use libc::c_void;
extern "C" {
pub fn EnzymeSetCLBool(arg1: *mut ::std::os::raw::c_void, arg2: u8);
}
extern "C" {
static mut EnzymePrintPerf: c_void;
static mut EnzymePrintActivity: c_void;
static mut EnzymePrintType: c_void;
static mut EnzymePrint: c_void;
static mut EnzymeStrictAliasing: c_void;
static mut looseTypeAnalysis: c_void;
static mut EnzymeInline: c_void;
static mut RustTypeRules: c_void;
}
pub fn set_print_perf(print: bool) {
unsafe {
EnzymeSetCLBool(std::ptr::addr_of_mut!(EnzymePrintPerf), print as u8);
}
}
pub fn set_print_activity(print: bool) {
unsafe {
EnzymeSetCLBool(std::ptr::addr_of_mut!(EnzymePrintActivity), print as u8);
}
}
pub fn set_print_type(print: bool) {
unsafe {
EnzymeSetCLBool(std::ptr::addr_of_mut!(EnzymePrintType), print as u8);
}
}
pub fn set_print(print: bool) {
unsafe {
EnzymeSetCLBool(std::ptr::addr_of_mut!(EnzymePrint), print as u8);
}
}
pub fn set_strict_aliasing(strict: bool) {
unsafe {
EnzymeSetCLBool(std::ptr::addr_of_mut!(EnzymeStrictAliasing), strict as u8);
}
}
pub fn set_loose_types(loose: bool) {
unsafe {
EnzymeSetCLBool(std::ptr::addr_of_mut!(looseTypeAnalysis), loose as u8);
}
}
pub fn set_inline(val: bool) {
unsafe {
EnzymeSetCLBool(std::ptr::addr_of_mut!(EnzymeInline), val as u8);
}
}
pub fn set_rust_rules(val: bool) {
unsafe {
EnzymeSetCLBool(std::ptr::addr_of_mut!(RustTypeRules), val as u8);
}
}
}

#[cfg(not(llvm_enzyme))]
pub use self::Fallback_AD::*;

#[cfg(not(llvm_enzyme))]
pub mod Fallback_AD {
#![allow(unused_variables)]

pub fn set_inline(val: bool) {
unimplemented!()
}
pub fn set_print_perf(print: bool) {
unimplemented!()
}
pub fn set_print_activity(print: bool) {
unimplemented!()
}
pub fn set_print_type(print: bool) {
unimplemented!()
}
pub fn set_print(print: bool) {
unimplemented!()
}
pub fn set_strict_aliasing(strict: bool) {
unimplemented!()
}
pub fn set_loose_types(loose: bool) {
unimplemented!()
}
pub fn set_rust_rules(val: bool) {
unimplemented!()
}
}
3 changes: 2 additions & 1 deletion compiler/rustc_codegen_ssa/src/back/write.rs
Original file line number Diff line number Diff line change
Expand Up @@ -405,7 +405,8 @@ fn generate_lto_work<B: ExtraBackendMethods>(
B::run_fat_lto(cgcx, needs_fat_lto, import_only_modules).unwrap_or_else(|e| e.raise());
if cgcx.lto == Lto::Fat && !autodiff.is_empty() {
let config = cgcx.config(ModuleKind::Regular);
module = unsafe { module.autodiff(cgcx, autodiff, config).unwrap() };
module =
unsafe { module.autodiff(cgcx, autodiff, config).unwrap_or_else(|e| e.raise()) };
}
// We are adding a single work item, so the cost doesn't matter.
vec![(WorkItem::LTO(module), 0)]
Expand Down
2 changes: 1 addition & 1 deletion compiler/rustc_feature/src/builtin_attrs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -743,7 +743,7 @@ pub static BUILTIN_ATTRIBUTES: &[BuiltinAttribute] = &[
rustc_attr!(
rustc_autodiff, Normal,
template!(Word, List: r#""...""#), DuplicatesOk,
EncodeCrossCrate::No, INTERNAL_UNSTABLE
EncodeCrossCrate::Yes, INTERNAL_UNSTABLE
),

// ==========================================================================
Expand Down
2 changes: 1 addition & 1 deletion compiler/rustc_interface/src/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -759,7 +759,7 @@ fn test_unstable_options_tracking_hash() {
tracked!(allow_features, Some(vec![String::from("lang_items")]));
tracked!(always_encode_mir, true);
tracked!(assume_incomplete_release, true);
tracked!(autodiff, vec![AutoDiff::Print]);
tracked!(autodiff, vec![AutoDiff::Enable]);
tracked!(binary_dep_depinfo, true);
tracked!(box_noalias, false);
tracked!(
Expand Down
9 changes: 6 additions & 3 deletions compiler/rustc_llvm/llvm-wrapper/PassWrapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -692,9 +692,12 @@ struct LLVMRustSanitizerOptions {
bool SanitizeKernelAddressRecover;
};

// This symbol won't be available or used when Enzyme is not enabled
// This symbol won't be available or used when Enzyme is not enabled.
// Always set AugmentPassBuilder to true, since it registers optimizations which
// will improve the performance for Enzyme.
#ifdef ENZYME
extern "C" void registerEnzyme(llvm::PassBuilder &PB);
extern "C" void registerEnzymeAndPassPipeline(llvm::PassBuilder &PB,
/* augmentPassBuilder */ bool);
#endif

extern "C" LLVMRustResult LLVMRustOptimize(
Expand Down Expand Up @@ -1023,7 +1026,7 @@ extern "C" LLVMRustResult LLVMRustOptimize(
// now load "-enzyme" pass:
#ifdef ENZYME
if (RunEnzyme) {
registerEnzyme(PB);
registerEnzymeAndPassPipeline(PB, true);
if (auto Err = PB.parsePassPipeline(MPM, "enzyme")) {
std::string ErrMsg = toString(std::move(Err));
LLVMRustSetLastError(ErrMsg.c_str());
Expand Down
2 changes: 1 addition & 1 deletion compiler/rustc_monomorphize/src/partitioning/autodiff.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ pub(crate) fn find_autodiff_source_functions<'tcx>(
let mut autodiff_items: Vec<AutoDiffItem> = vec![];
for (item, instance) in autodiff_mono_items {
let target_id = instance.def_id();
let cg_fn_attr = tcx.codegen_fn_attrs(target_id).autodiff_item.clone();
let cg_fn_attr = &tcx.codegen_fn_attrs(target_id).autodiff_item;
let Some(target_attrs) = cg_fn_attr else {
continue;
};
Expand Down
Loading

0 comments on commit 8dac72b

Please sign in to comment.