From 3e4e28aa4021ebd3dc1adc9d2a899d8fcbfe7f11 Mon Sep 17 00:00:00 2001 From: Manuel Drehwald Date: Mon, 30 Oct 2023 01:28:50 -0400 Subject: [PATCH] Cleanup Rust-Enzyme history Co-authored-by: Lorenz Schmidt git@lorenzschmidt.com Co-authored-by: William Moses gh@wsmoses.com --- .github/workflows/enzyme-ci.yml | 38 ++ .gitmodules | 3 + Cargo.lock | 1 + Cargo.toml | 1 + README.md | 63 ++- compiler/rustc_ast/src/mut_visit.rs | 2 +- compiler/rustc_codegen_llvm/src/attributes.rs | 3 + compiler/rustc_codegen_llvm/src/back/lto.rs | 3 +- compiler/rustc_codegen_llvm/src/back/write.rs | 306 ++++++++++- compiler/rustc_codegen_llvm/src/base.rs | 12 +- compiler/rustc_codegen_llvm/src/context.rs | 4 + compiler/rustc_codegen_llvm/src/lib.rs | 44 +- compiler/rustc_codegen_llvm/src/llvm/ffi.rs | 484 +++++++++++++++++- compiler/rustc_codegen_llvm/src/typetree.rs | 33 ++ .../src/assert_module_sources.rs | 2 +- compiler/rustc_codegen_ssa/src/back/lto.rs | 25 +- .../src/back/symbol_export.rs | 2 +- compiler/rustc_codegen_ssa/src/back/write.rs | 47 +- compiler/rustc_codegen_ssa/src/base.rs | 9 +- .../rustc_codegen_ssa/src/codegen_attrs.rs | 162 +++++- compiler/rustc_codegen_ssa/src/traits/misc.rs | 1 + .../rustc_codegen_ssa/src/traits/write.rs | 12 + compiler/rustc_feature/src/builtin_attrs.rs | 7 + compiler/rustc_interface/src/tests.rs | 1 + .../rustc_llvm/llvm-wrapper/RustWrapper.cpp | 5 + compiler/rustc_middle/src/arena.rs | 1 + .../rustc_middle/src/middle/autodiff_attrs.rs | 94 ++++ compiler/rustc_middle/src/middle/mod.rs | 2 + compiler/rustc_middle/src/middle/typetree.rs | 39 ++ compiler/rustc_middle/src/query/erase.rs | 4 + compiler/rustc_middle/src/query/mod.rs | 10 +- compiler/rustc_middle/src/ty/mod.rs | 2 + compiler/rustc_monomorphize/Cargo.toml | 1 + compiler/rustc_monomorphize/src/collector.rs | 1 + .../rustc_monomorphize/src/partitioning.rs | 4 +- compiler/rustc_passes/src/check_attr.rs | 15 + compiler/rustc_resolve/src/lib.rs | 1 + compiler/rustc_session/src/options.rs | 2 + compiler/rustc_span/src/symbol.rs | 2 + config.example.toml | 3 + library/autodiff/Cargo.lock | 314 ++++++++++++ library/autodiff/Cargo.toml | 28 + library/autodiff/examples/array.rs | 23 + library/autodiff/examples/box.rs | 24 + library/autodiff/examples/broken_matvec.rs | 34 ++ library/autodiff/examples/hessian_sin.rs | 28 + library/autodiff/examples/ndarray.rs | 25 + library/autodiff/examples/rosenbrock_fwd.rs | 34 ++ .../autodiff/examples/rosenbrock_fwd_iter.rs | 34 ++ library/autodiff/examples/rosenbrock_rev.rs | 33 ++ library/autodiff/examples/sin.rs | 36 ++ library/autodiff/examples/sqrt.rs | 21 + library/autodiff/examples/struct.rs | 33 ++ library/autodiff/examples/vec.rs | 24 + library/autodiff/examples_broken/biquad.rs | 54 ++ .../autodiff/examples_broken/broken_iter.rs | 20 + .../examples_broken/broken_recursive.rs | 66 +++ .../examples_broken/broken_second_order.rs | 17 + library/autodiff/src/gen.rs | 217 ++++++++ library/autodiff/src/lib.rs | 31 ++ library/autodiff/src/parser.rs | 464 +++++++++++++++++ .../expand/forward_duplicated.expanded.rs | 10 + .../tests/expand/forward_duplicated.rs | 6 + .../forward_duplicated_return.expanded.rs | 15 + .../tests/expand/forward_duplicated_return.rs | 6 + .../expand/reverse_duplicated.expanded.rs | 10 + .../tests/expand/reverse_duplicated.rs | 6 + .../expand/reverse_return_array.expanded.rs | 10 + .../tests/expand/reverse_return_array.rs | 6 + .../expand/reverse_return_mixed.expanded.rs | 17 + .../tests/expand/reverse_return_mixed.rs | 6 + .../tests/ui/active_in_forward_mode.rs | 6 + .../tests/ui/active_in_forward_mode.stderr | 7 + .../tests/ui/activities_inline_and_header.rs | 6 + .../ui/activities_inline_and_header.stderr | 7 + .../autodiff/tests/ui/invalid_indirection.rs | 19 + .../tests/ui/invalid_indirection.stderr | 31 ++ .../tests/ui/invalid_mutability_pairs.rs | 24 + .../tests/ui/invalid_mutability_pairs.stderr | 55 ++ library/autodiff/tests/ui/invalid_return.rs | 12 + .../autodiff/tests/ui/invalid_return.stderr | 23 + .../autodiff/tests/ui/invalid_return_type.rs | 16 + .../tests/ui/invalid_return_type.stderr | 31 ++ library/autodiff/tests/ui/no_function_name.rs | 6 + .../autodiff/tests/ui/no_function_name.stderr | 8 + library/autodiff/tests/ui/not_a_function.rs | 6 + .../autodiff/tests/ui/not_a_function.stderr | 7 + library/autodiff/tests/ui/reverse_tangent.rs | 12 + .../autodiff/tests/ui/reverse_tangent.stderr | 23 + library/autodiff/tests/ui/wrong_mode.rs | 6 + library/autodiff/tests/ui/wrong_mode.stderr | 7 + library/core/src/macros/mod.rs | 12 + src/bootstrap/configure.py | 1 + src/bootstrap/src/core/build_steps/compile.rs | 19 + src/bootstrap/src/core/build_steps/llvm.rs | 66 +++ src/bootstrap/src/core/builder.rs | 5 + src/bootstrap/src/core/config/config.rs | 25 +- src/bootstrap/src/lib.rs | 4 + src/test/ui/terminal-width/flag-human.rs | 9 + src/test/ui/terminal-width/flag-json.rs | 9 + src/test/ui/terminal-width/flag-json.stderr | 40 ++ src/tools/enzyme | 1 + tests/rustdoc-ui/doctest/terminal-width.rs | 5 + .../rustdoc-ui/doctest/terminal-width.stderr | 15 + tests/ui/json/autodiff.rs | 16 + 105 files changed, 3600 insertions(+), 42 deletions(-) create mode 100644 .github/workflows/enzyme-ci.yml create mode 100644 compiler/rustc_codegen_llvm/src/typetree.rs create mode 100644 compiler/rustc_middle/src/middle/autodiff_attrs.rs create mode 100644 compiler/rustc_middle/src/middle/typetree.rs create mode 100644 library/autodiff/Cargo.lock create mode 100644 library/autodiff/Cargo.toml create mode 100644 library/autodiff/examples/array.rs create mode 100644 library/autodiff/examples/box.rs create mode 100644 library/autodiff/examples/broken_matvec.rs create mode 100644 library/autodiff/examples/hessian_sin.rs create mode 100644 library/autodiff/examples/ndarray.rs create mode 100644 library/autodiff/examples/rosenbrock_fwd.rs create mode 100644 library/autodiff/examples/rosenbrock_fwd_iter.rs create mode 100644 library/autodiff/examples/rosenbrock_rev.rs create mode 100644 library/autodiff/examples/sin.rs create mode 100644 library/autodiff/examples/sqrt.rs create mode 100644 library/autodiff/examples/struct.rs create mode 100644 library/autodiff/examples/vec.rs create mode 100644 library/autodiff/examples_broken/biquad.rs create mode 100644 library/autodiff/examples_broken/broken_iter.rs create mode 100644 library/autodiff/examples_broken/broken_recursive.rs create mode 100644 library/autodiff/examples_broken/broken_second_order.rs create mode 100644 library/autodiff/src/gen.rs create mode 100644 library/autodiff/src/lib.rs create mode 100644 library/autodiff/src/parser.rs create mode 100644 library/autodiff/tests/expand/forward_duplicated.expanded.rs create mode 100644 library/autodiff/tests/expand/forward_duplicated.rs create mode 100644 library/autodiff/tests/expand/forward_duplicated_return.expanded.rs create mode 100644 library/autodiff/tests/expand/forward_duplicated_return.rs create mode 100644 library/autodiff/tests/expand/reverse_duplicated.expanded.rs create mode 100644 library/autodiff/tests/expand/reverse_duplicated.rs create mode 100644 library/autodiff/tests/expand/reverse_return_array.expanded.rs create mode 100644 library/autodiff/tests/expand/reverse_return_array.rs create mode 100644 library/autodiff/tests/expand/reverse_return_mixed.expanded.rs create mode 100644 library/autodiff/tests/expand/reverse_return_mixed.rs create mode 100644 library/autodiff/tests/ui/active_in_forward_mode.rs create mode 100644 library/autodiff/tests/ui/active_in_forward_mode.stderr create mode 100644 library/autodiff/tests/ui/activities_inline_and_header.rs create mode 100644 library/autodiff/tests/ui/activities_inline_and_header.stderr create mode 100644 library/autodiff/tests/ui/invalid_indirection.rs create mode 100644 library/autodiff/tests/ui/invalid_indirection.stderr create mode 100644 library/autodiff/tests/ui/invalid_mutability_pairs.rs create mode 100644 library/autodiff/tests/ui/invalid_mutability_pairs.stderr create mode 100644 library/autodiff/tests/ui/invalid_return.rs create mode 100644 library/autodiff/tests/ui/invalid_return.stderr create mode 100644 library/autodiff/tests/ui/invalid_return_type.rs create mode 100644 library/autodiff/tests/ui/invalid_return_type.stderr create mode 100644 library/autodiff/tests/ui/no_function_name.rs create mode 100644 library/autodiff/tests/ui/no_function_name.stderr create mode 100644 library/autodiff/tests/ui/not_a_function.rs create mode 100644 library/autodiff/tests/ui/not_a_function.stderr create mode 100644 library/autodiff/tests/ui/reverse_tangent.rs create mode 100644 library/autodiff/tests/ui/reverse_tangent.stderr create mode 100644 library/autodiff/tests/ui/wrong_mode.rs create mode 100644 library/autodiff/tests/ui/wrong_mode.stderr create mode 100644 src/test/ui/terminal-width/flag-human.rs create mode 100644 src/test/ui/terminal-width/flag-json.rs create mode 100644 src/test/ui/terminal-width/flag-json.stderr create mode 160000 src/tools/enzyme create mode 100644 tests/rustdoc-ui/doctest/terminal-width.rs create mode 100644 tests/rustdoc-ui/doctest/terminal-width.stderr create mode 100644 tests/ui/json/autodiff.rs diff --git a/.github/workflows/enzyme-ci.yml b/.github/workflows/enzyme-ci.yml new file mode 100644 index 0000000000000..4064d4709a5ed --- /dev/null +++ b/.github/workflows/enzyme-ci.yml @@ -0,0 +1,38 @@ +name: Rust CI + +on: + push: + branches: + - master + pull_request: + branches: + - master + merge_group: + +jobs: + build: + name: Rust Integration CI LLVM ${{ matrix.llvm }} ${{ matrix.build }} ${{ matrix.os }} + runs-on: ${{ matrix.os }} + + strategy: + fail-fast: false + matrix: + os: [openstack22] + + timeout-minutes: 600 + steps: + - name: checkout the source code + uses: actions/checkout@v3 + with: + fetch-depth: 2 + - name: build + run: | + mkdir build + cd build + ../configure --enable-llvm-link-shared --enable-llvm-plugins --enable-llvm-enzyme --release-channel=nightly --enable-llvm-assertions --enable-clang --enable-lld --enable-option-checking --enable-ninja --disable-docs + ../x.py build --stage 1 library/std library/proc_macro library/test tools/rustdoc + rustup toolchain link enzyme `pwd`/build/`rustup target list --installed`/stage1 + rustup toolchain install nightly # enables -Z unstable-options + - name: test + run: | + cargo +enzyme test --examples diff --git a/.gitmodules b/.gitmodules index f5025097a18dc..7e217cb215dd8 100644 --- a/.gitmodules +++ b/.gitmodules @@ -43,3 +43,6 @@ path = library/backtrace url = https://github.com/rust-lang/backtrace-rs.git shallow = true +[submodule "src/tools/enzyme"] + path = src/tools/enzyme + url = https://github.com/EnzymeAD/Enzyme.git diff --git a/Cargo.lock b/Cargo.lock index 0761268c9d411..06be5561ceec4 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4312,6 +4312,7 @@ dependencies = [ "rustc_middle", "rustc_session", "rustc_span", + "rustc_symbol_mangling", "rustc_target", "serde", "serde_json", diff --git a/Cargo.toml b/Cargo.toml index 9b11ae8744b4f..ab42109434320 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -61,6 +61,7 @@ exclude = [ "src/tools/x", # stdarch has its own Cargo workspace "library/stdarch", + "library/autodiff", ] [profile.release.package.compiler_builtins] diff --git a/README.md b/README.md index a88ee4b8bf061..2ac2d3b38d679 100644 --- a/README.md +++ b/README.md @@ -1,9 +1,68 @@ -# The Rust Programming Language +# The Rust Programming Language +Enzyme [![Rust Community](https://img.shields.io/badge/Rust_Community%20-Join_us-brightgreen?style=plastic&logo=rust)](https://www.rust-lang.org/community) This is the main source code repository for [Rust]. It contains the compiler, -standard library, and documentation. +standard library, and documentation. It is modified to use Enzyme for AutoDiff. + +Please configure this fork using the following command: + +``` +mkdir build +cd build +../configure --enable-llvm-link-shared --enable-llvm-plugins --enable-llvm-enzyme --release-channel=nightly --enable-llvm-assertions --enable-clang --enable-lld --enable-option-checking --enable-ninja --disable-docs +``` + +Afterwards you can build rustc using: +``` +../x.py build --stage 1 library/std library/proc_macro library/test tools/rustdoc +``` + +Afterwards rustc toolchain link will allow you to use it through cargo: +``` +rustup toolchain link enzyme `pwd`/build/`rustup target list --installed`/stage1 +rustup toolchain install nightly # enables -Z unstable-options +``` + +You can then look at examples in the `library/autodiff/examples/*` folder and run them with + +```bash +# rosenbrock forward iteration +cargo +enzyme run --example rosenbrock_fwd_iter --release + +# or all of them +cargo +enzyme test --examples +``` + +## Enzyme Config +To help with debugging, Enzyme can be configured using environment variables. +```bash +export ENZYME_PRINT_TA=1 +export ENZYME_PRINT_AA=1 +export ENZYME_PRINT=1 +export ENZYME_PRINT_MOD=1 +export ENZYME_PRINT_MOD_AFTER=1 +``` +The first three will print TypeAnalysis, ActivityAnalysis and the llvm-ir on a function basis, respectively. +The last two variables will print the whole module directly before and after Enzyme differented the functions. + +When experimenting with flags please make sure that EnzymeStrictAliasing=0 +is not changed, since it is required for Enzyme to handle enums correctly. + +## Bug reporting +Bugs are pretty much expected at this point of the development process. +In order to help us please minimize the Rust code as far as possible. +This tool might be a nicer helper: https://github.com/Nilstrieb/cargo-minimize +If you have some knowledge of LLVM-IR we also greatly appreciate it if you could help +us by compiling your minimized Rust code to LLVM-IR and reducing it further. + +The only exception to this strategy is error based on "Can not deduce type of X", +where reducing your example will make it harder for us to understand the origin of the bug. +In this case please just try to inline all dependencies into a single crate or even file, +without deleting used code. + + + [Rust]: https://www.rust-lang.org/ diff --git a/compiler/rustc_ast/src/mut_visit.rs b/compiler/rustc_ast/src/mut_visit.rs index 0634ee970ec5e..23e7975edd65b 100644 --- a/compiler/rustc_ast/src/mut_visit.rs +++ b/compiler/rustc_ast/src/mut_visit.rs @@ -381,7 +381,7 @@ pub fn visit_bounds(bounds: &mut GenericBounds, vis: &mut T) { } // No `noop_` prefix because there isn't a corresponding method in `MutVisitor`. -pub fn visit_fn_sig(FnSig { header, decl, span }: &mut FnSig, vis: &mut T) { +pub fn visit_fn_sig(FnSig { header, decl, span, .. }: &mut FnSig, vis: &mut T) { vis.visit_fn_header(header); vis.visit_fn_decl(decl); vis.visit_span(span); diff --git a/compiler/rustc_codegen_llvm/src/attributes.rs b/compiler/rustc_codegen_llvm/src/attributes.rs index b6c01545f308c..a586559016cd5 100644 --- a/compiler/rustc_codegen_llvm/src/attributes.rs +++ b/compiler/rustc_codegen_llvm/src/attributes.rs @@ -285,6 +285,7 @@ pub fn from_fn_attrs<'ll, 'tcx>( instance: ty::Instance<'tcx>, ) { let codegen_fn_attrs = cx.tcx.codegen_fn_attrs(instance.def_id()); + let autodiff_attrs = cx.tcx.autodiff_attrs(instance.def_id()); let mut to_add = SmallVec::<[_; 16]>::new(); @@ -302,6 +303,8 @@ pub fn from_fn_attrs<'ll, 'tcx>( let inline = if codegen_fn_attrs.inline == InlineAttr::None && instance.def.requires_inline(cx.tcx) { InlineAttr::Hint + } else if autodiff_attrs.is_active() { + InlineAttr::Never } else { codegen_fn_attrs.inline }; diff --git a/compiler/rustc_codegen_llvm/src/back/lto.rs b/compiler/rustc_codegen_llvm/src/back/lto.rs index 8655aeec13dd6..c63870dfe4327 100644 --- a/compiler/rustc_codegen_llvm/src/back/lto.rs +++ b/compiler/rustc_codegen_llvm/src/back/lto.rs @@ -273,6 +273,7 @@ fn fat_lto( info!("pushing cached module {:?}", wp.cgu_name); (buffer, CString::new(wp.cgu_name).unwrap()) })); + for module in modules { match module { FatLtoInput::InMemory(m) => in_memory.push(m), @@ -734,7 +735,7 @@ pub unsafe fn optimize_thin_module( let llcx = llvm::LLVMRustContextCreate(cgcx.fewer_names); let llmod_raw = parse_module(llcx, module_name, thin_module.data(), &diag_handler)? as *const _; let mut module = ModuleCodegen { - module_llvm: ModuleLlvm { llmod_raw, llcx, tm }, + module_llvm: ModuleLlvm { llmod_raw, llcx, tm, typetrees: Default::default() }, name: thin_module.name().to_string(), kind: ModuleKind::Regular, }; diff --git a/compiler/rustc_codegen_llvm/src/back/write.rs b/compiler/rustc_codegen_llvm/src/back/write.rs index 9d5204034def0..153b09d867a29 100644 --- a/compiler/rustc_codegen_llvm/src/back/write.rs +++ b/compiler/rustc_codegen_llvm/src/back/write.rs @@ -10,11 +10,24 @@ use crate::errors::{ WithLlvmError, WriteBytecode, }; use crate::llvm::{self, DiagnosticInfo, PassManager}; +use crate::llvm::{LLVMReplaceAllUsesWith, LLVMVerifyFunction, Value}; use crate::llvm_util; use crate::type_::Type; +use crate::typetree::to_enzyme_typetree; use crate::LlvmCodegenBackend; use crate::ModuleLlvm; +use crate::{base, DiffTypeTree}; use llvm::{ + enzyme_rust_forward_diff, enzyme_rust_reverse_diff, BasicBlock, CreateEnzymeLogic, + CreateTypeAnalysis, EnzymeLogicRef, EnzymeTypeAnalysisRef, LLVMAddFunction, + LLVMAppendBasicBlockInContext, LLVMBuildCall2, LLVMBuildExtractValue, LLVMBuildRet, + LLVMCountParams, LLVMCountStructElementTypes, LLVMCreateBuilderInContext, LLVMDeleteFunction, + LLVMDisposeBuilder, LLVMGetBasicBlockTerminator, LLVMGetElementType, LLVMGetModuleContext, + LLVMGetParams, LLVMGetReturnType, LLVMPositionBuilderAtEnd, LLVMSetValueName2, LLVMTypeOf, + LLVMVoidTypeInContext, LLVMGlobalGetValueType, LLVMGetStringAttributeAtIndex, + LLVMIsStringAttribute, LLVMRemoveStringAttributeAtIndex, LLVMRemoveEnumAttributeAtIndex, AttributeKind, + LLVMGetFirstFunction, LLVMGetNextFunction, LLVMGetEnumAttributeAtIndex, LLVMIsEnumAttribute, + LLVMCreateStringAttribute, LLVMRustAddFunctionAttributes, LLVMCreateEnumAttribute, LLVMDumpModule, LLVMRustLLVMHasZlibCompressionForDebugSymbols, LLVMRustLLVMHasZstdCompressionForDebugSymbols, }; use rustc_codegen_ssa::back::link::ensure_removed; @@ -24,10 +37,12 @@ use rustc_codegen_ssa::back::write::{ }; use rustc_codegen_ssa::traits::*; use rustc_codegen_ssa::{CompiledModule, ModuleCodegen}; +use rustc_data_structures::fx::FxHashMap; use rustc_data_structures::profiling::SelfProfilerRef; use rustc_data_structures::small_c_str::SmallCStr; use rustc_errors::{FatalError, Handler, Level}; use rustc_fs_util::{link_or_copy, path_to_c_string}; +use rustc_middle::middle::autodiff_attrs::{AutoDiffItem, DiffActivity, DiffMode}; use rustc_middle::ty::TyCtxt; use rustc_session::config::{self, Lto, OutputType, Passes, SplitDwarfKind, SwitchWithOptPath}; use rustc_session::Session; @@ -37,7 +52,7 @@ use rustc_target::spec::{CodeModel, RelocModel, SanitizerSet, SplitDebuginfo}; use crate::llvm::diagnostic::OptimizationDiagnosticKind; use libc::{c_char, c_int, c_uint, c_void, size_t}; -use std::ffi::CString; +use std::ffi::{CStr, CString}; use std::fs; use std::io::{self, Write}; use std::path::{Path, PathBuf}; @@ -513,8 +528,18 @@ pub(crate) unsafe fn llvm_optimize( opt_level: config::OptLevel, opt_stage: llvm::OptStage, ) -> Result<(), FatalError> { - let unroll_loops = - opt_level != config::OptLevel::Size && opt_level != config::OptLevel::SizeMin; + // Enzyme: + // We want to simplify / optimize functions before AD. + // However, benchmarks show that optimizations increasing the code size + // tend to reduce AD performance. Therefore activate them first, then differentiate the code + // and finally re-optimize the module, now with all optimizations available. + // RIP compile time. + // let unroll_loops = + // opt_level != config::OptLevel::Size && opt_level != config::OptLevel::SizeMin; + let unroll_loops = false; + let vectorize_slp = false; + let vectorize_loop = false; + let using_thin_buffers = opt_stage == llvm::OptStage::PreLinkThinLTO || config.bitcode_needed(); let pgo_gen_path = get_pgo_gen_path(config); let pgo_use_path = get_pgo_use_path(config); @@ -569,8 +594,8 @@ pub(crate) unsafe fn llvm_optimize( using_thin_buffers, config.merge_functions, unroll_loops, - config.vectorize_slp, - config.vectorize_loop, + vectorize_slp, + vectorize_loop, config.no_builtins, config.emit_lifetime_markers, sanitizer_options.as_ref(), @@ -592,6 +617,255 @@ pub(crate) unsafe fn llvm_optimize( result.into_result().map_err(|()| llvm_err(diag_handler, LlvmError::RunLlvmPasses)) } +fn get_params(fnc: &Value) -> Vec<&Value> { + unsafe { + let param_num = LLVMCountParams(fnc) as usize; + let mut fnc_args: Vec<&Value> = vec![]; + fnc_args.reserve(param_num); + LLVMGetParams(fnc, fnc_args.as_mut_ptr()); + fnc_args.set_len(param_num); + fnc_args + } +} + +// TODO: cleanup +unsafe fn create_wrapper<'a>( + llmod: &'a llvm::Module, + //module: &'a ModuleCodegen, + fnc: &'a Value, + u_type: &Type, + fnc_name: String, +) -> (&'a Value, &'a BasicBlock, Vec<&'a Value>, Vec<&'a Value>, CString) { + //let llmod = module.module_llvm.llmod(); + let context = LLVMGetModuleContext(llmod); + let inner_fnc_name = "inner_".to_string() + &fnc_name; + let c_inner_fnc_name = CString::new(inner_fnc_name.clone()).unwrap(); + LLVMSetValueName2(fnc, c_inner_fnc_name.as_ptr(), inner_fnc_name.len() as usize); + + let c_outer_fnc_name = CString::new(fnc_name).unwrap(); + let outer_fnc: &Value = + LLVMAddFunction(llmod, c_outer_fnc_name.as_ptr(), LLVMGetElementType(u_type) as &Type); + + let entry = "fnc_entry".to_string(); + let c_entry = CString::new(entry).unwrap(); + let basic_block = LLVMAppendBasicBlockInContext(context, outer_fnc, c_entry.as_ptr()); + + let outer_params: Vec<&Value> = get_params(outer_fnc); + let inner_params: Vec<&Value> = get_params(fnc); + + (outer_fnc, basic_block, outer_params, inner_params, c_inner_fnc_name) +} + +//pub(crate) fn get_type(t: LLVMTypeRef) -> CString { +// unsafe { CString::from_raw(LLVMPrintTypeToString(t)) } +//} + +// TODO: Don't write a wrapper function, just unwrap the struct inside of the same fnc. +// Might help during debugging, if you have one function less to jump trough +pub(crate) unsafe fn extract_return_type<'a>( + llmod: &'a llvm::Module, + fnc: &'a Value, + u_type: &Type, + fnc_name: String, +) -> &'a Value { + //let llmod = module.module_llvm.llmod(); + let context = llvm::LLVMGetModuleContext(llmod); + //dbg!("Unpacking", fnc_name.clone()); + //dbg!("From: ", f_type, " into ", u_type); + + let inner_param_num = LLVMCountParams(fnc); + let (outer_fnc, outer_bb, mut outer_args, _inner_args, c_inner_fnc_name) = + create_wrapper(llmod, fnc, u_type, fnc_name); + + if inner_param_num as usize != outer_args.len() { + panic!("Args len shouldn't differ. Please report this."); + } + + let builder = LLVMCreateBuilderInContext(context); + LLVMPositionBuilderAtEnd(builder, outer_bb); + let struct_ret = LLVMBuildCall2( + builder, + u_type, + fnc, + outer_args.as_mut_ptr(), + outer_args.len(), + c_inner_fnc_name.as_ptr(), + ); + // We can use an arbitrary name here, since it will be used to store a tmp value. + let inner_grad_name = "foo".to_string(); + let c_inner_grad_name = CString::new(inner_grad_name).unwrap(); + let struct_ret = LLVMBuildExtractValue(builder, struct_ret, 0, c_inner_grad_name.as_ptr()); + let _ret = LLVMBuildRet(builder, struct_ret); + let _terminator = LLVMGetBasicBlockTerminator(outer_bb); + //assert!(LLVMIsNull(terminator)!=0, "no terminator"); + LLVMDisposeBuilder(builder); + + let _fnc_ok = + LLVMVerifyFunction(outer_fnc, llvm::LLVMVerifierFailureAction::LLVMAbortProcessAction); + //dbg!(outer_fnc); + //assert!(fnc_ok); + //if let Err(e) = verify_function(outer_fnc) { + // panic!("Creating a wrapper function failed! {}", e); + //} + + outer_fnc +} + +// As unsafe as it can be. +#[allow(unused_variables)] +#[allow(unused)] +pub(crate) unsafe fn enzyme_ad( + llmod: &llvm::Module, + llcx: &llvm::Context, + item: AutoDiffItem, +) -> Result<(), FatalError> { + let autodiff_mode = item.attrs.mode; + let rust_name = item.source; + let rust_name2 = &item.target; + + let args_activity = item.attrs.input_activity.clone(); + let ret_activity: DiffActivity = item.attrs.ret_activity; + + // get target and source function + let name = CString::new(rust_name.to_owned()).unwrap(); + let name2 = CString::new(rust_name2.clone()).unwrap(); + let src_fnc = llvm::LLVMGetNamedFunction(llmod, name.as_c_str().as_ptr()).unwrap(); + let target_fnc = llvm::LLVMGetNamedFunction(llmod, name2.as_ptr()).unwrap(); + + // create enzyme typetrees + let llvm_data_layout = unsafe { llvm::LLVMGetDataLayoutStr(&*llmod) }; + let llvm_data_layout = + std::str::from_utf8(unsafe { CStr::from_ptr(llvm_data_layout) }.to_bytes()) + .expect("got a non-UTF8 data-layout from LLVM"); + + let input_tts = + item.inputs.into_iter().map(|x| to_enzyme_typetree(x, llvm_data_layout, llcx)).collect(); + let output_tt = to_enzyme_typetree(item.output, llvm_data_layout, llcx); + + let opt = 1; + let ret_primary_ret = false; + let diff_primary_ret = false; + let logic_ref: EnzymeLogicRef = CreateEnzymeLogic(opt as u8); + let type_analysis: EnzymeTypeAnalysisRef = + CreateTypeAnalysis(logic_ref, std::ptr::null_mut(), std::ptr::null_mut(), 0); + + llvm::EnzymeSetCLBool(std::ptr::addr_of_mut!(llvm::EnzymeStrictAliasing), 0); + + if std::env::var("ENZYME_PRINT_TA").is_ok() { + llvm::EnzymeSetCLBool(std::ptr::addr_of_mut!(llvm::EnzymePrintType), 1); + } + if std::env::var("ENZYME_PRINT_AA").is_ok() { + llvm::EnzymeSetCLBool(std::ptr::addr_of_mut!(llvm::EnzymePrintActivity), 1); + } + if std::env::var("ENZYME_PRINT_PERF").is_ok() { + llvm::EnzymeSetCLBool(std::ptr::addr_of_mut!(llvm::EnzymePrintPerf), 1); + } + if std::env::var("ENZYME_PRINT").is_ok() { + llvm::EnzymeSetCLBool(std::ptr::addr_of_mut!(llvm::EnzymePrint), 1); + } + + let mut res: &Value = match item.attrs.mode { + DiffMode::Forward => enzyme_rust_forward_diff( + logic_ref, + type_analysis, + src_fnc, + args_activity, + ret_activity, + ret_primary_ret, + input_tts, + output_tt, + ), + DiffMode::Reverse => enzyme_rust_reverse_diff( + logic_ref, + type_analysis, + src_fnc, + args_activity, + ret_activity, + ret_primary_ret, + diff_primary_ret, + input_tts, + output_tt, + ), + _ => unreachable!(), + }; + let f_return_type = LLVMGetReturnType(LLVMGlobalGetValueType(res)); + + let void_type = LLVMVoidTypeInContext(llcx); + if item.attrs.mode == DiffMode::Reverse && f_return_type != void_type { + //dbg!("Reverse Mode sanitizer"); + //dbg!(f_type); + //dbg!(f_return_type); + let num_elem_in_ret_struct = LLVMCountStructElementTypes(f_return_type); + if num_elem_in_ret_struct == 1 { + let u_type = LLVMTypeOf(target_fnc); + res = extract_return_type(llmod, res, u_type, rust_name2.clone()); // TODO: check if name or name2 + } + } + //dbg!(&target_fnc); + LLVMSetValueName2(res, name2.as_ptr(), rust_name2.len()); + LLVMReplaceAllUsesWith(target_fnc, res); + LLVMDeleteFunction(target_fnc); + + Ok(()) +} + +pub(crate) unsafe fn differentiate( + module: &ModuleCodegen, + _cgcx: &CodegenContext, + diff_items: Vec, + _typetrees: FxHashMap, + _config: &ModuleConfig, +) -> Result<(), FatalError> { + let llmod = module.module_llvm.llmod(); + let llcx = &module.module_llvm.llcx; + + llvm::EnzymeSetCLBool(std::ptr::addr_of_mut!(llvm::EnzymeStrictAliasing), 0); + + if std::env::var("ENZYME_PRINT_MOD").is_ok() { + unsafe {LLVMDumpModule(llmod);} + } + if std::env::var("ENZYME_TT_DEPTH").is_ok() { + let depth = std::env::var("ENZYME_TT_DEPTH").unwrap(); + let depth = depth.parse::().unwrap(); + assert!(depth >= 1); + llvm::EnzymeSetCLInteger(std::ptr::addr_of_mut!(llvm::EnzymeMaxTypeDepth), depth); + } + if std::env::var("ENZYME_TT_WIDTH").is_ok() { + let width = std::env::var("ENZYME_TT_WIDTH").unwrap(); + let width = width.parse::().unwrap(); + assert!(width >= 1); + llvm::EnzymeSetCLInteger(std::ptr::addr_of_mut!(llvm::MaxTypeOffset), width); + } + + for item in diff_items { + let res = enzyme_ad(llmod, llcx, item); + assert!(res.is_ok()); + } + + let mut f = LLVMGetFirstFunction(llmod); + loop { + if let Some(lf) = f { + f = LLVMGetNextFunction(lf); + let myhwattr = "enzyme_hw"; + let attr = LLVMGetStringAttributeAtIndex(lf, c_uint::MAX, myhwattr.as_ptr() as *const c_char, myhwattr.as_bytes().len() as c_uint); + if LLVMIsStringAttribute(attr) { + LLVMRemoveStringAttributeAtIndex(lf, c_uint::MAX, myhwattr.as_ptr() as *const c_char, myhwattr.as_bytes().len() as c_uint); + } else { + LLVMRemoveEnumAttributeAtIndex(lf, c_uint::MAX, AttributeKind::SanitizeHWAddress); + } + + + } else { + break; + } + } + if std::env::var("ENZYME_PRINT_MOD_AFTER").is_ok() { + unsafe {LLVMDumpModule(llmod);} + } + + Ok(()) +} + // Unsafe due to LLVM calls. pub(crate) unsafe fn optimize( cgcx: &CodegenContext, @@ -615,6 +889,28 @@ pub(crate) unsafe fn optimize( llvm::LLVMWriteBitcodeToFile(llmod, out.as_ptr()); } + { + let mut f = LLVMGetFirstFunction(llmod); + loop { + if let Some(lf) = f { + f = LLVMGetNextFunction(lf); + let myhwattr = "enzyme_hw"; + let myhwv = ""; + let prevattr = LLVMGetEnumAttributeAtIndex(lf, c_uint::MAX, AttributeKind::SanitizeHWAddress); + if LLVMIsEnumAttribute(prevattr) { + let attr = LLVMCreateStringAttribute(llcx, myhwattr.as_ptr() as *const c_char, myhwattr.as_bytes().len() as c_uint, myhwv.as_ptr() as *const c_char, myhwv.as_bytes().len() as c_uint); + LLVMRustAddFunctionAttributes(lf, c_uint::MAX, &attr, 1); + } else { + let attr = LLVMCreateEnumAttribute(llcx, AttributeKind::SanitizeHWAddress, 0); + LLVMRustAddFunctionAttributes(lf, c_uint::MAX, &attr, 1); + } + + } else { + break; + } + } + } + if let Some(opt_level) = config.opt_level { let opt_stage = match cgcx.lto { Lto::Fat => llvm::OptStage::PreLinkFatLTO, diff --git a/compiler/rustc_codegen_llvm/src/base.rs b/compiler/rustc_codegen_llvm/src/base.rs index b659fd02eecf6..1d9157e6355f4 100644 --- a/compiler/rustc_codegen_llvm/src/base.rs +++ b/compiler/rustc_codegen_llvm/src/base.rs @@ -25,6 +25,7 @@ use rustc_codegen_ssa::base::maybe_create_entry_wrapper; use rustc_codegen_ssa::mono_item::MonoItemExt; use rustc_codegen_ssa::traits::*; use rustc_codegen_ssa::{ModuleCodegen, ModuleKind}; +use rustc_data_structures::fx::FxHashMap; use rustc_data_structures::small_c_str::SmallCStr; use rustc_middle::dep_graph; use rustc_middle::middle::codegen_fn_attrs::CodegenFnAttrs; @@ -82,9 +83,10 @@ pub fn compile_codegen_unit(tcx: TyCtxt<'_>, cgu_name: Symbol) -> (ModuleCodegen recorder.record_arg(cgu.size_estimate().to_string()); }); // Instantiate monomorphizations without filling out definitions yet... - let llvm_module = ModuleLlvm::new(tcx, cgu_name.as_str()); - { + let mut llvm_module = ModuleLlvm::new(tcx, cgu_name.as_str()); + let typetrees = { let cx = CodegenCx::new(tcx, cgu, &llvm_module); + let mono_items = cx.codegen_unit.items_in_deterministic_order(cx.tcx); for &(mono_item, data) in &mono_items { mono_item.predefine::>(&cx, data.linkage, data.visibility); @@ -132,7 +134,11 @@ pub fn compile_codegen_unit(tcx: TyCtxt<'_>, cgu_name: Symbol) -> (ModuleCodegen if cx.sess().opts.debuginfo != DebugInfo::None { cx.debuginfo_finalize(); } - } + + FxHashMap::default() + }; + + llvm_module.typetrees = typetrees; ModuleCodegen { name: cgu_name.to_string(), diff --git a/compiler/rustc_codegen_llvm/src/context.rs b/compiler/rustc_codegen_llvm/src/context.rs index b4b2ab1e1f8a9..2d16649ce17d2 100644 --- a/compiler/rustc_codegen_llvm/src/context.rs +++ b/compiler/rustc_codegen_llvm/src/context.rs @@ -624,6 +624,10 @@ impl<'ll, 'tcx> MiscMethods<'tcx> for CodegenCx<'ll, 'tcx> { None } } + + fn create_autodiff(&self) -> Vec { + return vec![]; + } } impl<'ll> CodegenCx<'ll, '_> { diff --git a/compiler/rustc_codegen_llvm/src/lib.rs b/compiler/rustc_codegen_llvm/src/lib.rs index 8a6a5f79b3bb9..011a208eb6389 100644 --- a/compiler/rustc_codegen_llvm/src/lib.rs +++ b/compiler/rustc_codegen_llvm/src/lib.rs @@ -30,6 +30,7 @@ use back::owned_target_machine::OwnedTargetMachine; use back::write::{create_informational_target_machine, create_target_machine}; use errors::ParseTargetMachineConfig; +use llvm::TypeTree; pub use llvm_util::target_features; use rustc_ast::expand::allocator::AllocatorKind; use rustc_codegen_ssa::back::lto::{LtoModuleCodegen, SerializedModule, ThinModule}; @@ -44,6 +45,8 @@ use rustc_errors::{DiagnosticMessage, ErrorGuaranteed, FatalError, Handler, Subd use rustc_fluent_macro::fluent_messages; use rustc_metadata::EncodedMetadata; use rustc_middle::dep_graph::{WorkProduct, WorkProductId}; +use rustc_middle::middle::autodiff_attrs::AutoDiffItem; +use rustc_middle::ty::query::Providers; use rustc_middle::ty::TyCtxt; use rustc_middle::util::Providers; use rustc_session::config::{OptLevel, OutputFilenames, PrintKind, PrintRequest}; @@ -77,6 +80,7 @@ mod debuginfo; mod declare; mod errors; mod intrinsic; +mod typetree; // The following is a workaround that replaces `pub mod llvm;` and that fixes issue 53912. #[path = "llvm/mod.rs"] @@ -172,6 +176,8 @@ impl WriteBackendMethods for LlvmCodegenBackend { type TargetMachineError = crate::errors::LlvmError<'static>; type ThinData = back::lto::ThinData; type ThinBuffer = back::lto::ThinBuffer; + type TypeTree = DiffTypeTree; + fn print_pass_timings(&self) { unsafe { let mut size = 0; @@ -254,6 +260,20 @@ impl WriteBackendMethods for LlvmCodegenBackend { fn serialize_module(module: ModuleCodegen) -> (String, Self::ModuleBuffer) { (module.name, back::lto::ModuleBuffer::new(module.module_llvm.llmod())) } + /// Generate autodiff rules + fn autodiff( + cgcx: &CodegenContext, + module: &ModuleCodegen, + diff_fncs: Vec, + typetrees: FxHashMap, + config: &ModuleConfig, + ) -> Result<(), FatalError> { + unsafe { back::write::differentiate(module, cgcx, diff_fncs, typetrees, config) } + } + + fn typetrees(module: &mut Self::Module) -> FxHashMap { + module.typetrees.drain().collect() + } } unsafe impl Send for LlvmCodegenBackend {} // Llvm is on a per-thread basis @@ -404,12 +424,20 @@ impl CodegenBackend for LlvmCodegenBackend { } } +#[derive(Clone, Debug)] +pub struct DiffTypeTree { + pub ret_tt: TypeTree, + pub input_tt: Vec, +} + +#[allow(dead_code)] pub struct ModuleLlvm { llcx: &'static mut llvm::Context, llmod_raw: *const llvm::Module, // independent from llcx and llmod_raw, resources get disposed by drop impl tm: OwnedTargetMachine, + typetrees: FxHashMap, } unsafe impl Send for ModuleLlvm {} @@ -420,7 +448,12 @@ impl ModuleLlvm { unsafe { let llcx = llvm::LLVMRustContextCreate(tcx.sess.fewer_names()); let llmod_raw = context::create_module(tcx, llcx, mod_name) as *const _; - ModuleLlvm { llmod_raw, llcx, tm: create_target_machine(tcx, mod_name) } + ModuleLlvm { + llmod_raw, + llcx, + tm: create_target_machine(tcx, mod_name), + typetrees: Default::default(), + } } } @@ -428,7 +461,12 @@ impl ModuleLlvm { unsafe { let llcx = llvm::LLVMRustContextCreate(tcx.sess.fewer_names()); let llmod_raw = context::create_module(tcx, llcx, mod_name) as *const _; - ModuleLlvm { llmod_raw, llcx, tm: create_informational_target_machine(tcx.sess) } + ModuleLlvm { + llmod_raw, + llcx, + tm: create_informational_target_machine(tcx.sess), + typetrees: Default::default(), + } } } @@ -449,7 +487,7 @@ impl ModuleLlvm { } }; - Ok(ModuleLlvm { llmod_raw, llcx, tm }) + Ok(ModuleLlvm { llmod_raw, llcx, tm, typetrees: Default::default() }) } } diff --git a/compiler/rustc_codegen_llvm/src/llvm/ffi.rs b/compiler/rustc_codegen_llvm/src/llvm/ffi.rs index a038b3af03dd6..c5514d5bff823 100644 --- a/compiler/rustc_codegen_llvm/src/llvm/ffi.rs +++ b/compiler/rustc_codegen_llvm/src/llvm/ffi.rs @@ -1,6 +1,9 @@ #![allow(non_camel_case_types)] #![allow(non_upper_case_globals)] +use rustc_codegen_ssa::coverageinfo::map as coverage_map; +use rustc_middle::middle::autodiff_attrs::DiffActivity; + use super::debuginfo::{ DIArray, DIBasicType, DIBuilder, DICompositeType, DIDerivedType, DIDescriptor, DIEnumerator, DIFile, DIFlags, DIGlobalVariableExpression, DILexicalBlock, DILocation, DINameSpace, @@ -11,6 +14,8 @@ use super::debuginfo::{ use libc::{c_char, c_int, c_uint, size_t}; use libc::{c_ulonglong, c_void}; +use core::fmt; +use std::ffi::{CStr, CString}; use std::marker::PhantomData; use super::RustString; @@ -187,7 +192,7 @@ pub enum AttributeKind { OptimizeNone = 24, ReturnsTwice = 25, ReadNone = 26, - SanitizeHWAddress = 28, + SanitizeHWAddress = 51, WillReturn = 29, StackProtectReq = 30, StackProtectStrong = 31, @@ -819,10 +824,186 @@ pub type SelfProfileBeforePassCallback = unsafe extern "C" fn(*mut c_void, *const c_char, *const c_char); pub type SelfProfileAfterPassCallback = unsafe extern "C" fn(*mut c_void); +#[repr(C)] +pub enum LLVMVerifierFailureAction { + LLVMAbortProcessAction, + LLVMPrintMessageAction, + LLVMReturnStatusAction, +} + +pub(crate) unsafe fn enzyme_rust_forward_diff( + logic_ref: EnzymeLogicRef, + type_analysis: EnzymeTypeAnalysisRef, + fnc: &Value, + input_diffactivity: Vec, + ret_diffactivity: DiffActivity, + mut ret_primary_ret: bool, + input_tts: Vec, + output_tt: TypeTree, +) -> &Value { + let ret_activity = cdiffe_from(ret_diffactivity); + assert!(ret_activity != CDIFFE_TYPE::DFT_OUT_DIFF); + let mut input_activity: Vec = vec![]; + for input in input_diffactivity { + let act = cdiffe_from(input); + assert!(act == CDIFFE_TYPE::DFT_CONSTANT || act == CDIFFE_TYPE::DFT_DUP_ARG || act == CDIFFE_TYPE::DFT_DUP_NONEED); + input_activity.push(act); + } + + if ret_activity == CDIFFE_TYPE::DFT_DUP_ARG { + if ret_primary_ret != true { + dbg!("overwriting ret_primary_ret!"); + } + ret_primary_ret = true; + } else if ret_activity == CDIFFE_TYPE::DFT_DUP_NONEED { + if ret_primary_ret != false { + dbg!("overwriting ret_primary_ret!"); + } + ret_primary_ret = false; + } + + let mut args_tree = input_tts.iter().map(|x| x.inner).collect::>(); + //let mut args_tree = vec![TypeTree::new().inner; typetree.input_tt.len()]; + + // We don't support volatile / extern / (global?) values. + // Just because I didn't had time to test them, and it seems less urgent. + let args_uncacheable = vec![0; input_activity.len()]; + + let kv_tmp = IntList { data: std::ptr::null_mut(), size: 0 }; + + let mut known_values = vec![kv_tmp; input_activity.len()]; + + let dummy_type = CFnTypeInfo { + Arguments: args_tree.as_mut_ptr(), + Return: output_tt.inner.clone(), + KnownValues: known_values.as_mut_ptr(), + }; + + EnzymeCreateForwardDiff( + logic_ref, // Logic + std::ptr::null(), + std::ptr::null(), + fnc, + ret_activity, // LLVM function, return type + input_activity.as_ptr(), + input_activity.len(), // constant arguments + type_analysis, // type analysis struct + ret_primary_ret as u8, + CDerivativeMode::DEM_ForwardMode, // return value, dret_used, top_level which was 1 + 1, // free memory + 1, // vector mode width + Option::None, + dummy_type, // additional_arg, type info (return + args) + args_uncacheable.as_ptr(), + args_uncacheable.len(), // uncacheable arguments + std::ptr::null_mut(), // write augmented function to this + ) +} + +pub(crate) unsafe fn enzyme_rust_reverse_diff( + logic_ref: EnzymeLogicRef, + type_analysis: EnzymeTypeAnalysisRef, + fnc: &Value, + input_activity: Vec, + ret_activity: DiffActivity, + mut ret_primary_ret: bool, + diff_primary_ret: bool, + input_tts: Vec, + output_tt: TypeTree, +) -> &Value { + let ret_activity = cdiffe_from(ret_activity); + assert!(ret_activity == CDIFFE_TYPE::DFT_CONSTANT || ret_activity == CDIFFE_TYPE::DFT_OUT_DIFF); + let input_activity: Vec = input_activity.iter().map(|&x| cdiffe_from(x)).collect(); + + if ret_activity == CDIFFE_TYPE::DFT_DUP_ARG { + if ret_primary_ret != true { + dbg!("overwriting ret_primary_ret!"); + } + ret_primary_ret = true; + } else if ret_activity == CDIFFE_TYPE::DFT_DUP_NONEED { + if ret_primary_ret != false { + dbg!("overwriting ret_primary_ret!"); + } + ret_primary_ret = false; + } + + let mut args_tree = input_tts.iter().map(|x| x.inner).collect::>(); + + // We don't support volatile / extern / (global?) values. + // Just because I didn't had time to test them, and it seems less urgent. + let args_uncacheable = vec![0; input_tts.len()]; + let kv_tmp = IntList { data: std::ptr::null_mut(), size: 0 }; + + + let mut known_values = vec![kv_tmp; input_tts.len()]; + + let dummy_type = CFnTypeInfo { + Arguments: args_tree.as_mut_ptr(), + Return: output_tt.inner.clone(), + KnownValues: known_values.as_mut_ptr(), + }; + + EnzymeCreatePrimalAndGradient( + logic_ref, // Logic + std::ptr::null(), + std::ptr::null(), + fnc, + ret_activity, // LLVM function, return type + input_activity.as_ptr(), + input_activity.len(), // constant arguments + type_analysis, // type analysis struct + ret_primary_ret as u8, + diff_primary_ret as u8, //0 + CDerivativeMode::DEM_ReverseModeCombined, // return value, dret_used, top_level which was 1 + 1, // vector mode width + 1, // free memory + Option::None, + 0, // do not force anonymous tape + dummy_type, // additional_arg, type info (return + args) + args_uncacheable.as_ptr(), + args_uncacheable.len(), // uncacheable arguments + std::ptr::null_mut(), // write augmented function to this + 0, + ) +} pub type GetSymbolsCallback = unsafe extern "C" fn(*mut c_void, *const c_char) -> *mut c_void; pub type GetSymbolsErrorCallback = unsafe extern "C" fn(*const c_char) -> *mut c_void; extern "C" { + + // Enzyme + //pub fn LLVMReplaceAllUsesWith(old: &Value, new: &Value); + pub fn GibtsNicht(M: &Module) -> bool; + pub fn LLVMIsStructTy(ty: &Type) -> bool; + pub fn LLVMGetReturnType(T: &Type) -> &Type; + pub fn LLVMDumpModule(M: &Module); + pub fn LLVMCountStructElementTypes(T: &Type) -> c_uint; + pub fn LLVMDeleteFunction(V: &Value); + pub fn LLVMRemoveStringAttributeAtIndex(F : &Value, Idx: c_uint, K: *const c_char, KLen : c_uint); + pub fn LLVMGetStringAttributeAtIndex(F : &Value, Idx: c_uint, K: *const c_char, KLen : c_uint) -> &Attribute; + pub fn LLVMRemoveEnumAttributeAtIndex(F : &Value, Idx: c_uint, K: AttributeKind); + pub fn LLVMGetEnumAttributeAtIndex(F : &Value, Idx: c_uint, K: AttributeKind) -> &Attribute; + pub fn LLVMIsEnumAttribute(A : &Attribute) -> bool; + pub fn LLVMCreateEnumAttribute(C : &Context, Kind: AttributeKind, val:u64) -> &Attribute; + pub fn LLVMIsStringAttribute(A : &Attribute) -> bool; + pub fn LLVMVerifyFunction(V: &Value, action: LLVMVerifierFailureAction) -> bool; + pub fn LLVMGetParams(Fnc: &Value, parms: *mut &Value); + pub fn LLVMBuildCall2<'a>( + arg1: &Builder<'a>, + ty: &Type, + func: &Value, + args: *mut &Value, + num_args: size_t, + name: *const c_char, + ) -> &'a Value; + pub fn LLVMGetBasicBlockTerminator(B: &BasicBlock) -> &Value; + pub fn LLVMAddFunction<'a>(M: &Module, Name: *const c_char, Ty: &Type) -> &'a Value; + pub fn LLVMGetFirstFunction(M: &Module) -> Option<&Value>; + pub fn LLVMGetNextFunction(V: &Value) -> Option<&Value>; + pub fn LLVMGetNamedFunction(M: &Module, Name: *const c_char) -> Option<&Value>; + pub fn LLVMGlobalGetValueType(val: &Value) -> &Type; + + pub fn LLVMRustGetFunctionType(fnc: &Value) -> &Type; pub fn LLVMRustInstallFatalErrorHandler(); pub fn LLVMRustDisableSystemDialogsOnCrash(); @@ -2091,6 +2272,8 @@ extern "C" { #[allow(improper_ctypes)] pub fn LLVMRustWriteTypeToString(Type: &Type, s: &RustString); #[allow(improper_ctypes)] + pub fn LLVMRustWriteValueNameToString(value_ref: &Value, s: &RustString); + #[allow(improper_ctypes)] pub fn LLVMRustWriteValueToString(value_ref: &Value, s: &RustString); pub fn LLVMIsAConstantInt(value_ref: &Value) -> Option<&ConstantInt>; @@ -2362,7 +2545,6 @@ extern "C" { remark_file: *const c_char, pgo_available: bool, ); - #[allow(improper_ctypes)] pub fn LLVMRustGetMangledName(V: &Value, out: &RustString); @@ -2382,3 +2564,301 @@ extern "C" { error_callback: GetSymbolsErrorCallback, ) -> *mut c_void; } +// Manuel +#[repr(C)] +#[derive(Debug, Copy, Clone)] +pub struct EnzymeOpaqueTypeAnalysis { + _unused: [u8; 0], +} +pub type EnzymeTypeAnalysisRef = *mut EnzymeOpaqueTypeAnalysis; +#[repr(C)] +#[derive(Debug, Copy, Clone)] +pub struct EnzymeOpaqueLogic { + _unused: [u8; 0], +} +pub type EnzymeLogicRef = *mut EnzymeOpaqueLogic; +#[repr(C)] +#[derive(Debug, Copy, Clone)] +pub struct EnzymeOpaqueAugmentedReturn { + _unused: [u8; 0], +} +pub type EnzymeAugmentedReturnPtr = *mut EnzymeOpaqueAugmentedReturn; +#[repr(C)] +#[derive(Debug, Copy, Clone)] +pub struct IntList { + pub data: *mut i64, + pub size: size_t, +} +#[repr(u32)] +#[derive(Debug, Copy, Clone, Hash, PartialEq, Eq)] +pub enum CConcreteType { + DT_Anything = 0, + DT_Integer = 1, + DT_Pointer = 2, + DT_Half = 3, + DT_Float = 4, + DT_Double = 5, + DT_Unknown = 6, +} +#[repr(C)] +#[derive(Debug, Copy, Clone)] +pub struct EnzymeTypeTree { + _unused: [u8; 0], +} +pub type CTypeTreeRef = *mut EnzymeTypeTree; +extern "C" { + fn EnzymeNewTypeTree() -> CTypeTreeRef; +} +extern "C" { + fn EnzymeFreeTypeTree(CTT: CTypeTreeRef); +} +extern "C" { + pub fn EnzymeSetCLBool(arg1: *mut ::std::os::raw::c_void, arg2: u8); +} +extern "C" { + pub fn EnzymeSetCLInteger(arg1: *mut ::std::os::raw::c_void, arg2: i64); +} + +extern "C" { + pub static mut MaxIntOffset: c_void; + pub static mut MaxTypeOffset: c_void; + pub static mut EnzymeMaxTypeDepth: c_void; + + pub static mut EnzymePrintPerf: c_void; + pub static mut EnzymePrintActivity: c_void; + pub static mut EnzymePrintType: c_void; + pub static mut EnzymePrint: c_void; + pub static mut EnzymeStrictAliasing: c_void; +} + +#[repr(C)] +#[derive(Debug, Copy, Clone)] +pub struct CFnTypeInfo { + #[doc = " Types of arguments, assumed of size len(Arguments)"] + pub Arguments: *mut CTypeTreeRef, + #[doc = " Type of return"] + pub Return: CTypeTreeRef, + #[doc = " The specific constant(s) known to represented by an argument, if constant"] + pub KnownValues: *mut IntList, +} +#[repr(u32)] +#[derive(Debug, Copy, Clone, Hash, PartialEq, Eq)] +pub enum CDIFFE_TYPE { + DFT_OUT_DIFF = 0, + DFT_DUP_ARG = 1, + DFT_CONSTANT = 2, + DFT_DUP_NONEED = 3, +} + +fn cdiffe_from(act: DiffActivity) -> CDIFFE_TYPE { + return match act { + DiffActivity::None => CDIFFE_TYPE::DFT_CONSTANT, + DiffActivity::Active => CDIFFE_TYPE::DFT_OUT_DIFF, + DiffActivity::Const => CDIFFE_TYPE::DFT_CONSTANT, + DiffActivity::Duplicated => CDIFFE_TYPE::DFT_DUP_ARG, + DiffActivity::DuplicatedNoNeed => CDIFFE_TYPE::DFT_DUP_NONEED, + }; +} + +#[repr(u32)] +#[derive(Debug, Copy, Clone, Hash, PartialEq, Eq)] +pub enum CDerivativeMode { + DEM_ForwardMode = 0, + DEM_ReverseModePrimal = 1, + DEM_ReverseModeGradient = 2, + DEM_ReverseModeCombined = 3, + DEM_ForwardModeSplit = 4, +} +extern "C" { + fn EnzymeCreatePrimalAndGradient<'a>( + arg1: EnzymeLogicRef, + _builderCtx: *const u8, // &'a Builder<'_>, + _callerCtx: *const u8,// &'a Value, + todiff: &'a Value, + retType: CDIFFE_TYPE, + constant_args: *const CDIFFE_TYPE, + constant_args_size: size_t, + TA: EnzymeTypeAnalysisRef, + returnValue: u8, + dretUsed: u8, + mode: CDerivativeMode, + width: ::std::os::raw::c_uint, + freeMemory: u8, + additionalArg: Option<&Type>, + forceAnonymousTape: u8, + typeInfo: CFnTypeInfo, + _uncacheable_args: *const u8, + uncacheable_args_size: size_t, + augmented: EnzymeAugmentedReturnPtr, + AtomicAdd: u8, + ) -> &'a Value; + //) -> LLVMValueRef; +} +extern "C" { + fn EnzymeCreateForwardDiff<'a>( + arg1: EnzymeLogicRef, + _builderCtx: *const u8,// &'a Builder<'_>, + _callerCtx: *const u8,// &'a Value, + todiff: &'a Value, + retType: CDIFFE_TYPE, + constant_args: *const CDIFFE_TYPE, + constant_args_size: size_t, + TA: EnzymeTypeAnalysisRef, + returnValue: u8, + mode: CDerivativeMode, + freeMemory: u8, + width: ::std::os::raw::c_uint, + additionalArg: Option<&Type>, + typeInfo: CFnTypeInfo, + _uncacheable_args: *const u8, + uncacheable_args_size: size_t, + augmented: EnzymeAugmentedReturnPtr, + ) -> &'a Value; +} +pub type CustomRuleType = ::std::option::Option< + unsafe extern "C" fn( + direction: ::std::os::raw::c_int, + ret: CTypeTreeRef, + args: *mut CTypeTreeRef, + known_values: *mut IntList, + num_args: size_t, + fnc: &Value, + ta: *const ::std::os::raw::c_void, + ) -> u8, +>; +extern "C" { + pub fn CreateTypeAnalysis( + Log: EnzymeLogicRef, + customRuleNames: *mut *mut ::std::os::raw::c_char, + customRules: *mut CustomRuleType, + numRules: size_t, + ) -> EnzymeTypeAnalysisRef; +} +extern "C" { + pub fn ClearTypeAnalysis(arg1: EnzymeTypeAnalysisRef); +} +extern "C" { + pub fn FreeTypeAnalysis(arg1: EnzymeTypeAnalysisRef); +} +extern "C" { + pub fn CreateEnzymeLogic(PostOpt: u8) -> EnzymeLogicRef; +} +extern "C" { + pub fn ClearEnzymeLogic(arg1: EnzymeLogicRef); +} +extern "C" { + pub fn FreeEnzymeLogic(arg1: EnzymeLogicRef); +} + +extern "C" { + fn EnzymeNewTypeTreeCT(arg1: CConcreteType, ctx: &Context) -> CTypeTreeRef; + fn EnzymeNewTypeTreeTR(arg1: CTypeTreeRef) -> CTypeTreeRef; + fn EnzymeMergeTypeTree(arg1: CTypeTreeRef, arg2: CTypeTreeRef) -> bool; + fn EnzymeTypeTreeOnlyEq(arg1: CTypeTreeRef, pos: i64); + fn EnzymeTypeTreeData0Eq(arg1: CTypeTreeRef); + fn EnzymeTypeTreeShiftIndiciesEq( + arg1: CTypeTreeRef, + data_layout: *const c_char, + offset: i64, + max_size: i64, + add_offset: u64, + ); + fn EnzymeTypeTreeToStringFree(arg1: *const c_char); + fn EnzymeTypeTreeToString(arg1: CTypeTreeRef) -> *const c_char; +} + +pub struct TypeTree { + pub inner: CTypeTreeRef, +} + +impl TypeTree { + pub fn new() -> TypeTree { + let inner = unsafe { EnzymeNewTypeTree() }; + + TypeTree { inner } + } + + #[must_use] + pub fn from_type(t: CConcreteType, ctx: &Context) -> TypeTree { + let inner = unsafe { EnzymeNewTypeTreeCT(t, ctx) }; + + TypeTree { inner } + } + + #[must_use] + pub fn only(self, idx: isize) -> TypeTree { + unsafe { + EnzymeTypeTreeOnlyEq(self.inner, idx as i64); + } + self + } + + #[must_use] + pub fn data0(self) -> TypeTree { + unsafe { + EnzymeTypeTreeData0Eq(self.inner); + } + self + } + + pub fn merge(self, other: Self) -> Self { + unsafe { + EnzymeMergeTypeTree(self.inner, other.inner); + } + drop(other); + + self + } + + #[must_use] + pub fn shift(self, layout: &str, offset: isize, max_size: isize, add_offset: usize) -> Self { + let layout = CString::new(layout).unwrap(); + + unsafe { + EnzymeTypeTreeShiftIndiciesEq( + self.inner, + layout.as_ptr(), + offset as i64, + max_size as i64, + add_offset as u64, + ) + } + + self + } +} + +impl Clone for TypeTree { + fn clone(&self) -> Self { + let inner = unsafe { EnzymeNewTypeTreeTR(self.inner) }; + TypeTree { inner } + } +} + +impl fmt::Display for TypeTree { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let ptr = unsafe { EnzymeTypeTreeToString(self.inner) }; + let cstr = unsafe { CStr::from_ptr(ptr) }; + match cstr.to_str() { + Ok(x) => write!(f, "{}", x)?, + Err(err) => write!(f, "could not parse: {}", err)?, + } + + // delete C string pointer + unsafe { EnzymeTypeTreeToStringFree(ptr) } + + Ok(()) + } +} + +impl fmt::Debug for TypeTree { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + ::fmt(self, f) + } +} + +impl Drop for TypeTree { + fn drop(&mut self) { + unsafe { EnzymeFreeTypeTree(self.inner) } + } +} diff --git a/compiler/rustc_codegen_llvm/src/typetree.rs b/compiler/rustc_codegen_llvm/src/typetree.rs new file mode 100644 index 0000000000000..091ddaa3cf213 --- /dev/null +++ b/compiler/rustc_codegen_llvm/src/typetree.rs @@ -0,0 +1,33 @@ +use crate::llvm; +use rustc_middle::middle::typetree::{Kind, TypeTree}; + +pub fn to_enzyme_typetree( + tree: TypeTree, + llvm_data_layout: &str, + llcx: &llvm::Context, +) -> llvm::TypeTree { + tree.0.iter().fold(llvm::TypeTree::new(), |obj, x| { + let scalar = match x.kind { + Kind::Integer => llvm::CConcreteType::DT_Integer, + Kind::Float => llvm::CConcreteType::DT_Float, + Kind::Double => llvm::CConcreteType::DT_Double, + Kind::Pointer => llvm::CConcreteType::DT_Pointer, + _ => panic!("Unknown kind {:?}", x.kind), + }; + + let tt = llvm::TypeTree::from_type(scalar, llcx).only(-1); + + let tt = if !x.child.0.is_empty() { + let inner_tt = to_enzyme_typetree(x.child.clone(), llvm_data_layout, llcx); + tt.merge(inner_tt.only(-1)) + } else { + tt + }; + + if x.offset != -1 { + obj.merge(tt.shift(llvm_data_layout, 0, x.size as isize, x.offset as usize)) + } else { + obj.merge(tt) + } + }) +} diff --git a/compiler/rustc_codegen_ssa/src/assert_module_sources.rs b/compiler/rustc_codegen_ssa/src/assert_module_sources.rs index 16bb7b12bd3c1..a4ba7bfe7d20b 100644 --- a/compiler/rustc_codegen_ssa/src/assert_module_sources.rs +++ b/compiler/rustc_codegen_ssa/src/assert_module_sources.rs @@ -46,7 +46,7 @@ pub fn assert_module_sources(tcx: TyCtxt<'_>, set_reuse: &dyn Fn(&mut CguReuseTr } let available_cgus = - tcx.collect_and_partition_mono_items(()).1.iter().map(|cgu| cgu.name()).collect(); + tcx.collect_and_partition_mono_items(()).2.iter().map(|cgu| cgu.name()).collect(); let mut ams = AssertModuleSource { tcx, diff --git a/compiler/rustc_codegen_ssa/src/back/lto.rs b/compiler/rustc_codegen_ssa/src/back/lto.rs index cb6244050df24..f27b09c8146f3 100644 --- a/compiler/rustc_codegen_ssa/src/back/lto.rs +++ b/compiler/rustc_codegen_ssa/src/back/lto.rs @@ -1,9 +1,11 @@ use super::write::CodegenContext; +use crate::back::write::ModuleConfig; use crate::traits::*; use crate::ModuleCodegen; -use rustc_data_structures::memmap::Mmap; +use rustc_data_structures::{fx::FxHashMap, memmap::Mmap}; use rustc_errors::FatalError; +use rustc_middle::middle::autodiff_attrs::AutoDiffItem; use std::ffi::CString; use std::sync::Arc; @@ -76,6 +78,27 @@ impl LtoModuleCodegen { } } + /// Run autodiff on Fat LTO module + pub unsafe fn autodiff( + self, + cgcx: &CodegenContext, + diff_fncs: Vec, + typetrees: FxHashMap, + config: &ModuleConfig, + ) -> Result, FatalError> { + match &self { + LtoModuleCodegen::Fat { ref module, .. } => { + //let module = module.take().unwrap(); + { + B::autodiff(cgcx, &module, diff_fncs, typetrees, config)?; + } + }, + _ => {}, + } + + Ok(self) + } + /// A "gauge" of how costly it is to optimize this module, used to sort /// biggest modules first. pub fn cost(&self) -> u64 { diff --git a/compiler/rustc_codegen_ssa/src/back/symbol_export.rs b/compiler/rustc_codegen_ssa/src/back/symbol_export.rs index 9cd4394108a4a..5fd525dd56e03 100644 --- a/compiler/rustc_codegen_ssa/src/back/symbol_export.rs +++ b/compiler/rustc_codegen_ssa/src/back/symbol_export.rs @@ -317,7 +317,7 @@ fn exported_symbols_provider_local( // external linkage is enough for monomorphization to be linked to. let need_visibility = tcx.sess.target.dynamic_linking && !tcx.sess.target.only_cdylib; - let (_, cgus) = tcx.collect_and_partition_mono_items(()); + let (_, _, cgus) = tcx.collect_and_partition_mono_items(()); for (mono_item, data) in cgus.iter().flat_map(|cgu| cgu.items().iter()) { if data.linkage != Linkage::External { diff --git a/compiler/rustc_codegen_ssa/src/back/write.rs b/compiler/rustc_codegen_ssa/src/back/write.rs index 3d6a212433463..0f77938999d9e 100644 --- a/compiler/rustc_codegen_ssa/src/back/write.rs +++ b/compiler/rustc_codegen_ssa/src/back/write.rs @@ -24,6 +24,7 @@ use rustc_incremental::{ use rustc_metadata::fs::copy_to_stdout; use rustc_metadata::EncodedMetadata; use rustc_middle::dep_graph::{WorkProduct, WorkProductId}; +use rustc_middle::middle::autodiff_attrs::AutoDiffItem; use rustc_middle::middle::exported_symbols::SymbolExportInfo; use rustc_middle::ty::TyCtxt; use rustc_session::config::{self, CrateType, Lto, OutFileName, OutputFilenames, OutputType}; @@ -117,6 +118,7 @@ pub struct ModuleConfig { pub inline_threshold: Option, pub emit_lifetime_markers: bool, pub llvm_plugins: Vec, + pub enzyme_print_activity: bool, } impl ModuleConfig { @@ -194,6 +196,7 @@ impl ModuleConfig { false ), + enzyme_print_activity: sess.opts.unstable_opts.enzyme_print_activity, sanitizer: if_regular!(sess.opts.unstable_opts.sanitizer, SanitizerSet::empty()), sanitizer_recover: if_regular!( sess.opts.unstable_opts.sanitizer_recover, @@ -385,6 +388,8 @@ impl CodegenContext { fn generate_lto_work( cgcx: &CodegenContext, + autodiff: Vec, + typetrees: FxHashMap, needs_fat_lto: Vec>, needs_thin_lto: Vec<(String, B::ThinBuffer)>, import_only_modules: Vec<(SerializedModule, WorkProduct)>, @@ -393,10 +398,14 @@ fn generate_lto_work( if !needs_fat_lto.is_empty() { assert!(needs_thin_lto.is_empty()); - let module = + let mut lto_module = B::run_fat_lto(cgcx, needs_fat_lto, import_only_modules).unwrap_or_else(|e| e.raise()); + if cgcx.lto == Lto::Fat { + let config = cgcx.config(ModuleKind::Regular); + lto_module = unsafe { lto_module.autodiff(cgcx, autodiff, typetrees, config).unwrap() }; + } // We are adding a single work item, so the cost doesn't matter. - vec![(WorkItem::LTO(module), 0)] + vec![(WorkItem::LTO(lto_module), 0)] } else { assert!(needs_fat_lto.is_empty()); let (lto_modules, copy_jobs) = B::run_thin_lto(cgcx, needs_thin_lto, import_only_modules) @@ -985,6 +994,8 @@ pub(crate) enum Message { work_product: WorkProduct, }, + AddAutoDiffItems(Vec), + /// The frontend has finished generating everything for all codegen units. /// Sent from the main thread. CodegenComplete, @@ -1287,6 +1298,8 @@ fn start_executing_work( let mut needs_link = Vec::new(); let mut needs_fat_lto = Vec::new(); let mut needs_thin_lto = Vec::new(); + let mut autodiff_items = Vec::new(); + let mut typetrees = FxHashMap::::default(); let mut lto_import_only_modules = Vec::new(); let mut started_lto = false; @@ -1393,9 +1406,14 @@ fn start_executing_work( let needs_thin_lto = mem::take(&mut needs_thin_lto); let import_only_modules = mem::take(&mut lto_import_only_modules); - for (work, cost) in - generate_lto_work(&cgcx, needs_fat_lto, needs_thin_lto, import_only_modules) - { + for (work, cost) in generate_lto_work( + &cgcx, + autodiff_items.clone(), + typetrees.clone(), + needs_fat_lto, + needs_thin_lto, + import_only_modules, + ) { let insertion_index = work_items .binary_search_by_key(&cost, |&(_, cost)| cost) .unwrap_or_else(|e| e); @@ -1508,7 +1526,16 @@ fn start_executing_work( } } - Message::CodegenDone { llvm_work_item, cost } => { + Message::CodegenDone { mut llvm_work_item, cost } => { + //// extract build typetrees + match &mut llvm_work_item { + WorkItem::Optimize(module) => { + let tt = B::typetrees(&mut module.module_llvm); + typetrees.extend(tt); + } + _ => {}, + } + // We keep the queue sorted by estimated processing cost, // so that more expensive items are processed earlier. This // is good for throughput as it gives the main thread more @@ -1549,6 +1576,10 @@ fn start_executing_work( codegen_state = Aborted; } + Message::AddAutoDiffItems(mut items) => { + autodiff_items.append(&mut items); + } + Message::WorkItem { result, worker_id } => { free_worker(worker_id); @@ -2000,6 +2031,10 @@ impl OngoingCodegen { drop(self.coordinator.sender.send(Box::new(Message::CodegenComplete::))); } + pub fn submit_autodiff_items(&self, items: Vec) { + drop(self.coordinator.sender.send(Box::new(Message::::AddAutoDiffItems(items)))); + } + pub fn check_for_errors(&self, sess: &Session) { self.shared_emitter_main.check(sess, false); } diff --git a/compiler/rustc_codegen_ssa/src/base.rs b/compiler/rustc_codegen_ssa/src/base.rs index 198e5696357af..a8641ba9fbb30 100644 --- a/compiler/rustc_codegen_ssa/src/base.rs +++ b/compiler/rustc_codegen_ssa/src/base.rs @@ -590,7 +590,8 @@ pub fn codegen_crate( // Run the monomorphization collector and partition the collected items into // codegen units. - let codegen_units = tcx.collect_and_partition_mono_items(()).1; + let (_, autodiff_fncs, codegen_units) = tcx.collect_and_partition_mono_items(()); + let autodiff_fncs = autodiff_fncs.to_vec(); // Force all codegen_unit queries so they are already either red or green // when compile_codegen_unit accesses them. We are not able to re-execute @@ -659,6 +660,10 @@ pub fn codegen_crate( ); } + if !autodiff_fncs.is_empty() { + ongoing_codegen.submit_autodiff_items(autodiff_fncs); + } + // For better throughput during parallel processing by LLVM, we used to sort // CGUs largest to smallest. This would lead to better thread utilization // by, for example, preventing a large CGU from being processed last and @@ -982,7 +987,7 @@ pub fn provide(providers: &mut Providers) { config::OptLevel::SizeMin => config::OptLevel::Default, }; - let (defids, _) = tcx.collect_and_partition_mono_items(cratenum); + let (defids, _, _) = tcx.collect_and_partition_mono_items(cratenum); let any_for_speed = defids.items().any(|id| { let CodegenFnAttrs { optimize, .. } = tcx.codegen_fn_attrs(*id); diff --git a/compiler/rustc_codegen_ssa/src/codegen_attrs.rs b/compiler/rustc_codegen_ssa/src/codegen_attrs.rs index 2e0840f2d1bc3..58019ae43129f 100644 --- a/compiler/rustc_codegen_ssa/src/codegen_attrs.rs +++ b/compiler/rustc_codegen_ssa/src/codegen_attrs.rs @@ -1,10 +1,11 @@ -use rustc_ast::{ast, attr, MetaItemKind, NestedMetaItem}; +use rustc_ast::{ast, attr, MetaItem, MetaItemKind, NestedMetaItem}; use rustc_attr::{list_contains_name, InlineAttr, InstructionSetAttr, OptimizeAttr}; use rustc_errors::struct_span_err; use rustc_hir as hir; use rustc_hir::def::DefKind; use rustc_hir::def_id::{DefId, LocalDefId, LOCAL_CRATE}; use rustc_hir::{lang_items, weak_lang_items::WEAK_LANG_ITEMS, LangItem}; +use rustc_middle::middle::autodiff_attrs::{AutoDiffAttrs, DiffActivity, DiffMode}; use rustc_middle::middle::codegen_fn_attrs::{CodegenFnAttrFlags, CodegenFnAttrs}; use rustc_middle::mir::mono::Linkage; use rustc_middle::query::Providers; @@ -13,6 +14,7 @@ use rustc_session::{lint, parse::feature_err}; use rustc_span::symbol::Ident; use rustc_span::{sym, Span}; use rustc_target::spec::{abi, SanitizerSet}; +use std::str::FromStr; use crate::errors; use crate::target_features::from_target_feature; @@ -697,6 +699,162 @@ fn check_link_name_xor_ordinal( } } +fn autodiff_attrs(tcx: TyCtxt<'_>, id: DefId) -> AutoDiffAttrs { + let attrs = tcx.get_attrs(id, sym::autodiff_into); + + let attrs = attrs + .into_iter() + .filter(|attr| attr.name_or_empty() == sym::autodiff_into) + .collect::>(); + + // check for exactly one autodiff attribute on extern block + let attr = match &attrs[..] { + &[] => return AutoDiffAttrs::inactive(), + &[elm] => elm, + x => { + tcx.sess + .struct_span_err(x[1].span, "autodiff attribute can only be applied once") + .span_label(x[1].span, "more than one") + .emit(); + + return AutoDiffAttrs::inactive(); + } + }; + + let list = attr.meta_item_list().unwrap_or_default(); + + // empty autodiff attribute macros (i.e. `#[autodiff]`) are used to mark source functions + if list.len() == 0 { + return AutoDiffAttrs { + mode: DiffMode::Source, + ret_activity: DiffActivity::None, + input_activity: Vec::new(), + }; + } + + let mode = match &list[0] { + NestedMetaItem::MetaItem(MetaItem { path: ref p2, kind: MetaItemKind::Word, .. }) => { + p2.segments.first().unwrap().ident + } + _ => { + tcx.sess + .struct_span_err(attr.span, "attribute must contain autodiff mode") + .span_label(attr.span, "empty argument list") + .emit(); + + return AutoDiffAttrs::inactive(); + } + }; + + // parse mode + let mode = match mode.as_str() { + //map(|x| x.as_str()) { + "Forward" => DiffMode::Forward, + "Reverse" => DiffMode::Reverse, + _ => { + tcx.sess + .struct_span_err(attr.span, "mode should be either forward or reverse") + .span_label(attr.span, "invalid mode") + .emit(); + + return AutoDiffAttrs::inactive(); + } + }; + + let ret_symbol = match &list[1] { + NestedMetaItem::MetaItem(MetaItem { path: ref p2, kind: MetaItemKind::Word, .. }) => { + p2.segments.first().unwrap().ident + } + _ => { + tcx.sess + .struct_span_err(attr.span, "autodiff attribute must contain the return activity") + .span_label(attr.span, "missing return activity") + .emit(); + + return AutoDiffAttrs::inactive(); + } + }; + + let ret_activity = match DiffActivity::from_str(ret_symbol.as_str()) { + Ok(x) => x, + Err(_) => { + tcx.sess + .struct_span_err(attr.span, "unknown return activity") + .span_label(attr.span, "invalid return activity") + .emit(); + + return AutoDiffAttrs::inactive(); + } + }; + + let mut arg_activities: Vec = vec![]; + for arg in &list[2..] { + let arg_symbol = match arg { + NestedMetaItem::MetaItem(MetaItem { + path: ref p2, kind: MetaItemKind::Word, .. + }) => p2.segments.first().unwrap().ident, + _ => { + tcx.sess + .struct_span_err( + attr.span, + "autodiff attribute must contain the return activity", + ) + .span_label(attr.span, "missing return activity") + .emit(); + + return AutoDiffAttrs::inactive(); + } + }; + + match DiffActivity::from_str(arg_symbol.as_str()) { + Ok(arg_activity) => arg_activities.push(arg_activity), + Err(_) => { + tcx.sess + .struct_span_err(attr.span, "unknown return activity") + .span_label(attr.span, "invalid input activity") + .emit(); + + return AutoDiffAttrs::inactive(); + } + } + } + + if mode == DiffMode::Forward { + if ret_activity == DiffActivity::Active { + tcx.sess + .struct_span_err(attr.span, "Forward Mode is incompatible with Active ret") + .span_label(attr.span, "invalid return activity") + .emit(); + return AutoDiffAttrs::inactive(); + } + if arg_activities.iter().filter(|&x| *x == DiffActivity::Active).count() > 0 { + tcx.sess + .struct_span_err(attr.span, "Forward Mode is incompatible with Active args") + .span_label(attr.span, "invalid input activity") + .emit(); + return AutoDiffAttrs::inactive(); + } + } + + if mode == DiffMode::Reverse { + if ret_activity == DiffActivity::Duplicated + || ret_activity == DiffActivity::DuplicatedNoNeed + { + tcx.sess + .struct_span_err( + attr.span, + "Reverse Mode is only compatible with Active, None, or Const ret", + ) + .span_label(attr.span, "invalid return activity") + .emit(); + return AutoDiffAttrs::inactive(); + } + } + + AutoDiffAttrs { mode, ret_activity, input_activity: arg_activities } +} + pub fn provide(providers: &mut Providers) { - *providers = Providers { codegen_fn_attrs, should_inherit_track_caller, ..*providers }; + *providers = + Providers { codegen_fn_attrs, should_inherit_track_caller, autodiff_attrs, ..*providers }; } diff --git a/compiler/rustc_codegen_ssa/src/traits/misc.rs b/compiler/rustc_codegen_ssa/src/traits/misc.rs index 04e2b8796c46a..5f64dd3367661 100644 --- a/compiler/rustc_codegen_ssa/src/traits/misc.rs +++ b/compiler/rustc_codegen_ssa/src/traits/misc.rs @@ -19,4 +19,5 @@ pub trait MiscMethods<'tcx>: BackendTypes { fn apply_target_cpu_attr(&self, llfn: Self::Function); /// Declares the extern "C" main function for the entry point. Returns None if the symbol already exists. fn declare_c_main(&self, fn_type: Self::Type) -> Option; + fn create_autodiff(&self) -> Vec; } diff --git a/compiler/rustc_codegen_ssa/src/traits/write.rs b/compiler/rustc_codegen_ssa/src/traits/write.rs index ecf5095d8a335..9c1be89580dc4 100644 --- a/compiler/rustc_codegen_ssa/src/traits/write.rs +++ b/compiler/rustc_codegen_ssa/src/traits/write.rs @@ -2,8 +2,10 @@ use crate::back::lto::{LtoModuleCodegen, SerializedModule, ThinModule}; use crate::back::write::{CodegenContext, FatLtoInput, ModuleConfig}; use crate::{CompiledModule, ModuleCodegen}; +use rustc_data_structures::fx::FxHashMap; use rustc_errors::{FatalError, Handler}; use rustc_middle::dep_graph::WorkProduct; +use rustc_middle::middle::autodiff_attrs::AutoDiffItem; pub trait WriteBackendMethods: 'static + Sized + Clone { type Module: Send + Sync; @@ -12,6 +14,7 @@ pub trait WriteBackendMethods: 'static + Sized + Clone { type ModuleBuffer: ModuleBufferMethods; type ThinData: Send + Sync; type ThinBuffer: ThinBufferMethods; + type TypeTree: Clone; /// Merge all modules into main_module and returning it fn run_link( @@ -58,6 +61,15 @@ pub trait WriteBackendMethods: 'static + Sized + Clone { ) -> Result; fn prepare_thin(module: ModuleCodegen) -> (String, Self::ThinBuffer); fn serialize_module(module: ModuleCodegen) -> (String, Self::ModuleBuffer); + /// Generate autodiff rules + fn autodiff( + cgcx: &CodegenContext, + module: &ModuleCodegen, + diff_fncs: Vec, + typetrees: FxHashMap, + config: &ModuleConfig, + ) -> Result<(), FatalError>; + fn typetrees(module: &mut Self::Module) -> FxHashMap; } pub trait ThinBufferMethods: Send + Sync { diff --git a/compiler/rustc_feature/src/builtin_attrs.rs b/compiler/rustc_feature/src/builtin_attrs.rs index e808e4815fe0b..2ed334569995b 100644 --- a/compiler/rustc_feature/src/builtin_attrs.rs +++ b/compiler/rustc_feature/src/builtin_attrs.rs @@ -353,6 +353,13 @@ pub const BUILTIN_ATTRIBUTES: &[BuiltinAttribute] = &[ ungated!(used, Normal, template!(Word, List: "compiler|linker"), WarnFollowing, @only_local: true), ungated!(link_ordinal, Normal, template!(List: "ordinal"), ErrorPreceding), + // Autodiff + ungated!( + autodiff_into, Normal, + template!(Word, List: r#""...""#), + DuplicatesOk, + ), + // Limits: ungated!(recursion_limit, CrateLevel, template!(NameValueStr: "N"), FutureWarnFollowing), ungated!(type_length_limit, CrateLevel, template!(NameValueStr: "N"), FutureWarnFollowing), diff --git a/compiler/rustc_interface/src/tests.rs b/compiler/rustc_interface/src/tests.rs index 57ca709267a7e..4439550d8d037 100644 --- a/compiler/rustc_interface/src/tests.rs +++ b/compiler/rustc_interface/src/tests.rs @@ -767,6 +767,7 @@ fn test_unstable_options_tracking_hash() { tracked!(debug_macros, true); tracked!(dep_info_omit_d_target, true); tracked!(dual_proc_macros, true); + tracked!(enzyme_print_activity, false); tracked!(dwarf_version, Some(5)); tracked!(emit_thin_lto, false); tracked!(export_executable_symbols, true); diff --git a/compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp b/compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp index 4390486b0deb1..e7db075aefa2f 100644 --- a/compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp +++ b/compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp @@ -96,6 +96,11 @@ extern "C" char *LLVMRustGetLastError(void) { return Ret; } +extern "C" LLVMTypeRef LLVMRustGetFunctionType(LLVMValueRef Fn) { + auto Ftype = unwrap(Fn)->getFunctionType(); + return wrap(Ftype); +} + extern "C" void LLVMRustSetLastError(const char *Err) { free((void *)LastError); LastError = strdup(Err); diff --git a/compiler/rustc_middle/src/arena.rs b/compiler/rustc_middle/src/arena.rs index 1d573a746b918..acb0a25f087eb 100644 --- a/compiler/rustc_middle/src/arena.rs +++ b/compiler/rustc_middle/src/arena.rs @@ -97,6 +97,7 @@ macro_rules! arena_types { [] upvars_mentioned: rustc_data_structures::fx::FxIndexMap, [] object_safety_violations: rustc_middle::traits::ObjectSafetyViolation, [] codegen_unit: rustc_middle::mir::mono::CodegenUnit<'tcx>, + [] autodiff_item: rustc_middle::middle::autodiff_attrs::AutoDiffItem, [decode] attribute: rustc_ast::Attribute, [] name_set: rustc_data_structures::unord::UnordSet, [] ordered_name_set: rustc_data_structures::fx::FxIndexSet, diff --git a/compiler/rustc_middle/src/middle/autodiff_attrs.rs b/compiler/rustc_middle/src/middle/autodiff_attrs.rs new file mode 100644 index 0000000000000..2412df725fe2b --- /dev/null +++ b/compiler/rustc_middle/src/middle/autodiff_attrs.rs @@ -0,0 +1,94 @@ +use crate::middle::typetree::TypeTree; +use std::str::FromStr; + +#[allow(dead_code)] +#[derive(Clone, Copy, Eq, PartialEq, TyEncodable, TyDecodable, HashStable, Debug)] +pub enum DiffMode { + Inactive, + Source, + Forward, + Reverse, +} + +#[allow(dead_code)] +#[derive(Clone, Copy, Eq, PartialEq, TyEncodable, TyDecodable, HashStable, Debug)] +pub enum DiffActivity { + None, + Active, + Const, + Duplicated, + DuplicatedNoNeed, +} + +impl FromStr for DiffActivity { + type Err = (); + + fn from_str(s: &str) -> Result { + match s { + "None" => Ok(DiffActivity::None), + "Active" => Ok(DiffActivity::Active), + "Const" => Ok(DiffActivity::Const), + "Duplicated" => Ok(DiffActivity::Duplicated), + "DuplicatedNoNeed" => Ok(DiffActivity::DuplicatedNoNeed), + _ => Err(()), + } + } +} + +#[allow(dead_code)] +#[derive(Clone, Eq, PartialEq, TyEncodable, TyDecodable, HashStable, Debug)] +pub struct AutoDiffAttrs { + pub mode: DiffMode, + pub ret_activity: DiffActivity, + pub input_activity: Vec, +} + +impl AutoDiffAttrs { + pub fn inactive() -> Self { + AutoDiffAttrs { + mode: DiffMode::Inactive, + ret_activity: DiffActivity::None, + input_activity: Vec::new(), + } + } + + pub fn is_active(&self) -> bool { + match self.mode { + DiffMode::Inactive => false, + _ => true, + } + } + + pub fn is_source(&self) -> bool { + match self.mode { + DiffMode::Source => true, + _ => false, + } + } + pub fn apply_autodiff(&self) -> bool { + match self.mode { + DiffMode::Inactive => false, + DiffMode::Source => false, + _ => true, + } + } + + pub fn into_item( + self, + source: String, + target: String, + inputs: Vec, + output: TypeTree, + ) -> AutoDiffItem { + AutoDiffItem { source, target, inputs, output, attrs: self } + } +} + +#[derive(Clone, Eq, PartialEq, TyEncodable, TyDecodable, HashStable, Debug)] +pub struct AutoDiffItem { + pub source: String, + pub target: String, + pub attrs: AutoDiffAttrs, + pub inputs: Vec, + pub output: TypeTree, +} diff --git a/compiler/rustc_middle/src/middle/mod.rs b/compiler/rustc_middle/src/middle/mod.rs index 85c5af9ca13cb..43e60c2571cc0 100644 --- a/compiler/rustc_middle/src/middle/mod.rs +++ b/compiler/rustc_middle/src/middle/mod.rs @@ -1,3 +1,4 @@ +pub mod autodiff_attrs; pub mod codegen_fn_attrs; pub mod debugger_visualizer; pub mod dependency_format; @@ -32,6 +33,7 @@ pub mod privacy; pub mod region; pub mod resolve_bound_vars; pub mod stability; +pub mod typetree; pub fn provide(providers: &mut crate::query::Providers) { limits::provide(providers); diff --git a/compiler/rustc_middle/src/middle/typetree.rs b/compiler/rustc_middle/src/middle/typetree.rs new file mode 100644 index 0000000000000..4049d32540bd2 --- /dev/null +++ b/compiler/rustc_middle/src/middle/typetree.rs @@ -0,0 +1,39 @@ +use std::fmt; +#[derive(Clone, Eq, PartialEq, TyEncodable, TyDecodable, HashStable, Debug)] +pub enum Kind { + Anything, + Integer, + Pointer, + Half, + Float, + Double, + Unknown, +} + +#[derive(Clone, Eq, PartialEq, TyEncodable, TyDecodable, HashStable, Debug)] +pub struct TypeTree(pub Vec); + +#[derive(Clone, Eq, PartialEq, TyEncodable, TyDecodable, HashStable, Debug)] +pub struct Type { + pub offset: isize, + pub size: usize, + pub kind: Kind, + pub child: TypeTree, +} + +impl Type { + pub fn add_offset(self, add: isize) -> Self { + let offset = match self.offset { + -1 => add, + x => add + x, + }; + + Self { size: self.size, kind: self.kind, child: self.child, offset } + } +} + +impl fmt::Display for Type { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + ::fmt(self, f) + } +} diff --git a/compiler/rustc_middle/src/query/erase.rs b/compiler/rustc_middle/src/query/erase.rs index e20e9d9312c1b..7f33139d67b8d 100644 --- a/compiler/rustc_middle/src/query/erase.rs +++ b/compiler/rustc_middle/src/query/erase.rs @@ -190,6 +190,10 @@ impl EraseType for (&'_ T0, &'_ [T1]) { type Result = [u8; size_of::<(&'static (), &'static [()])>()]; } +impl EraseType for (&'_ T0, &'_ [T1], &'_ [T2]) { + type Result = [u8; size_of::<(&'static (), &'static [()], &'static [()])>()]; +} + macro_rules! trivial { ($($ty:ty),+ $(,)?) => { $( diff --git a/compiler/rustc_middle/src/query/mod.rs b/compiler/rustc_middle/src/query/mod.rs index 062b03e71fdc1..8d85928374b5c 100644 --- a/compiler/rustc_middle/src/query/mod.rs +++ b/compiler/rustc_middle/src/query/mod.rs @@ -10,6 +10,7 @@ use crate::dep_graph; use crate::infer::canonical::{self, Canonical}; use crate::lint::LintExpectation; use crate::metadata::ModChild; +use crate::middle::autodiff_attrs::{AutoDiffAttrs, AutoDiffItem}; use crate::middle::codegen_fn_attrs::CodegenFnAttrs; use crate::middle::debugger_visualizer::DebuggerVisualizerFile; use crate::middle::exported_symbols::{ExportedSymbol, SymbolExportInfo}; @@ -1229,6 +1230,13 @@ rustc_queries! { separate_provide_extern } + /// The list autodiff extern functions in current crate + query autodiff_attrs(def_id: DefId) -> &'tcx AutoDiffAttrs { + desc { |tcx| "computing autodiff attributes of `{}`", tcx.def_path_str(def_id) } + arena_cache + cache_on_disk_if { def_id.is_local() } + } + query asm_target_features(def_id: DefId) -> &'tcx FxIndexSet { desc { |tcx| "computing target features for inline asm of `{}`", tcx.def_path_str(def_id) } } @@ -1878,7 +1886,7 @@ rustc_queries! { separate_provide_extern } - query collect_and_partition_mono_items(_: ()) -> (&'tcx DefIdSet, &'tcx [CodegenUnit<'tcx>]) { + query collect_and_partition_mono_items(_: ()) -> (&'tcx DefIdSet, &'tcx [AutoDiffItem], &'tcx [CodegenUnit<'tcx>]) { eval_always desc { "collect_and_partition_mono_items" } } diff --git a/compiler/rustc_middle/src/ty/mod.rs b/compiler/rustc_middle/src/ty/mod.rs index 739d4fa886ec3..31d60e97cded7 100644 --- a/compiler/rustc_middle/src/ty/mod.rs +++ b/compiler/rustc_middle/src/ty/mod.rs @@ -176,6 +176,8 @@ pub struct ResolverGlobalCtxt { /// Mapping from ident span to path span for paths that don't exist as written, but that /// exist under `std`. For example, wrote `str::from_utf8` instead of `std::str::from_utf8`. pub confused_type_with_std_module: FxHashMap, + /// Mapping of autodiff function IDs + pub autodiff_map: FxHashMap, pub doc_link_resolutions: FxHashMap, pub doc_link_traits_in_scope: FxHashMap>, pub all_macro_rules: FxHashMap>, diff --git a/compiler/rustc_monomorphize/Cargo.toml b/compiler/rustc_monomorphize/Cargo.toml index fe097424e8ad4..b75941e71989a 100644 --- a/compiler/rustc_monomorphize/Cargo.toml +++ b/compiler/rustc_monomorphize/Cargo.toml @@ -18,3 +18,4 @@ rustc_middle = { path = "../rustc_middle" } rustc_session = { path = "../rustc_session" } rustc_span = { path = "../rustc_span" } rustc_target = { path = "../rustc_target" } +rustc_symbol_mangling = { path = "../rustc_symbol_mangling" } diff --git a/compiler/rustc_monomorphize/src/collector.rs b/compiler/rustc_monomorphize/src/collector.rs index 82fee7c8dfe58..baac8d98e8b1b 100644 --- a/compiler/rustc_monomorphize/src/collector.rs +++ b/compiler/rustc_monomorphize/src/collector.rs @@ -1244,6 +1244,7 @@ impl<'v> RootCollector<'_, 'v> { /// monomorphized copy of the start lang item based on /// the return type of `main`. This is not needed when /// the user writes their own `start` manually. + /// TODO: remove annotations after automatic differentation pass fn push_extra_entry_roots(&mut self) { let Some((main_def_id, EntryFnType::Main { .. })) = self.entry_fn else { return; diff --git a/compiler/rustc_monomorphize/src/partitioning.rs b/compiler/rustc_monomorphize/src/partitioning.rs index 4009e28924068..fa39f35dc334e 100644 --- a/compiler/rustc_monomorphize/src/partitioning.rs +++ b/compiler/rustc_monomorphize/src/partitioning.rs @@ -1282,12 +1282,12 @@ pub fn provide(providers: &mut Providers) { providers.collect_and_partition_mono_items = collect_and_partition_mono_items; providers.is_codegened_item = |tcx, def_id| { - let (all_mono_items, _) = tcx.collect_and_partition_mono_items(()); + let (all_mono_items, _, _) = tcx.collect_and_partition_mono_items(()); all_mono_items.contains(&def_id) }; providers.codegen_unit = |tcx, name| { - let (_, all) = tcx.collect_and_partition_mono_items(()); + let (_, _, all) = tcx.collect_and_partition_mono_items(()); all.iter() .find(|cgu| cgu.name() == name) .unwrap_or_else(|| panic!("failed to find cgu with name {name:?}")) diff --git a/compiler/rustc_passes/src/check_attr.rs b/compiler/rustc_passes/src/check_attr.rs index a8a27e761cb3f..19aa8a308b3a2 100644 --- a/compiler/rustc_passes/src/check_attr.rs +++ b/compiler/rustc_passes/src/check_attr.rs @@ -233,6 +233,7 @@ impl CheckAttrVisitor<'_> { self.check_generic_attr(hir_id, attr, target, Target::Fn); self.check_proc_macro(hir_id, target, ProcMacroKind::Derive) } + sym::autodiff_into => self.check_autodiff(hir_id, attr, span, target), _ => {} } @@ -2394,6 +2395,20 @@ impl CheckAttrVisitor<'_> { self.abort.set(true); } } + + /// Checks if `#[autodiff]` is applied to an item other than a foreign module. + fn check_autodiff(&self, _hir_id: HirId, _attr: &Attribute, _span: Span, _target: Target) { + //match target { + // Target::ForeignMod => {} + // _ => { + // self.tcx + // .sess + // .struct_span_err(attr.span, "attribute should be applied to an `extern` block") + // .span_label(span, "not an `extern` block") + // .emit(); + // } + //} + } } impl<'tcx> Visitor<'tcx> for CheckAttrVisitor<'tcx> { diff --git a/compiler/rustc_resolve/src/lib.rs b/compiler/rustc_resolve/src/lib.rs index 501747df5c908..58e6d82595e7a 100644 --- a/compiler/rustc_resolve/src/lib.rs +++ b/compiler/rustc_resolve/src/lib.rs @@ -1522,6 +1522,7 @@ impl<'a, 'tcx> Resolver<'a, 'tcx> { trait_impls: self.trait_impls, proc_macros, confused_type_with_std_module, + autodiff_map: Default::default(), doc_link_resolutions: self.doc_link_resolutions, doc_link_traits_in_scope: self.doc_link_traits_in_scope, all_macro_rules: self.all_macro_rules, diff --git a/compiler/rustc_session/src/options.rs b/compiler/rustc_session/src/options.rs index 30c8b9d67002c..e6401d2fbfbab 100644 --- a/compiler/rustc_session/src/options.rs +++ b/compiler/rustc_session/src/options.rs @@ -1537,6 +1537,8 @@ options! { "enables LTO for dylib crate type"), emit_stack_sizes: bool = (false, parse_bool, [UNTRACKED], "emit a section containing stack size metadata (default: no)"), + enzyme_print_activity: bool = (false, parse_bool, [TRACKED], + "print type trees for functions passed to enzyme"), emit_thin_lto: bool = (true, parse_bool, [TRACKED], "emit the bc module with thin LTO info (default: yes)"), export_executable_symbols: bool = (false, parse_bool, [TRACKED], diff --git a/compiler/rustc_span/src/symbol.rs b/compiler/rustc_span/src/symbol.rs index 3f99d2a4b1ffb..a461745d9162c 100644 --- a/compiler/rustc_span/src/symbol.rs +++ b/compiler/rustc_span/src/symbol.rs @@ -437,6 +437,7 @@ symbols! { attributes, augmented_assignments, auto_traits, + autodiff_into, automatically_derived, avx, avx512_target_feature, @@ -1022,6 +1023,7 @@ symbols! { miri, misc, mmx_reg, + mode, modifiers, module, module_path, diff --git a/config.example.toml b/config.example.toml index 66fa91d4bad15..6050848cb3a05 100644 --- a/config.example.toml +++ b/config.example.toml @@ -142,6 +142,9 @@ change-id = 116998 # Whether or not to specify `-DLLVM_TEMPORARILY_ALLOW_OLD_TOOLCHAIN=YES` #allow-old-toolchain = false +# Whether to build enzyme +#enzyme = false + # Whether to include the Polly optimizer. #polly = false diff --git a/library/autodiff/Cargo.lock b/library/autodiff/Cargo.lock new file mode 100644 index 0000000000000..b11b872e7dbd9 --- /dev/null +++ b/library/autodiff/Cargo.lock @@ -0,0 +1,314 @@ +# This file is automatically @generated by Cargo. +# It is not intended for manual editing. +version = 3 + +[[package]] +name = "autocfg" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa" + +[[package]] +name = "autodiff" +version = "0.1.0" +dependencies = [ + "macrotest", + "ndarray", + "proc-macro-error", + "proc-macro2", + "quote", + "syn 1.0.109", + "trybuild", +] + +[[package]] +name = "basic-toml" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "24c12265665aebaa236af9bbe266681bcc9c5666192119e3d8335cf083aca26f" +dependencies = [ + "serde", +] + +[[package]] +name = "diff" +version = "0.1.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "56254986775e3233ffa9c4d7d3faaf6d36a2c09d30b20687e9f88bc8bafc16c8" + +[[package]] +name = "glob" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d2fabcfbdc87f4758337ca535fb41a6d701b65693ce38287d856d1674551ec9b" + +[[package]] +name = "itoa" +version = "1.0.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "af150ab688ff2122fcef229be89cb50dd66af9e01a4ff320cc137eecc9bacc38" + +[[package]] +name = "macrotest" +version = "1.0.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7489ae0986ce45414b7b3122c2e316661343ecf396b206e3e15f07c846616f10" +dependencies = [ + "diff", + "glob", + "prettyplease", + "serde", + "serde_json", + "syn 1.0.109", + "toml", +] + +[[package]] +name = "matrixmultiply" +version = "0.3.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7574c1cf36da4798ab73da5b215bbf444f50718207754cb522201d78d1cd0ff2" +dependencies = [ + "autocfg", + "rawpointer", +] + +[[package]] +name = "ndarray" +version = "0.15.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "adb12d4e967ec485a5f71c6311fe28158e9d6f4bc4a447b474184d0f91a8fa32" +dependencies = [ + "matrixmultiply", + "num-complex", + "num-integer", + "num-traits", + "rawpointer", +] + +[[package]] +name = "num-complex" +version = "0.4.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1ba157ca0885411de85d6ca030ba7e2a83a28636056c7c699b07c8b6f7383214" +dependencies = [ + "num-traits", +] + +[[package]] +name = "num-integer" +version = "0.1.45" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "225d3389fb3509a24c93f5c29eb6bde2586b98d9f016636dff58d7c6f7569cd9" +dependencies = [ + "autocfg", + "num-traits", +] + +[[package]] +name = "num-traits" +version = "0.2.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "39e3200413f237f41ab11ad6d161bc7239c84dcb631773ccd7de3dfe4b5c267c" +dependencies = [ + "autocfg", +] + +[[package]] +name = "once_cell" +version = "1.18.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dd8b5dd2ae5ed71462c540258bedcb51965123ad7e7ccf4b9a8cafaa4a63576d" + +[[package]] +name = "prettyplease" +version = "0.1.25" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6c8646e95016a7a6c4adea95bafa8a16baab64b583356217f2c85db4a39d9a86" +dependencies = [ + "proc-macro2", + "syn 1.0.109", +] + +[[package]] +name = "proc-macro-error" +version = "1.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "da25490ff9892aab3fcf7c36f08cfb902dd3e71ca0f9f9517bea02a73a5ce38c" +dependencies = [ + "proc-macro-error-attr", + "proc-macro2", + "quote", + "syn 1.0.109", + "version_check", +] + +[[package]] +name = "proc-macro-error-attr" +version = "1.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a1be40180e52ecc98ad80b184934baf3d0d29f979574e439af5a55274b35f869" +dependencies = [ + "proc-macro2", + "quote", + "version_check", +] + +[[package]] +name = "proc-macro2" +version = "1.0.69" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "134c189feb4956b20f6f547d2cf727d4c0fe06722b20a0eec87ed445a97f92da" +dependencies = [ + "unicode-ident", +] + +[[package]] +name = "quote" +version = "1.0.33" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5267fca4496028628a95160fc423a33e8b2e6af8a5302579e322e4b520293cae" +dependencies = [ + "proc-macro2", +] + +[[package]] +name = "rawpointer" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "60a357793950651c4ed0f3f52338f53b2f809f32d83a07f72909fa13e4c6c1e3" + +[[package]] +name = "ryu" +version = "1.0.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1ad4cc8da4ef723ed60bced201181d83791ad433213d8c24efffda1eec85d741" + +[[package]] +name = "serde" +version = "1.0.190" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "91d3c334ca1ee894a2c6f6ad698fe8c435b76d504b13d436f0685d648d6d96f7" +dependencies = [ + "serde_derive", +] + +[[package]] +name = "serde_derive" +version = "1.0.190" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "67c5609f394e5c2bd7fc51efda478004ea80ef42fee983d5c67a65e34f32c0e3" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.38", +] + +[[package]] +name = "serde_json" +version = "1.0.107" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6b420ce6e3d8bd882e9b243c6eed35dbc9a6110c9769e74b584e0d68d1f20c65" +dependencies = [ + "itoa", + "ryu", + "serde", +] + +[[package]] +name = "syn" +version = "1.0.109" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "72b64191b275b66ffe2469e8af2c1cfe3bafa67b529ead792a6d0160888b4237" +dependencies = [ + "proc-macro2", + "quote", + "unicode-ident", +] + +[[package]] +name = "syn" +version = "2.0.38" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e96b79aaa137db8f61e26363a0c9b47d8b4ec75da28b7d1d614c2303e232408b" +dependencies = [ + "proc-macro2", + "quote", + "unicode-ident", +] + +[[package]] +name = "termcolor" +version = "1.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6093bad37da69aab9d123a8091e4be0aa4a03e4d601ec641c327398315f62b64" +dependencies = [ + "winapi-util", +] + +[[package]] +name = "toml" +version = "0.5.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f4f7f0dd8d50a853a531c426359045b1998f04219d88799810762cd4ad314234" +dependencies = [ + "serde", +] + +[[package]] +name = "trybuild" +version = "1.0.85" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "196a58260a906cedb9bf6d8034b6379d0c11f552416960452f267402ceeddff1" +dependencies = [ + "basic-toml", + "glob", + "once_cell", + "serde", + "serde_derive", + "serde_json", + "termcolor", +] + +[[package]] +name = "unicode-ident" +version = "1.0.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3354b9ac3fae1ff6755cb6db53683adb661634f67557942dea4facebec0fee4b" + +[[package]] +name = "version_check" +version = "0.9.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "49874b5167b65d7193b8aba1567f5c7d93d001cafc34600cee003eda787e483f" + +[[package]] +name = "winapi" +version = "0.3.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5c839a674fcd7a98952e593242ea400abe93992746761e38641405d28b00f419" +dependencies = [ + "winapi-i686-pc-windows-gnu", + "winapi-x86_64-pc-windows-gnu", +] + +[[package]] +name = "winapi-i686-pc-windows-gnu" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6" + +[[package]] +name = "winapi-util" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f29e6f9198ba0d26b4c9f07dbe6f9ed633e1f3d5b8b414090084349e46a52596" +dependencies = [ + "winapi", +] + +[[package]] +name = "winapi-x86_64-pc-windows-gnu" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f" diff --git a/library/autodiff/Cargo.toml b/library/autodiff/Cargo.toml new file mode 100644 index 0000000000000..cbbff8d375e3d --- /dev/null +++ b/library/autodiff/Cargo.toml @@ -0,0 +1,28 @@ +[package] +name = "autodiff" +version = "0.1.0" +edition = "2021" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + + +[profile.release] +lto = "fat" + +[profile.dev] +lto = "fat" + +[lib] +name = "autodiff" +proc-macro = true + +[dependencies] +quote = "1.0" +proc-macro2 = "1" +proc-macro-error = "1" +syn = { version = "1", features = ["extra-traits", "full", "visit", "visit-mut"]} + +[dev-dependencies] +macrotest = "1" +trybuild = "1" +ndarray = "0.15" diff --git a/library/autodiff/examples/array.rs b/library/autodiff/examples/array.rs new file mode 100644 index 0000000000000..60c6b63fd84cb --- /dev/null +++ b/library/autodiff/examples/array.rs @@ -0,0 +1,23 @@ +use autodiff::autodiff; + +#[autodiff(d_array, Reverse, Active, Duplicated)] +fn array(arr: &[[[f32; 2]; 2]; 2]) -> f32 { + arr[0][0][0] * arr[1][1][1] +} + +fn main() { + let arr = [[[1.0, 1.0], [1.0, 1.0]], [[1.0, 1.0], [1.0, 1.0]]]; + let mut d_arr = [[[0.0; 2]; 2]; 2]; + + d_array(&arr, &mut d_arr, 1.0); + + dbg!(&d_arr); +} + +#[cfg(test)] +mod tests { + #[test] + fn main() { + super::main() + } +} diff --git a/library/autodiff/examples/box.rs b/library/autodiff/examples/box.rs new file mode 100644 index 0000000000000..5d4f114830bf4 --- /dev/null +++ b/library/autodiff/examples/box.rs @@ -0,0 +1,24 @@ +use autodiff::autodiff; + +#[autodiff(cos_box, Reverse, Active, Duplicated)] +fn sin(x: &Box) -> f32 { + f32::sin(**x) +} + +fn main() { + let x = Box::::new(3.14); + let mut df_dx = Box::::new(0.0); + cos_box(&x, &mut df_dx, 1.0); + + dbg!(&df_dx); + + assert!(*df_dx == f32::cos(*x)); +} + +#[cfg(test)] +mod tests { + #[test] + fn main() { + super::main() + } +} diff --git a/library/autodiff/examples/broken_matvec.rs b/library/autodiff/examples/broken_matvec.rs new file mode 100644 index 0000000000000..0c4b2cfe6e927 --- /dev/null +++ b/library/autodiff/examples/broken_matvec.rs @@ -0,0 +1,34 @@ +use autodiff::autodiff; + +type Matrix = Vec>; +type Vector = Vec; + +#[autodiff(d_matvec, Forward, Const)] +fn matvec(#[dup] mat: &Matrix, vec: &Vector, #[dup] out: &mut Vector) { + for i in 0..mat.len() - 1 { + for j in 0..mat[0].len() - 1 { + out[i] += mat[i][j] * vec[j]; + } + } +} + +fn main() { + let mat = vec![vec![1.0, 1.0], vec![1.0, 1.0]]; + let mut d_mat = vec![vec![0.0, 0.0], vec![0.0, 0.0]]; + let inp = vec![1.0, 1.0]; + let mut out = vec![0.0, 0.0]; + let mut out_tang = vec![0.0, 1.0]; + + //matvec(&mat, &inp, &mut out); + d_matvec(&mat, &mut d_mat, &inp, &mut out, &mut out_tang); + + dbg!(&out); +} + +#[cfg(test)] +mod tests { + #[test] + fn main() { + super::main() + } +} diff --git a/library/autodiff/examples/hessian_sin.rs b/library/autodiff/examples/hessian_sin.rs new file mode 100644 index 0000000000000..6b1e776476fd2 --- /dev/null +++ b/library/autodiff/examples/hessian_sin.rs @@ -0,0 +1,28 @@ +use autodiff::autodiff; + +fn sin(x: &Vec, y: &mut f32) { + *y = x.into_iter().map(|x| f32::sin(*x)).sum() +} + +#[autodiff(sin, Reverse, Const, Duplicated, Duplicated)] +fn jac(x: &Vec, d_x: &mut Vec, y: &mut f32, y_t: &f32); + +#[autodiff(jac, Forward, Const, Duplicated, Const, Const, Const)] +fn hessian(x: &Vec, y_x: &Vec, d_x: &mut Vec, y: &mut f32, y_t: &f32); + +fn main() { + let inp = vec![3.1415 / 2., 1.0, 0.5]; + let mut d_inp = vec![0.0, 0.0, 0.0]; + let mut y = 0.0; + let tang = vec![1.0, 0.0, 0.0]; + hessian(&inp, &tang, &mut d_inp, &mut y, &1.0); + dbg!(&d_inp); +} + +#[cfg(test)] +mod tests { + #[test] + fn main() { + super::main() + } +} diff --git a/library/autodiff/examples/ndarray.rs b/library/autodiff/examples/ndarray.rs new file mode 100644 index 0000000000000..34402c43cb3e6 --- /dev/null +++ b/library/autodiff/examples/ndarray.rs @@ -0,0 +1,25 @@ +use autodiff::autodiff; + +use ndarray::Array1; + +#[autodiff(d_collect, Reverse, Active)] +fn collect(#[dup] x: &Array1) -> f32 { + x[0] +} + +fn main() { + let a = Array1::zeros(19); + let mut d_a = Array1::zeros(19); + + d_collect(&a, &mut d_a, 1.0); + + dbg!(&d_a); +} + +#[cfg(test)] +mod tests { + #[test] + fn main() { + super::main() + } +} diff --git a/library/autodiff/examples/rosenbrock_fwd.rs b/library/autodiff/examples/rosenbrock_fwd.rs new file mode 100644 index 0000000000000..a3ab7a47578d0 --- /dev/null +++ b/library/autodiff/examples/rosenbrock_fwd.rs @@ -0,0 +1,34 @@ +use autodiff::autodiff; + +#[autodiff(d_rosenbrock, Forward, DuplicatedNoNeed)] +fn rosenbrock(#[dup] x: &[f64; 2]) -> f64 { + let mut res = 0.0; + for i in 0..(x.len() - 1) { + let a = x[i + 1] - x[i] * x[i]; + let b = x[i] - 1.0; + res += 100.0 * a * a + b * b; + } + res +} + +fn main() { + let x = [3.14, 2.4]; + let output = rosenbrock(&x); + println!("{output}"); + let df_dx = d_rosenbrock(&x, &[1.0, 0.0]); + let df_dy = d_rosenbrock(&x, &[0.0, 1.0]); + + dbg!(&df_dx, &df_dy); + + // https://www.wolframalpha.com/input?i2d=true&i=x%3D3.14%3B+y%3D2.4%3B+D%5Brosenbrock+function%5C%2840%29x%5C%2844%29+y%5C%2841%29+%2Cy%5D + assert!((df_dx - 9373.54).abs() < 0.1); + assert!((df_dy - (-1491.92)).abs() < 0.1); +} + +#[cfg(test)] +mod tests { + #[test] + fn main() { + super::main() + } +} diff --git a/library/autodiff/examples/rosenbrock_fwd_iter.rs b/library/autodiff/examples/rosenbrock_fwd_iter.rs new file mode 100644 index 0000000000000..1648014392f19 --- /dev/null +++ b/library/autodiff/examples/rosenbrock_fwd_iter.rs @@ -0,0 +1,34 @@ +use autodiff::autodiff; + +#[autodiff(d_rosenbrock, Forward, DuplicatedNoNeed)] +fn rosenbrock(#[dup] x: &[f64; 2]) -> f64 { + (0..x.len() - 1) + .map(|i| { + let (a, b) = (x[i + 1] - x[i] * x[i], x[i] - 1.0); + 100.0 * a * a + b * b + }) + .sum() +} + +fn main() { + let x = [3.14f64, 2.4]; + let output = rosenbrock(&x); + println!("{output}"); + + let df_dx = d_rosenbrock(&x, &[1.0, 0.0]); + let df_dy = d_rosenbrock(&x, &[0.0, 1.0]); + + dbg!(&df_dx, &df_dy); + + // https://www.wolframalpha.com/input?i2d=true&i=x%3D3.14%3B+y%3D2.4%3B+D%5Brosenbrock+function%5C%2840%29x%5C%2844%29+y%5C%2841%29+%2Cy%5D + assert!((df_dx - 9373.54).abs() < 0.1); + assert!((df_dy - (-1491.92)).abs() < 0.1); +} + +#[cfg(test)] +mod tests { + #[test] + fn main() { + super::main() + } +} diff --git a/library/autodiff/examples/rosenbrock_rev.rs b/library/autodiff/examples/rosenbrock_rev.rs new file mode 100644 index 0000000000000..b4ce00b5afe9d --- /dev/null +++ b/library/autodiff/examples/rosenbrock_rev.rs @@ -0,0 +1,33 @@ +use autodiff::autodiff; + +#[autodiff(d_rosenbrock, Reverse, Active)] +fn rosenbrock(#[dup] x: &[f64; 2]) -> f64 { + let mut res = 0.0; + for i in 0..(x.len() - 1) { + let a = x[i + 1] - x[i] * x[i]; + let b = x[i] - 1.0; + res += 100.0 * a * a + b * b; + } + res +} + +fn main() { + let x = [3.14, 2.4]; + let output = rosenbrock(&x); + println!("{output}"); + + let mut df_dx = [0.0f64; 2]; + d_rosenbrock(&x, &mut df_dx, 1.0); + + // https://www.wolframalpha.com/input?i2d=true&i=x%3D3.14%3B+y%3D2.4%3B+D%5Brosenbrock+function%5C%2840%29x%5C%2844%29+y%5C%2841%29+%2Cy%5D + assert!((df_dx[0] - 9373.54).abs() < 0.01); + assert!((df_dx[1] - (-1491.92)).abs() < 0.01); +} + +#[cfg(test)] +mod tests { + #[test] + fn main() { + super::main() + } +} diff --git a/library/autodiff/examples/sin.rs b/library/autodiff/examples/sin.rs new file mode 100644 index 0000000000000..1655b1e7ecd09 --- /dev/null +++ b/library/autodiff/examples/sin.rs @@ -0,0 +1,36 @@ +use autodiff::autodiff; + +#[autodiff(cos_inplace, Reverse, Const)] +fn sin_inplace(#[dup] x: &f32, #[dup] y: &mut f32) { + *y = x.sin(); +} + + +fn main() { + // Here we can use ==, even though we work on f32. + // Enzyme will recognize the sin function and replace it with llvm's cos function (see below). + // Calling f32::cos directly will also result in calling llvm's cos function. + let a = 3.1415; + let mut da = 0.0; + let mut y = 0.0; + cos_inplace(&a, &mut da, &mut y, &mut 1.0); + + dbg!(&a, &da, &y); + assert!(da - f32::cos(a) == 0.0); +} + +// Just for curious readers, this is the (inner) function that Enzyme does generate: +// define internal { float } @diffe_ZN3sin3sin17h18f17f71fe94e58fE(float %0, float %1) unnamed_addr #35 { +// %3 = call fast float @llvm.cos.f32(float %0) +// %4 = fmul fast float %1, %3 +// %5 = insertvalue { float } undef, float %4, 0 +// ret { float } %5 +// } + +#[cfg(test)] +mod tests { + #[test] + fn main() { + super::main() + } +} diff --git a/library/autodiff/examples/sqrt.rs b/library/autodiff/examples/sqrt.rs new file mode 100644 index 0000000000000..d15c6f5ec2051 --- /dev/null +++ b/library/autodiff/examples/sqrt.rs @@ -0,0 +1,21 @@ +use autodiff::autodiff; + +#[autodiff(d_sqrt, Reverse, Active)] +fn sqrt(#[active] a: f32, #[dup] b: &f32, c: &f32, #[active] d: f32) -> f32 { + a * (b * b + c*c*d*d).sqrt() +} + +fn main() { + let mut d_b = 0.0; + + let (d_a, d_d) = d_sqrt(1.0, &1.0, &mut d_b, &1.0, 1.0, 1.0); + dbg!(d_a, d_b, d_d); +} + +#[cfg(test)] +mod tests { + #[test] + fn main() { + super::main() + } +} diff --git a/library/autodiff/examples/struct.rs b/library/autodiff/examples/struct.rs new file mode 100644 index 0000000000000..1235307fdbcbf --- /dev/null +++ b/library/autodiff/examples/struct.rs @@ -0,0 +1,33 @@ +use autodiff::autodiff; + +use std::io; + +// Will be represented as {f32, i16, i16} when passed by reference +// will be represented as i64 if passed by value +struct Foo { + c1: i16, + a: f32, + c2: i16, +} + +#[autodiff(cos, Reverse, Active, Duplicated)] +fn sin(x: &Foo) -> f32 { + assert!(x.c1 < x.c2); + f32::sin(x.a) +} + +fn main() { + let mut s = String::new(); + println!("Please enter a value for c1"); + io::stdin().read_line(&mut s).unwrap(); + let c2 = s.trim_end().parse::().unwrap(); + dbg!(c2); + + let foo = Foo { c1: 4, a: 3.14, c2 }; + let mut df_dfoo = Foo { c1: 4, a: 0.0, c2 }; + + dbg!(df_dfoo.a); + dbg!(cos(&foo, &mut df_dfoo, 1.0)); + dbg!(df_dfoo.a); + dbg!(f32::cos(foo.a)); +} diff --git a/library/autodiff/examples/vec.rs b/library/autodiff/examples/vec.rs new file mode 100644 index 0000000000000..e82618fac4dac --- /dev/null +++ b/library/autodiff/examples/vec.rs @@ -0,0 +1,24 @@ +use autodiff::autodiff; + +#[autodiff(d_sum, Forward, Duplicated)] +fn sum(#[dup] x: &Vec>) -> f32 { + x.into_iter().map(|x| x.into_iter().map(|x| x.sqrt())).flatten().sum() +} + +fn main() { + let a = vec![vec![1.0, 2.0, 4.0, 8.0]]; + //let mut b = vec![vec![0.0, 0.0, 0.0, 0.0]]; + let b = vec![vec![1.0, 0.0, 0.0, 0.0]]; + + dbg!(&d_sum(&a, &b)); + + dbg!(&b); +} + +#[cfg(test)] +mod tests { + #[test] + fn main() { + super::main() + } +} diff --git a/library/autodiff/examples_broken/biquad.rs b/library/autodiff/examples_broken/biquad.rs new file mode 100644 index 0000000000000..7689b1cd1fc51 --- /dev/null +++ b/library/autodiff/examples_broken/biquad.rs @@ -0,0 +1,54 @@ +use autodiff::autodiff; + +#[derive(Debug)] +struct Biquad { + coeffs: [[f32; 5]; N], +} + +impl Biquad { + pub fn new() -> Self { + Biquad { coeffs: [[0.0; 5]; N] } + } + + pub fn process(&self, samples: &[f32], target: &[f32]) -> f32 { + // do some horrible inefficient biquad filtering + let mut samples = samples.to_vec(); + let mut samples_out = vec![0.0; samples.len()]; + + for coeff_set in self.coeffs { + for idx in 0..samples.len() { + samples_out[idx] = coeff_set[0] * samples[idx]; + + if idx > 0 { + samples_out[idx] += coeff_set[1] * samples[idx - 1] - + coeff_set[3] * samples_out[idx - 1]; + } + if idx > 1 { + samples_out[idx] += coeff_set[2] * samples[idx - 2] - + coeff_set[4] * samples_out[idx - 2]; + } + } + + (samples, samples_out) = (samples_out, samples); + } + + samples_out.into_iter().zip(target.into_iter()).map(|(a, b)| a - b).sum() + } + + #[autodiff(Self::process, Reverse, Active)] + pub fn deriv(#[dup] &self, params: &mut Self, samples: &[f32], target: &[f32], ret_adj: f32); +} + +fn main() { + let biquad = Biquad::<10>::new(); + let mut dbiquad = Biquad::<10>::new(); + + // create ramp and pulse train + let signal = (0..1024).map(|x| (x as f32) / 1024.0).collect::>(); + let target = (0..1024).map(|x| if x % 2 == 0 { 0.0 } else { 1.0 }).collect::>(); + + dbg!(&biquad.process(&signal, &target)); + biquad.deriv(&mut dbiquad, &signal, &target, 1.0); + + dbg!(&dbiquad); +} diff --git a/library/autodiff/examples_broken/broken_iter.rs b/library/autodiff/examples_broken/broken_iter.rs new file mode 100644 index 0000000000000..16d205f7373c8 --- /dev/null +++ b/library/autodiff/examples_broken/broken_iter.rs @@ -0,0 +1,20 @@ +#![feature(bench_black_box)] +use autodiff::autodiff; +use std::ptr; + +#[autodiff(sin_vec, Reverse, Active)] +fn cos_vec(#[dup] x: &Vec) -> f32 { + // uses enum internally and breaks + let res = x.into_iter().collect::>(); + + *res[0] +} + +fn main() { + let x = vec![1.0, 1.0, 1.0]; + let mut d_x = vec![0.0; 3]; + + sin_vec(&x, &mut d_x, 1.0); + + dbg!(&d_x, &x); +} diff --git a/library/autodiff/examples_broken/broken_recursive.rs b/library/autodiff/examples_broken/broken_recursive.rs new file mode 100644 index 0000000000000..a1f3ff25eb511 --- /dev/null +++ b/library/autodiff/examples_broken/broken_recursive.rs @@ -0,0 +1,66 @@ +#![feature(bench_black_box)] +use autodiff::autodiff; + +// TODO: As seen by the bloated code generated for the iterative version, +// we definetly have to disable unroll, slpvec, loop-vec before AD. +// We also should check if we have other opts that Julia, C++, Fortran etc. don't have +// and which could make our input code more "complex". +// We then however have to start doing whole-module opt after AD to re-include them, +// instead of just using enzyme to optimize the generated function. + +#[autodiff(d_power_recursive, Forward, DuplicatedNoNeed)] +fn power_recursive(#[dup] a: f64, n: i32) -> f64 { + if n == 0 { + return 1.0; + } + return a * power_recursive(a, n - 1); +} + +#[autodiff(d_power_iterative, Reverse, DuplicatedNoNeed)] +fn power_iterative(#[active] a: f64, n: i32) -> f64 { + let mut res = 1.0; + for _ in 0..n { + res *= a; + } + res +} + +fn main() { + // d/dx x^n = n * x^(n-1) + let n = 4; + let nf = n as f64; + let a = 1.337; + assert!(power_recursive(a, n) == power_iterative(a, n)); + let dpr = d_power_recursive(a, 1.0, n); + let dpi = d_power_iterative(a, n, 1.0); + let control = nf * a.powi(n - 1); + dbg!(dpr); + dbg!(dpi); + dbg!(control); + assert!(dpr == control); + assert!(dpi == control); +} + +// Again, for the curious. We can find n * x^(n-1) nicely in the LLVM-IR +// +// define internal double @fwddiffe_ZN9recursive15power_recursive17h789de751cfc6154dE(double %0, double %1, i32 %2) unnamed_addr #8 { +// => if (n == 0) goto 5: and return 0. Correct, since for n==0 we have 0 * x ^ (0-1) = 0 +// => if (n != 0) goto 7: +// %4 = icmp eq i32 %2, 0 +// br i1 %4, label %5, label %7 +// +// 5: ; preds = %7, %3 +// %6 = phi fast double [ %14, %7 ], [ 0.000000e+00, %3 ] +// ret double %6 +// +// 7: ; preds = %3 +// => reduce n by 1, +// %8 = add i32 %2, -1 +// %9 = call { double, double } @fwddiffe_ZN9recursive15power_recursive17h789de751cfc6154dE.1229(double %0, double %1, i32 %8) +// %10 = extractvalue { double, double } %9, 0 +// %11 = extractvalue { double, double } %9, 1 +// %12 = fmul fast double %11, %0 +// %13 = fmul fast double %1, %10 +// %14 = fadd fast double %12, %13 +// br label %5 +// } diff --git a/library/autodiff/examples_broken/broken_second_order.rs b/library/autodiff/examples_broken/broken_second_order.rs new file mode 100644 index 0000000000000..8b427d7dae36a --- /dev/null +++ b/library/autodiff/examples_broken/broken_second_order.rs @@ -0,0 +1,17 @@ +#![feature(bench_black_box)] +use autodiff::autodiff; + +fn sin(x: &f32) -> f32 { + f32::sin(*x) +} + +#[autodiff(sin, Reverse, Active, Active)] +fn cos(x: &f32, adj: f32) -> f32; + +//#[autodiff(cos, Reverse, Active, Active, Const)] +//fn neg_sin(x: &f32, adj: f32, adj_sec: f32) -> f32; + +fn main() { + dbg!(&cos(&1.0, 1.0)); + //dbg!(&neg_sin(&1.0, 1.0, 1.0)); +} diff --git a/library/autodiff/src/gen.rs b/library/autodiff/src/gen.rs new file mode 100644 index 0000000000000..68aae56ea3311 --- /dev/null +++ b/library/autodiff/src/gen.rs @@ -0,0 +1,217 @@ +use crate::parser::{is_ref_mut, PrimalSig}; +use crate::parser::{Activity, DiffItem, Mode}; +use proc_macro2::TokenStream; +use proc_macro_error::abort; +use quote::{format_ident, quote}; +use syn::{parse_quote, FnArg, Ident, Pat, ReturnType, Type}; + +pub(crate) fn generate_header(item: &DiffItem) -> TokenStream { + let mode = match item.header.mode { + Mode::Forward => format_ident!("Forward"), + Mode::Reverse => format_ident!("Reverse"), + }; + let ret_act = item.header.ret_act.to_ident(); + let param_act = item.params.iter().map(|x| x.to_ident()); + + quote!(#[autodiff_into(#mode, #ret_act, #( #param_act, )*)]) +} + +pub(crate) fn primal_fnc(item: &mut DiffItem) -> TokenStream { + // construct body of primal if not given + let body = item.block.clone().map(|x| quote!(#x)).unwrap_or_else(|| { + let header_fnc = &item.header.name; + //let primal_wrapper = format_ident!("primal_{}", item.primal.ident); + //item.primal.ident = primal_wrapper.clone(); + let inputs = item.primal.inputs.iter().map(|x| only_ident(x)).collect::>(); + + quote!({ + #header_fnc(#(#inputs,)*) + }) + }); + + let sig = &item.primal; + let PrimalSig { ident, inputs, output } = sig; + + let ident = + if item.block.is_some() { ident.clone() } else { format_ident!("primal_{}", ident) }; + + let sig = quote!(fn #ident(#(#inputs,)*) #output); + + quote!( + #[autodiff_into] + #sig + #body + ) +} + +fn only_ident(arg: &FnArg) -> Ident { + match arg { + FnArg::Receiver(_) => format_ident!("self"), + FnArg::Typed(t) => match &*t.pat { + Pat::Ident(ident) => ident.ident.clone(), + _ => panic!(""), + }, + } +} + +fn only_type(arg: &FnArg) -> Type { + match arg { + FnArg::Receiver(_) => parse_quote!(Self), + FnArg::Typed(t) => match &*t.ty { + Type::Reference(t) => *t.elem.clone(), + x => x.clone(), + }, + } +} + +fn as_ref_mut(arg: &FnArg, name: &str, mutable: bool) -> FnArg { + match arg { + FnArg::Receiver(_) => { + let name = format_ident!("{}_self", name); + if mutable { parse_quote!(#name: &mut Self) } else { parse_quote!(#name: &Self) } + } + FnArg::Typed(t) => { + let inner = match &*t.ty { + Type::Reference(t) => &t.elem, + _ => panic!(""), // should not be reachable, as we checked mutability before + }; + + let pat_name = match &*t.pat { + Pat::Ident(x) => &x.ident, + _ => panic!(""), + }; + + let name = format_ident!("{}_{}", name, pat_name); + if mutable { parse_quote!(#name: &mut #inner) } else { parse_quote!(#name: &#inner) } + } + } +} + +pub(crate) fn adjoint_fnc(item: &DiffItem) -> TokenStream { + let mut res_inputs: Vec = Vec::new(); + let mut add_inputs: Vec = Vec::new(); + let out_type = match &item.primal.output { + ReturnType::Type(_, x) => Some(*x.clone()), + _ => None, + }; + + let mut outputs = if item.header.ret_act == Activity::Duplicated { + vec![out_type.clone().unwrap()] + } else { + vec![] + }; + + let PrimalSig { ident, inputs, .. } = &item.primal; + + for (input, activity) in inputs.iter().zip(item.params.iter()) { + res_inputs.push(input.clone()); + + match (item.header.mode, activity, is_ref_mut(&input)) { + (Mode::Forward, Activity::Duplicated|Activity::DuplicatedNoNeed, Some(true)) => { + res_inputs.push(as_ref_mut(&input, "grad", true)); + add_inputs.push(as_ref_mut(&input, "grad", true)); + } + (Mode::Forward, Activity::Duplicated|Activity::DuplicatedNoNeed, Some(false)) => { + res_inputs.push(as_ref_mut(&input, "dual", false)); + add_inputs.push(as_ref_mut(&input, "dual", false)); + out_type.clone().map(|x| outputs.push(x)); + } + (Mode::Forward, Activity::Duplicated, None) => outputs.push(only_type(&input)), + (Mode::Reverse, Activity::Duplicated, Some(false)) => { + res_inputs.push(as_ref_mut(&input, "grad", true)); + add_inputs.push(as_ref_mut(&input, "grad", true)); + } + (Mode::Reverse, Activity::Duplicated | Activity::DuplicatedNoNeed, Some(true)) => { + res_inputs.push(as_ref_mut(&input, "grad", false)); + add_inputs.push(as_ref_mut(&input, "grad", false)); + } + (Mode::Reverse, Activity::Active, None) => outputs.push(only_type(&input)), + _ => {} + } + } + + match (item.header.mode, item.header.ret_act) { + (Mode::Reverse, Activity::Active) => { + let t: FnArg = match &item.primal.output { + ReturnType::Type(_, ty) => parse_quote!(tang_y: #ty), + _ => panic!(""), + }; + res_inputs.push(t.clone()); + add_inputs.push(t); + } + _ => {} + } + + // for adjoint function -> take header if primal + // -> take ident of primal function + let adjoint_ident = if item.block.is_some() { + if let Some(ident) = item.header.name.get_ident() { + ident.clone() + } else { + abort!( + item.header.name, + "not a function name"; + help = "`#[autodiff]` function name should be a single word instead of path" + ); + } + } else { + item.primal.ident.clone() + }; + + let output = match outputs.len() { + 0 => quote!(), + 1 => { + let output = outputs.first().unwrap(); + + quote!(-> #output) + } + _ => quote!(-> (#(#outputs,)*)), + }; + + let sig = quote!(fn #adjoint_ident(#(#res_inputs,)*) #output); + let inputs = inputs + .iter() + .map(|x| match x { + FnArg::Typed(ty) => { + let pat = &ty.pat; + quote!(#pat) + } + FnArg::Receiver(_) => quote!(self), + }) + .collect::>(); + let add_inputs = add_inputs + .iter() + .map(|x| match x { + FnArg::Typed(ty) => { + let pat = &ty.pat; + quote!(#pat) + } + FnArg::Receiver(_) => quote!(self), + }) + .collect::>(); + + let call_ident = match item.block.is_some() { + false => { + let ident = format_ident!("primal_{}", ident); + if item.header.name.segments.first().unwrap().ident == "Self" { + quote!(Self::#ident) + } else { + quote!(#ident) + } + } + true => quote!(#ident), + }; + + let body = quote!({ + std::hint::black_box((#call_ident(#(#inputs,)*), #(#add_inputs,)*)); + + std::hint::black_box(unsafe { std::mem::zeroed() }) + }); + let header = generate_header(&item); + + quote!( + #header + #sig + #body + ) +} diff --git a/library/autodiff/src/lib.rs b/library/autodiff/src/lib.rs new file mode 100644 index 0000000000000..b1d265fa9c59b --- /dev/null +++ b/library/autodiff/src/lib.rs @@ -0,0 +1,31 @@ +use proc_macro::TokenStream; +use proc_macro_error::proc_macro_error; +use quote::quote; + +mod gen; +mod parser; + +#[proc_macro_attribute] +#[proc_macro_error] +pub fn autodiff(args: TokenStream, input: TokenStream) -> TokenStream { + let mut params = parser::parse(args.into(), input.clone().into()); + let (primal, adjoint) = (gen::primal_fnc(&mut params), gen::adjoint_fnc(¶ms)); + + let res = quote!( + #primal + #adjoint + ); + + res.into() +} + +#[test] +pub fn expanding() { + macrotest::expand("tests/expand/*.rs"); +} + +#[test] +fn ui() { + let t = trybuild::TestCases::new(); + t.compile_fail("tests/ui/*.rs"); +} diff --git a/library/autodiff/src/parser.rs b/library/autodiff/src/parser.rs new file mode 100644 index 0000000000000..d11eea24d5015 --- /dev/null +++ b/library/autodiff/src/parser.rs @@ -0,0 +1,464 @@ +use proc_macro2::TokenStream; +use proc_macro_error::abort; +use quote::{format_ident, quote}; +use syn::{ + parse::Parser, parse_quote, punctuated::Punctuated, Attribute, Block, FnArg, ForeignItemFn, + Ident, Item, Path, ReturnType, Signature, Token, Type, +}; + +#[derive(Debug)] +pub struct PrimalSig { + pub(crate) ident: Ident, + pub(crate) inputs: Vec, + pub(crate) output: ReturnType, +} + +#[derive(Debug)] +pub struct DiffItem { + pub(crate) header: Header, + pub(crate) params: Vec, + pub(crate) primal: PrimalSig, + pub(crate) block: Option>, +} + +#[derive(Copy, Clone, Debug, PartialEq)] +pub(crate) enum Mode { + Forward, + Reverse, +} + +#[derive(Copy, Clone, Debug, PartialEq)] +pub(crate) enum Activity { + Const, + Active, + Duplicated, + DuplicatedNoNeed, +} + +impl Activity { + fn from_header(name: Option<&Ident>) -> Activity { + if name.is_none() { + return Activity::Const; + } + + match name.unwrap().to_string().as_str() { + "Const" => Activity::Const, + "Active" => Activity::Active, + "Duplicated" => Activity::Duplicated, + "DuplicatedNoNeed" => Activity::DuplicatedNoNeed, + _ => { + abort!( + name, + "unknown activity"; + help = "`#[autodiff]` should use activities (Const|Active|Duplicated|DuplicatedNoNeed)" + ); + } + } + } + + fn from_inline(name: Attribute) -> Activity { + let name = name.path.segments.first().unwrap(); + match name.ident.to_string().as_str() { + "const" => Activity::Const, + "active" => Activity::Active, + "dup" => Activity::Duplicated, + "dup_noneed" => Activity::DuplicatedNoNeed, + _ => { + abort!( + name, + "unknown activity"; + help = "`#[autodiff]` should use activities (const|active|dup|dup_noneed)" + ); + } + } + } + + pub(crate) fn to_ident(&self) -> Ident { + format_ident!( + "{}", + match self { + Activity::Const => "Const", + Activity::Active => "Active", + Activity::Duplicated => "Duplicated", + Activity::DuplicatedNoNeed => "DuplicatedNoNeed", + } + ) + } +} + +#[derive(Debug)] +pub(crate) struct Header { + pub name: Path, + pub mode: Mode, + pub ret_act: Activity, +} + +impl Header { + fn from_params(name: &Path, mode: Option<&Ident>, ret_activity: Option<&Ident>) -> Self { + // parse mode and return activity + let mode = mode + .map(|x| match x.to_string().as_str() { + "forward" | "Forward" => Mode::Forward, + "reverse" | "Reverse" => Mode::Reverse, + _ => { + abort!( + mode, + "should be forward or reverse"; + help = "`#[autodiff]` modes should be either forward or reverse" + ); + } + }) + .unwrap_or(Mode::Forward); + let ret_act = Activity::from_header(ret_activity); + + // check for invalid mode and return activity combinations + match (mode, ret_act) { + (Mode::Forward, Activity::Active) => abort!( + ret_activity, + "active return for forward mode"; + help = "`#[autodiff]` return should be Const, Duplicated or DuplicatedNoNeed in forward mode" + ), + (Mode::Reverse, Activity::Duplicated | Activity::DuplicatedNoNeed) => abort!( + ret_activity, + "duplicated return for reverse mode"; + help = "`#[autodiff]` return should be Const or Active in reverse mode" + ), + + _ => {} + } + + Header { name: name.clone(), mode, ret_act } + } + + fn parse(args: TokenStream) -> (Header, Vec) { + let args_parsed: Vec<_> = + match Punctuated::::parse_terminated.parse(args.clone().into()) { + Ok(x) => x.into_iter().collect(), + Err(_) => abort!( + args, + "duplicated return for reverse mode"; + help = "`#[autodiff]` return should be Const or Active in reverse mode" + ), + }; + + match &args_parsed[..] { + [name] => (Self::from_params(&name, None, None), vec![]), + [name, mode] => { + (Self::from_params(&name, Some(&mode.get_ident().unwrap()), None), vec![]) + } + [name, mode, ret_act, rem @ ..] => { + let params = Self::from_params( + &name, + Some(&mode.get_ident().unwrap()), + Some(&ret_act.get_ident().unwrap()), + ); + let rem = rem.into_iter() + .map(|x| x.get_ident().unwrap()) + .map(|x| Activity::from_header(Some(x))) + .map(|x| match (params.mode, x) { + (Mode::Forward, Activity::Active) => { + abort!( + args, + "active argument in forward mode"; + help = "`#[autodiff]` forward mode should be either Const, Duplicated" + ); + }, + (_, x) => x, + }) + .collect(); + + (params, rem) + } + _ => { + abort!( + args, + "please specify the autodiff function"; + help = "`#[autodiff]` needs a function name for primal or adjoint" + ); + } + } + } +} + +pub(crate) fn is_ref_mut(t: &FnArg) -> Option { + match t { + FnArg::Receiver(pat) => Some(pat.mutability.is_some()), + FnArg::Typed(pat) => match &*pat.ty { + Type::Reference(t) => Some(t.mutability.is_some()), + _ => None, + }, + } +} + +fn is_scalar(t: &Type) -> bool { + let t_f32: Type = parse_quote!(f32); + let t_f64: Type = parse_quote!(f64); + t == &t_f32 || t == &t_f64 +} + +fn ret_arg(arg: &FnArg) -> Type { + match arg { + FnArg::Receiver(_) => parse_quote!(Self), + FnArg::Typed(t) => match &*t.ty { + Type::Reference(t) => *t.elem.clone(), + x => x.clone(), + }, + } +} + +pub(crate) fn reduce_params( + mut sig: Signature, + header_acts: Vec, + is_adjoint: bool, + header: &Header, +) -> (PrimalSig, Vec) { + let mut args = Vec::new(); + let mut ret = Vec::new(); + let mut acts = Vec::new(); + let mut last_arg: Option = None; + + let mut arg_it = sig.inputs.iter_mut(); + let mut header_acts_it = header_acts.iter(); + + while let Some(arg) = arg_it.next() { + // Compare current with last argument when parsing duplicated rules. This only + // happens when we parse the signature of adjoint/augmented primal function + if let Some(prev_arg) = last_arg.take() { + match (header.mode, is_ref_mut(&prev_arg), is_ref_mut(&arg)) { + (Mode::Forward, Some(false), Some(true) | None) => abort!( + arg, + "should be an immutable reference"; + help = "`#[autodiff]` input parameter should duplicate tangent into second parameter for forward mode" + ), + (Mode::Forward, Some(true), Some(false) | None) => abort!( + arg, + "should be a mutable reference"; + help = "`#[autodiff]` output parameter should duplicate derivative into second parameter for forward mode" + ), + (Mode::Reverse, Some(false), Some(false) | None) => abort!( + arg, + "should be a mutable reference"; + help = "`#[autodiff]` input parameter should duplicate derivative into second parameter for reverse mode" + ), + (Mode::Reverse, Some(true), Some(true) | None) => abort!( + arg, + "should be an immutable reference"; + help = "`#[autodiff]` input parameter should duplicate derivative into second parameter for reverse mode" + ), + _ => {} + } + + continue; + } + + // parse current attribute macro + let attrs: Vec<_> = match arg { + FnArg::Typed(pat) => pat.attrs.drain(..).collect(), + FnArg::Receiver(pat) => pat.attrs.drain(..).collect(), + }; + let attr = attrs.first(); + let act: Activity = match (header_acts.is_empty(), attr) { + (false, None) => header_acts_it.next().map(|x| *x).unwrap_or(Activity::Const), + (true, Some(x)) => Activity::from_inline(x.clone()), + (true, None) => Activity::Const, + _ => { + abort!( + arg, + "inline activity"; + help = "`#[autodiff]` should have activities either specified in header or as inline attributes" + ); + } + }; + + // compare indirection with activity + match (header.mode, is_ref_mut(&arg), act) { + (Mode::Forward, None, Activity::Duplicated) => abort!( + arg, + "type not behind reference"; + help = "`#[autodiff]` duplicated types should be behind a reference" + ), + (Mode::Forward, Some(false), Activity::DuplicatedNoNeed) => abort!( + arg, + "should be mutable reference"; + help = "`#[autodiff]` parameter should be output for DuplicatedNoNeed activity" + ), + (Mode::Reverse, Some(_), Activity::Active) => abort!( + arg, + "type behind reference"; + help = "`#[autodiff]` active parameter should be concrete in reverse mode" + ), + (Mode::Reverse, None, Activity::Duplicated | Activity::DuplicatedNoNeed) => abort!( + arg, + "type not behind reference"; + help = "`#[autodiff]` duplicated parameters should be behind reference in reverse mode" + ), + (Mode::Reverse, Some(false), Activity::DuplicatedNoNeed) => abort!( + arg, + "use duplicated instead"; + help = "`#[autodiff]` input parameter cannot be declared as duplicatednoneed" + ), + (Mode::Forward, Some(false), Activity::Duplicated) + if header.ret_act != Activity::Const => + { + ret.push(ret_arg(&arg)) + } + (Mode::Reverse, None, Activity::Active) => ret.push(ret_arg(&arg)), + (Mode::Forward, Some(_), Activity::Duplicated | Activity::DuplicatedNoNeed) + | (Mode::Reverse, _, Activity::Duplicated | Activity::DuplicatedNoNeed) + if is_adjoint => + { + last_arg = Some(arg.clone()) + } + _ => {} + } + + args.push(arg.clone()); + acts.push(act); + } + + // if we have adjoint signature and are in forward mode + // if duplicated -> return type * (n + 1) times + // if duplicated_no_need -> return type * n times + // if const -> no return + + // if we have adjoint signature and are in reverse mode + // if active -> input type * n times + // construct return type based on mode + let ret = if is_adjoint { + let ret_typs = match &sig.output { + ReturnType::Type(_, ref x) => match &**x { + Type::Tuple(x) => x.elems.iter().cloned().collect(), + x => vec![x.clone()], + }, + ReturnType::Default => vec![], + }; + + match (header.mode, header.ret_act) { + (Mode::Forward, Activity::Duplicated) => { + let expected = ret_typs[0].clone(); + let list = vec![expected.clone(); ret.len() + 1]; + + if list != ret_typs { + let ret = quote!((#(#list,)*)); + abort!( + sig.output, + "invalid output"; + help = format!("`#[autodiff]` expected {}", ret) + ); + } + + parse_quote!(-> #expected) + } + (Mode::Forward, Activity::DuplicatedNoNeed) => { + let expected = ret_typs[0].clone(); + let list = vec![expected.clone(); ret.len()]; + + if list != ret_typs { + let ret = quote!((#(#list,)*)); + abort!( + sig.output, + "invalid output"; + help = format!("`#[autodiff]` expected {}", ret) + ); + } + + parse_quote!(-> #expected) + } + (Mode::Reverse, Activity::Active) => { + // tangent of output is latest in parameter list + let ret_typ = match (args.pop(), acts.pop()) { + (Some(x), Some(y)) => { + let x = ret_arg(&x); + if !is_scalar(&x) { + abort!( + x, + "output tangent not a floating point"; + help = "`#[autodiff]` the output tangent should be a floating point" + ); + } else if y != Activity::Const { + abort!( + x, + "output tangent not const"; + help = "`#[autodiff]` the last parameter of an adjoint with active return should be a constant tangent" + ); + } else { + parse_quote!(-> #x) + } + } + (None, None) => abort!( + sig, + "missing output tangent parameter"; + help = "`#[autodiff]` the last parameter of an adjoint with active return should exist" + ), + _ => unreachable!(), + }; + + // check that the return tuple confirms with return types + if ret_typs != ret { + let ret = quote!((#(#ret,)*)); + abort!( + sig.output, + "invalid output"; + help = format!("`#[autodiff]` expected {}", ret) + ) + } + + ret_typ + } + (_, Activity::Const) if ret.len() > 0 => { + abort!( + ret[0], + "constant return but more than one return"; + help = "`#[autodiff]` adjoint should have a return type when active" + ) + } + _ => ReturnType::Default, + } + } else { + if header.ret_act != Activity::Const && sig.output == ReturnType::Default { + abort!( + sig, + "no return type"; + help = "`#[autodiff]` non-const return activity but no return type" + ) + } + + sig.output.clone() + }; + + let sig = if is_adjoint { + // header is used for calling if we are adjoint + format_ident!("{}", sig.ident) + } else { + sig.ident.clone() + }; + + (PrimalSig { ident: sig, inputs: args, output: ret }, acts) +} + +pub(crate) fn parse(args: TokenStream, input: TokenStream) -> DiffItem { + // first parse function + let (_attrs, _, sig, block) = match syn::parse2::(input) { + Ok(Item::Fn(item)) => (item.attrs, item.vis, item.sig, Some(item.block)), + Ok(Item::Verbatim(x)) => match syn::parse2::(x) { + Ok(item) => (item.attrs, item.vis, item.sig, None), + Err(err) => panic!("Could not parse item {}", err), + }, + Ok(item) => { + abort!( + item, + "item is not a function"; + help = "`#[autodiff]` can only be used on primal or adjoint functions" + ) + } + Err(err) => panic!("Could not parse item: {}", err), + }; + + // then parse attributes + let (header, param_attrs) = Header::parse(args); + + // reduce parameters to primal parameter set + let (primal, params) = reduce_params(sig, param_attrs, !block.is_some(), &header); + + DiffItem { header, primal, params, block } +} diff --git a/library/autodiff/tests/expand/forward_duplicated.expanded.rs b/library/autodiff/tests/expand/forward_duplicated.expanded.rs new file mode 100644 index 0000000000000..bf3890154ab8e --- /dev/null +++ b/library/autodiff/tests/expand/forward_duplicated.expanded.rs @@ -0,0 +1,10 @@ +use autodiff::autodiff; +#[autodiff_into] +fn square(a: &Vec, b: &mut f32) { + *b = a.into_iter().map(f32::square).sum(); +} +#[autodiff_into(Forward, Const, Duplicated, Duplicated)] +fn d_square(a: &Vec, dual_a: &Vec, b: &mut f32, grad_b: &mut f32) { + std::hint::black_box((square(a, b), dual_a, grad_b)); + std::hint::black_box(unsafe { std::mem::zeroed() }) +} diff --git a/library/autodiff/tests/expand/forward_duplicated.rs b/library/autodiff/tests/expand/forward_duplicated.rs new file mode 100644 index 0000000000000..9a0bfc6c13a47 --- /dev/null +++ b/library/autodiff/tests/expand/forward_duplicated.rs @@ -0,0 +1,6 @@ +use autodiff::autodiff; + +#[autodiff(d_square, Forward, Const)] +fn square(#[dup] a: &Vec, #[dup] b: &mut f32) { + *b = a.into_iter().map(f32::square).sum(); +} diff --git a/library/autodiff/tests/expand/forward_duplicated_return.expanded.rs b/library/autodiff/tests/expand/forward_duplicated_return.expanded.rs new file mode 100644 index 0000000000000..a3754de7ab70b --- /dev/null +++ b/library/autodiff/tests/expand/forward_duplicated_return.expanded.rs @@ -0,0 +1,15 @@ +use autodiff::autodiff; +#[autodiff_into] +fn square2(a: &Vec, b: &Vec) -> f32 { + a.into_iter().map(f32::square).sum() +} +#[autodiff_into(Forward, Duplicated, Duplicated, Duplicated)] +fn d_square2( + a: &Vec, + dual_a: &Vec, + b: &Vec, + dual_b: &Vec, +) -> (f32, f32, f32) { + std::hint::black_box((square2(a, b), dual_a, dual_b)); + std::hint::black_box(unsafe { std::mem::zeroed() }) +} diff --git a/library/autodiff/tests/expand/forward_duplicated_return.rs b/library/autodiff/tests/expand/forward_duplicated_return.rs new file mode 100644 index 0000000000000..3397e5309ea96 --- /dev/null +++ b/library/autodiff/tests/expand/forward_duplicated_return.rs @@ -0,0 +1,6 @@ +use autodiff::autodiff; + +#[autodiff(d_square2, Forward, Duplicated)] +fn square2(#[dup] a: &Vec, #[dup] b: &Vec) -> f32 { + a.into_iter().map(f32::square).sum() +} diff --git a/library/autodiff/tests/expand/reverse_duplicated.expanded.rs b/library/autodiff/tests/expand/reverse_duplicated.expanded.rs new file mode 100644 index 0000000000000..60c0d7f2f696b --- /dev/null +++ b/library/autodiff/tests/expand/reverse_duplicated.expanded.rs @@ -0,0 +1,10 @@ +use autodiff::autodiff; +#[autodiff_into] +fn square(a: &Vec, b: &mut f32) { + *b = a.into_iter().map(f32::square).sum(); +} +#[autodiff_into(Reverse, Const, Duplicated, Duplicated)] +fn d_square(a: &Vec, grad_a: &mut Vec, b: &mut f32, grad_b: &f32) { + std::hint::black_box((square(a, b), grad_a, grad_b)); + std::hint::black_box(unsafe { std::mem::zeroed() }) +} diff --git a/library/autodiff/tests/expand/reverse_duplicated.rs b/library/autodiff/tests/expand/reverse_duplicated.rs new file mode 100644 index 0000000000000..107a708bec848 --- /dev/null +++ b/library/autodiff/tests/expand/reverse_duplicated.rs @@ -0,0 +1,6 @@ +use autodiff::autodiff; + +#[autodiff(d_square, Reverse, Const)] +fn square(#[dup] a: &Vec, #[dup] b: &mut f32) { + *b = a.into_iter().map(f32::square).sum(); +} diff --git a/library/autodiff/tests/expand/reverse_return_array.expanded.rs b/library/autodiff/tests/expand/reverse_return_array.expanded.rs new file mode 100644 index 0000000000000..5b784157fea7b --- /dev/null +++ b/library/autodiff/tests/expand/reverse_return_array.expanded.rs @@ -0,0 +1,10 @@ +use autodiff::autodiff; +#[autodiff_into] +fn array(arr: &[[[f32; 2]; 2]; 2]) -> f32 { + arr[0][0][0] * arr[1][1][1] +} +#[autodiff_into(Reverse, Active, Duplicated)] +fn d_array(arr: &[[[f32; 2]; 2]; 2], grad_arr: &mut [[[f32; 2]; 2]; 2], tang_y: f32) { + std::hint::black_box((array(arr), grad_arr, tang_y)); + std::hint::black_box(unsafe { std::mem::zeroed() }) +} diff --git a/library/autodiff/tests/expand/reverse_return_array.rs b/library/autodiff/tests/expand/reverse_return_array.rs new file mode 100644 index 0000000000000..da080a6b3a860 --- /dev/null +++ b/library/autodiff/tests/expand/reverse_return_array.rs @@ -0,0 +1,6 @@ +use autodiff::autodiff; + +#[autodiff(d_array, Reverse, Active)] +fn array(#[dup] arr: &[[[f32; 2]; 2]; 2]) -> f32 { + arr[0][0][0] * arr[1][1][1] +} diff --git a/library/autodiff/tests/expand/reverse_return_mixed.expanded.rs b/library/autodiff/tests/expand/reverse_return_mixed.expanded.rs new file mode 100644 index 0000000000000..f49864fb7e9b9 --- /dev/null +++ b/library/autodiff/tests/expand/reverse_return_mixed.expanded.rs @@ -0,0 +1,17 @@ +use autodiff::autodiff; +#[autodiff_into] +fn sqrt(a: f32, b: &f32, c: &f32, d: f32) -> f32 { + a * (b * b + c * c * d * d).sqrt() +} +#[autodiff_into(Reverse, Active, Active, Duplicated, Const, Active)] +fn d_sqrt( + a: f32, + b: &f32, + grad_b: &mut f32, + c: &f32, + d: f32, + tang_y: f32, +) -> (f32, f32) { + std::hint::black_box((sqrt(a, b, c, d), grad_b, tang_y)); + std::hint::black_box(unsafe { std::mem::zeroed() }) +} diff --git a/library/autodiff/tests/expand/reverse_return_mixed.rs b/library/autodiff/tests/expand/reverse_return_mixed.rs new file mode 100644 index 0000000000000..3260c3560d523 --- /dev/null +++ b/library/autodiff/tests/expand/reverse_return_mixed.rs @@ -0,0 +1,6 @@ +use autodiff::autodiff; + +#[autodiff(d_sqrt, Reverse, Active)] +fn sqrt(#[active] a: f32, #[dup] b: &f32, c: &f32, #[active] d: f32) -> f32 { + a * (b * b + c*c*d*d).sqrt() +} diff --git a/library/autodiff/tests/ui/active_in_forward_mode.rs b/library/autodiff/tests/ui/active_in_forward_mode.rs new file mode 100644 index 0000000000000..10366b1b422b8 --- /dev/null +++ b/library/autodiff/tests/ui/active_in_forward_mode.rs @@ -0,0 +1,6 @@ +use autodiff::autodiff; + +#[autodiff(d_sin, Forward, DuplicatedNoNeed, Active)] +fn sin(x: f32) -> f32; + +fn main() {} diff --git a/library/autodiff/tests/ui/active_in_forward_mode.stderr b/library/autodiff/tests/ui/active_in_forward_mode.stderr new file mode 100644 index 0000000000000..cd413564068ae --- /dev/null +++ b/library/autodiff/tests/ui/active_in_forward_mode.stderr @@ -0,0 +1,7 @@ +error: active argument in forward mode + --> tests/ui/active_in_forward_mode.rs:3:12 + | +3 | #[autodiff(d_sin, Forward, DuplicatedNoNeed, Active)] + | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + | + = help: `#[autodiff]` forward mode should be either Const, Duplicated diff --git a/library/autodiff/tests/ui/activities_inline_and_header.rs b/library/autodiff/tests/ui/activities_inline_and_header.rs new file mode 100644 index 0000000000000..1ecf37ec60a8f --- /dev/null +++ b/library/autodiff/tests/ui/activities_inline_and_header.rs @@ -0,0 +1,6 @@ +use autodiff::autodiff; + +#[autodiff(d_sin, Reverse, Active, Active)] +fn sin(#[active] x: f32) -> f32; + +fn main() {} diff --git a/library/autodiff/tests/ui/activities_inline_and_header.stderr b/library/autodiff/tests/ui/activities_inline_and_header.stderr new file mode 100644 index 0000000000000..b4d50d02a26a4 --- /dev/null +++ b/library/autodiff/tests/ui/activities_inline_and_header.stderr @@ -0,0 +1,7 @@ +error: inline activity + --> tests/ui/activities_inline_and_header.rs:4:18 + | +4 | fn sin(#[active] x: f32) -> f32; + | ^^^^^^ + | + = help: `#[autodiff]` should have activities either specified in header or as inline attributes diff --git a/library/autodiff/tests/ui/invalid_indirection.rs b/library/autodiff/tests/ui/invalid_indirection.rs new file mode 100644 index 0000000000000..627a7cb0fc6f9 --- /dev/null +++ b/library/autodiff/tests/ui/invalid_indirection.rs @@ -0,0 +1,19 @@ +use autodiff::autodiff; + +#[autodiff(d_sin, Reverse, Const)] +fn duplicated_without_reference(#[dup] x: f32) { +} + +#[autodiff(d_sin, Reverse, Const)] +fn active_with_reference(#[active] x: &f32) { +} + +#[autodiff(d_sin, Forward, Const)] +fn duplicated_forward(#[dup] x: f32) { +} + +#[autodiff(d_sin, Forward, Const)] +fn duplicated_no_need_forward(#[dup_noneed] x: &f32) { +} + +fn main() {} diff --git a/library/autodiff/tests/ui/invalid_indirection.stderr b/library/autodiff/tests/ui/invalid_indirection.stderr new file mode 100644 index 0000000000000..cb27c542018e5 --- /dev/null +++ b/library/autodiff/tests/ui/invalid_indirection.stderr @@ -0,0 +1,31 @@ +error: type not behind reference + --> tests/ui/invalid_indirection.rs:4:40 + | +4 | fn duplicated_without_reference(#[dup] x: f32) { + | ^^^^^^ + | + = help: `#[autodiff]` duplicated parameters should be behind reference in reverse mode + +error: type behind reference + --> tests/ui/invalid_indirection.rs:8:36 + | +8 | fn active_with_reference(#[active] x: &f32) { + | ^^^^^^^ + | + = help: `#[autodiff]` active parameter should be concrete in reverse mode + +error: type not behind reference + --> tests/ui/invalid_indirection.rs:12:30 + | +12 | fn duplicated_forward(#[dup] x: f32) { + | ^^^^^^ + | + = help: `#[autodiff]` duplicated types should be behind a reference + +error: should be mutable reference + --> tests/ui/invalid_indirection.rs:16:45 + | +16 | fn duplicated_no_need_forward(#[dup_noneed] x: &f32) { + | ^^^^^^^ + | + = help: `#[autodiff]` parameter should be output for DuplicatedNoNeed activity diff --git a/library/autodiff/tests/ui/invalid_mutability_pairs.rs b/library/autodiff/tests/ui/invalid_mutability_pairs.rs new file mode 100644 index 0000000000000..708ecc597a5be --- /dev/null +++ b/library/autodiff/tests/ui/invalid_mutability_pairs.rs @@ -0,0 +1,24 @@ +use autodiff::autodiff; + +#[autodiff(d_sin, Forward, Duplicated)] +fn fwd_output_no_reference(#[dup] x: &mut f32, y: f32) -> f32; + +#[autodiff(d_sin, Forward, Duplicated)] +fn output_immutable(#[dup] x: &mut f32, y: &f32) -> f32; + +#[autodiff(d_sin, Reverse, Active)] +fn rev_input_no_reference(#[dup] x: &f32, y: f32) -> f32; + +#[autodiff(d_sin, Reverse, Active)] +fn rev_output_no_reference(#[dup] x: &mut f32, y: f32) -> f32; + +#[autodiff(d_sin, Reverse, Active)] +fn input_immutable(#[dup] x: &f32, y: &f32) -> f32; + +#[autodiff(d_sin, Reverse, Active)] +fn output_mutable(#[dup] x: &mut f32, y: &mut f32) -> f32; + +#[autodiff(d_sin, Reverse, Active)] +fn dupnoneed_input(#[dup_noneed] x: &f32, y: &f32) -> f32; + +fn main() {} diff --git a/library/autodiff/tests/ui/invalid_mutability_pairs.stderr b/library/autodiff/tests/ui/invalid_mutability_pairs.stderr new file mode 100644 index 0000000000000..37af0c2ad52ee --- /dev/null +++ b/library/autodiff/tests/ui/invalid_mutability_pairs.stderr @@ -0,0 +1,55 @@ +error: should be a mutable reference + --> tests/ui/invalid_mutability_pairs.rs:4:48 + | +4 | fn fwd_output_no_reference(#[dup] x: &mut f32, y: f32) -> f32; + | ^^^^^^ + | + = help: `#[autodiff]` output parameter should duplicate derivative into second parameter for forward mode + +error: should be a mutable reference + --> tests/ui/invalid_mutability_pairs.rs:7:41 + | +7 | fn output_immutable(#[dup] x: &mut f32, y: &f32) -> f32; + | ^^^^^^^ + | + = help: `#[autodiff]` output parameter should duplicate derivative into second parameter for forward mode + +error: should be a mutable reference + --> tests/ui/invalid_mutability_pairs.rs:10:43 + | +10 | fn rev_input_no_reference(#[dup] x: &f32, y: f32) -> f32; + | ^^^^^^ + | + = help: `#[autodiff]` input parameter should duplicate derivative into second parameter for reverse mode + +error: should be an immutable reference + --> tests/ui/invalid_mutability_pairs.rs:13:48 + | +13 | fn rev_output_no_reference(#[dup] x: &mut f32, y: f32) -> f32; + | ^^^^^^ + | + = help: `#[autodiff]` input parameter should duplicate derivative into second parameter for reverse mode + +error: should be a mutable reference + --> tests/ui/invalid_mutability_pairs.rs:16:36 + | +16 | fn input_immutable(#[dup] x: &f32, y: &f32) -> f32; + | ^^^^^^^ + | + = help: `#[autodiff]` input parameter should duplicate derivative into second parameter for reverse mode + +error: should be an immutable reference + --> tests/ui/invalid_mutability_pairs.rs:19:39 + | +19 | fn output_mutable(#[dup] x: &mut f32, y: &mut f32) -> f32; + | ^^^^^^^^^^^ + | + = help: `#[autodiff]` input parameter should duplicate derivative into second parameter for reverse mode + +error: use duplicated instead + --> tests/ui/invalid_mutability_pairs.rs:22:34 + | +22 | fn dupnoneed_input(#[dup_noneed] x: &f32, y: &f32) -> f32; + | ^^^^^^^ + | + = help: `#[autodiff]` input parameter cannot be declared as duplicatednoneed diff --git a/library/autodiff/tests/ui/invalid_return.rs b/library/autodiff/tests/ui/invalid_return.rs new file mode 100644 index 0000000000000..b3c8bce1166bf --- /dev/null +++ b/library/autodiff/tests/ui/invalid_return.rs @@ -0,0 +1,12 @@ +use autodiff::autodiff; + +#[autodiff(d_sin, Forward, Active)] +fn sin1(x: f32) -> f32; + +#[autodiff(d_sin, Reverse, Duplicated)] +fn sin2(x: f32) -> f32; + +#[autodiff(d_sin, Reverse, DuplicatedNoNeed)] +fn sin3(x: f32) -> f32; + +fn main() {} diff --git a/library/autodiff/tests/ui/invalid_return.stderr b/library/autodiff/tests/ui/invalid_return.stderr new file mode 100644 index 0000000000000..4ddaccdba0f72 --- /dev/null +++ b/library/autodiff/tests/ui/invalid_return.stderr @@ -0,0 +1,23 @@ +error: active return for forward mode + --> tests/ui/invalid_return.rs:3:28 + | +3 | #[autodiff(d_sin, Forward, Active)] + | ^^^^^^ + | + = help: `#[autodiff]` return should be Const, Duplicated or DuplicatedNoNeed in forward mode + +error: duplicated return for reverse mode + --> tests/ui/invalid_return.rs:6:28 + | +6 | #[autodiff(d_sin, Reverse, Duplicated)] + | ^^^^^^^^^^ + | + = help: `#[autodiff]` return should be Const or Active in reverse mode + +error: duplicated return for reverse mode + --> tests/ui/invalid_return.rs:9:28 + | +9 | #[autodiff(d_sin, Reverse, DuplicatedNoNeed)] + | ^^^^^^^^^^^^^^^^ + | + = help: `#[autodiff]` return should be Const or Active in reverse mode diff --git a/library/autodiff/tests/ui/invalid_return_type.rs b/library/autodiff/tests/ui/invalid_return_type.rs new file mode 100644 index 0000000000000..7b91ccd2d650a --- /dev/null +++ b/library/autodiff/tests/ui/invalid_return_type.rs @@ -0,0 +1,16 @@ +use autodiff::autodiff; + +#[autodiff(d_sin, Reverse, Active)] +fn active_but_no_return(#[active] x: f32) { +} + +#[autodiff(d_sin, Reverse, Active)] +fn invalid_primal_value(#[active] x: f32, #[active] y: Vec, #[active] z: Tensor, y_tang: f32) -> (i32, f32); + +#[autodiff(d_sin, Forward, Duplicated)] +fn invalid_forward_return(#[dup] x: &f32, tx: &f32, #[dup] y: &Vec, ty: &Vec, #[dup] z: &Tensor, tz: &Tensor) -> (f32, f32, f32); + +#[autodiff(d_sin, Forward, DuplicatedNoNeed)] +fn invalid_forward_return(#[dup] x: &f32, tx: &f32, #[dup] y: &Vec, ty: &Vec, #[dup] z: &Tensor, tz: &Tensor) -> (f32, f32); + +fn main() {} diff --git a/library/autodiff/tests/ui/invalid_return_type.stderr b/library/autodiff/tests/ui/invalid_return_type.stderr new file mode 100644 index 0000000000000..90e5e47a2a33d --- /dev/null +++ b/library/autodiff/tests/ui/invalid_return_type.stderr @@ -0,0 +1,31 @@ +error: no return type + --> tests/ui/invalid_return_type.rs:4:1 + | +4 | fn active_but_no_return(#[active] x: f32) { + | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + | + = help: `#[autodiff]` non-const return activity but no return type + +error: invalid output + --> tests/ui/invalid_return_type.rs:8:100 + | +8 | fn invalid_primal_value(#[active] x: f32, #[active] y: Vec, #[active] z: Tensor, y_tang: f32) -> (i32, f32); + | ^^^^^^^^^^^^^ + | + = help: `#[autodiff]` expected (f32, Vec < f32 >, Tensor,) + +error: invalid output + --> tests/ui/invalid_return_type.rs:11:121 + | +11 | fn invalid_forward_return(#[dup] x: &f32, tx: &f32, #[dup] y: &Vec, ty: &Vec, #[dup] z: &Tensor, tz: &Tensor) -> (f32, f32, f32); + | ^^^^^^^^^^^^^^^^^^ + | + = help: `#[autodiff]` expected (f32, f32, f32, f32,) + +error: invalid output + --> tests/ui/invalid_return_type.rs:14:121 + | +14 | fn invalid_forward_return(#[dup] x: &f32, tx: &f32, #[dup] y: &Vec, ty: &Vec, #[dup] z: &Tensor, tz: &Tensor) -> (f32, f32); + | ^^^^^^^^^^^^^ + | + = help: `#[autodiff]` expected (f32, f32, f32,) diff --git a/library/autodiff/tests/ui/no_function_name.rs b/library/autodiff/tests/ui/no_function_name.rs new file mode 100644 index 0000000000000..8222ca4aaf37d --- /dev/null +++ b/library/autodiff/tests/ui/no_function_name.rs @@ -0,0 +1,6 @@ +use autodiff::autodiff; + +#[autodiff] +fn sin(x: f32) -> f32; + +fn main() {} diff --git a/library/autodiff/tests/ui/no_function_name.stderr b/library/autodiff/tests/ui/no_function_name.stderr new file mode 100644 index 0000000000000..e98add3164c9f --- /dev/null +++ b/library/autodiff/tests/ui/no_function_name.stderr @@ -0,0 +1,8 @@ +error: please specify the autodiff function + --> tests/ui/no_function_name.rs:3:1 + | +3 | #[autodiff] + | ^^^^^^^^^^^ + | + = help: `#[autodiff]` needs a function name for primal or adjoint + = note: this error originates in the attribute macro `autodiff` (in Nightly builds, run with -Z macro-backtrace for more info) diff --git a/library/autodiff/tests/ui/not_a_function.rs b/library/autodiff/tests/ui/not_a_function.rs new file mode 100644 index 0000000000000..0a3c11725a086 --- /dev/null +++ b/library/autodiff/tests/ui/not_a_function.rs @@ -0,0 +1,6 @@ +use autodiff::autodiff; + +#[autodiff] +struct NotAFunction; + +fn main() {} diff --git a/library/autodiff/tests/ui/not_a_function.stderr b/library/autodiff/tests/ui/not_a_function.stderr new file mode 100644 index 0000000000000..c681841532a5e --- /dev/null +++ b/library/autodiff/tests/ui/not_a_function.stderr @@ -0,0 +1,7 @@ +error: item is not a function + --> tests/ui/not_a_function.rs:4:1 + | +4 | struct NotAFunction; + | ^^^^^^^^^^^^^^^^^^^^ + | + = help: `#[autodiff]` can only be used on primal or adjoint functions diff --git a/library/autodiff/tests/ui/reverse_tangent.rs b/library/autodiff/tests/ui/reverse_tangent.rs new file mode 100644 index 0000000000000..603f7fd1789ce --- /dev/null +++ b/library/autodiff/tests/ui/reverse_tangent.rs @@ -0,0 +1,12 @@ +use autodiff::autodiff; + +#[autodiff(d_sin, Reverse, Active)] +fn invalid_output_tangent_type(#[active] x: f32, y_tang: i32) -> f32; + +#[autodiff(d_sin, Reverse, Active)] +fn active_output_tangent(#[active] x: f32, #[active] y_tang: f32) -> f32; + +#[autodiff(d_sin, Reverse, Active)] +fn tangent_missing() -> f32; + +fn main() {} diff --git a/library/autodiff/tests/ui/reverse_tangent.stderr b/library/autodiff/tests/ui/reverse_tangent.stderr new file mode 100644 index 0000000000000..a7b4b6e3d97d6 --- /dev/null +++ b/library/autodiff/tests/ui/reverse_tangent.stderr @@ -0,0 +1,23 @@ +error: output tangent not a floating point + --> tests/ui/reverse_tangent.rs:4:58 + | +4 | fn invalid_output_tangent_type(#[active] x: f32, y_tang: i32) -> f32; + | ^^^ + | + = help: `#[autodiff]` the output tangent should be a floating point + +error: output tangent not const + --> tests/ui/reverse_tangent.rs:7:62 + | +7 | fn active_output_tangent(#[active] x: f32, #[active] y_tang: f32) -> f32; + | ^^^ + | + = help: `#[autodiff]` the last parameter of an adjoint with active return should be a constant tangent + +error: missing output tangent parameter + --> tests/ui/reverse_tangent.rs:10:1 + | +10 | fn tangent_missing() -> f32; + | ^^^^^^^^^^^^^^^^^^^^^^^^^^^ + | + = help: `#[autodiff]` the last parameter of an adjoint with active return should exist diff --git a/library/autodiff/tests/ui/wrong_mode.rs b/library/autodiff/tests/ui/wrong_mode.rs new file mode 100644 index 0000000000000..1b500711de109 --- /dev/null +++ b/library/autodiff/tests/ui/wrong_mode.rs @@ -0,0 +1,6 @@ +use autodiff::autodiff; + +#[autodiff(d_sin, WrongMode)] +fn sin(x: f32) -> f32; + +fn main() {} diff --git a/library/autodiff/tests/ui/wrong_mode.stderr b/library/autodiff/tests/ui/wrong_mode.stderr new file mode 100644 index 0000000000000..ca18d81abb306 --- /dev/null +++ b/library/autodiff/tests/ui/wrong_mode.stderr @@ -0,0 +1,7 @@ +error: should be forward or reverse + --> tests/ui/wrong_mode.rs:3:19 + | +3 | #[autodiff(d_sin, WrongMode)] + | ^^^^^^^^^ + | + = help: `#[autodiff]` modes should be either forward or reverse diff --git a/library/core/src/macros/mod.rs b/library/core/src/macros/mod.rs index 125a6f57bfbaa..658b640ae4ce3 100644 --- a/library/core/src/macros/mod.rs +++ b/library/core/src/macros/mod.rs @@ -1416,6 +1416,18 @@ pub(crate) mod builtin { }; } + /// Differentiate function + ///#[unstable( + /// feature = "autodiff", + /// issue = "29598", + /// reason = "autodiff is not stable enough" + ///)] + ///#[rustc_builtin_macro] + ///#[macro_export] + ///pub macro autodiff($item:item) { + /// /* compiler built-in */ + ///} + /// Parses a file as an expression or an item according to the context. /// /// **Warning**: For multi-file Rust projects, the `include!` macro is probably not what you diff --git a/src/bootstrap/configure.py b/src/bootstrap/configure.py index bfef3e672407d..6369a1a557a8f 100755 --- a/src/bootstrap/configure.py +++ b/src/bootstrap/configure.py @@ -70,6 +70,7 @@ def v(*args): # channel, etc. o("optimize-llvm", "llvm.optimize", "build optimized LLVM") o("llvm-assertions", "llvm.assertions", "build LLVM with assertions") +o("llvm-enzyme", "llvm.enzyme", "build LLVM with Enzyme") o("llvm-plugins", "llvm.plugins", "build LLVM with plugin interface") o("debug-assertions", "rust.debug-assertions", "build with debugging assertions") o("debug-assertions-std", "rust.debug-assertions-std", "build the standard library with debugging assertions") diff --git a/src/bootstrap/src/core/build_steps/compile.rs b/src/bootstrap/src/core/build_steps/compile.rs index 441931e415cc6..7a53c4caffe6d 100644 --- a/src/bootstrap/src/core/build_steps/compile.rs +++ b/src/bootstrap/src/core/build_steps/compile.rs @@ -1539,6 +1539,7 @@ pub struct Assemble { pub target_compiler: Compiler, } +#[allow(unreachable_code)] impl Step for Assemble { type Output = Compiler; const ONLY_HOSTS: bool = true; @@ -1599,6 +1600,24 @@ impl Step for Assemble { return target_compiler; } + // Build enzyme + let enzyme_install = if builder.config.llvm_enzyme { + Some(builder.ensure(llvm::Enzyme { target: build_compiler.host })) + } else { + None + }; + + if let Some(enzyme_install) = enzyme_install { + let src_lib = enzyme_install.join("build/Enzyme/LLVMEnzyme-16.so"); + + let libdir = builder.sysroot_libdir(build_compiler, build_compiler.host); + let target_libdir = builder.sysroot_libdir(target_compiler, target_compiler.host); + let dst_lib = libdir.join("libLLVMEnzyme-16.so"); + let target_dst_lib = target_libdir.join("libLLVMEnzyme-16.so"); + builder.copy(&src_lib, &dst_lib); + builder.copy(&src_lib, &target_dst_lib); + } + // Build the libraries for this compiler to link to (i.e., the libraries // it uses at runtime). NOTE: Crates the target compiler compiles don't // link to these. (FIXME: Is that correct? It seems to be correct most diff --git a/src/bootstrap/src/core/build_steps/llvm.rs b/src/bootstrap/src/core/build_steps/llvm.rs index 24351118a5aa1..11e377be92e24 100644 --- a/src/bootstrap/src/core/build_steps/llvm.rs +++ b/src/bootstrap/src/core/build_steps/llvm.rs @@ -802,6 +802,72 @@ fn get_var(var_base: &str, host: &str, target: &str) -> Option { .or_else(|| env::var_os(var_base)) } +#[derive(Debug, Copy, Clone, Hash, PartialEq, Eq)] +pub struct Enzyme { + pub target: TargetSelection, +} + +impl Step for Enzyme { + type Output = PathBuf; + const ONLY_HOSTS: bool = true; + + fn should_run(run: ShouldRun<'_>) -> ShouldRun<'_> { + run.path("src/tools/enzyme/enzyme") + } + + fn make_run(run: RunConfig<'_>) { + run.builder.ensure(Enzyme { target: run.target }); + } + + /// Compile Enzyme for `target`. + fn run(self, builder: &Builder<'_>) -> PathBuf { + if builder.config.dry_run() { + let out_dir = builder.enzyme_out(self.target); + return out_dir; + } + let target = self.target; + + let LlvmResult { llvm_config, .. } = builder.ensure(Llvm { target: self.target }); + + let out_dir = builder.enzyme_out(target); + let done_stamp = out_dir.join("enzyme-finished-building"); + if done_stamp.exists() { + return out_dir; + } + + builder.info(&format!("Building Enzyme for {}", target)); + let _time = helpers::timeit(&builder); + t!(fs::create_dir_all(&out_dir)); + + builder.update_submodule(&Path::new("src").join("tools").join("enzyme")); + let mut cfg = cmake::Config::new(builder.src.join("src/tools/enzyme/enzyme/")); + // TODO: Find a nicer way to use Enzyme Debug builds + //cfg.profile("Debug"); + //cfg.define("CMAKE_BUILD_TYPE", "Debug"); + configure_cmake(builder, target, &mut cfg, true, LdFlags::default(), &[]); + + // Re-use the same flags as llvm to control the level of debug information + // generated for lld. + let profile = match (builder.config.llvm_optimize, builder.config.llvm_release_debuginfo) { + (false, _) => "Debug", + (true, false) => "Release", + (true, true) => "RelWithDebInfo", + }; + + cfg.out_dir(&out_dir) + .profile(profile) + .env("LLVM_CONFIG_REAL", &llvm_config) + .define("LLVM_ENABLE_ASSERTIONS", "ON") + .define("ENZYME_EXTERNAL_SHARED_LIB", "OFF") + .define("LLVM_DIR", builder.llvm_out(target)); + + cfg.build(); + + t!(File::create(&done_stamp)); + out_dir + } +} + #[derive(Debug, Copy, Clone, Hash, PartialEq, Eq)] pub struct Lld { pub target: TargetSelection, diff --git a/src/bootstrap/src/core/builder.rs b/src/bootstrap/src/core/builder.rs index 90e09d12a9d50..65c94281e00d4 100644 --- a/src/bootstrap/src/core/builder.rs +++ b/src/bootstrap/src/core/builder.rs @@ -1371,6 +1371,11 @@ impl<'a> Builder<'a> { } } + // TODO: adjust -14 ending for Enzyme + // https://rust-lang.zulipchat.com/#narrow/stream/182449-t-compiler.2Fhelp/topic/.E2.9C.94.20link.20new.20library.20into.20stage1.2Frustc + rustflags.arg("-l"); + rustflags.arg("LLVMEnzyme-16"); + let use_new_symbol_mangling = match self.config.rust_new_symbol_mangling { Some(setting) => { // If an explicit setting is given, use that diff --git a/src/bootstrap/src/core/config/config.rs b/src/bootstrap/src/core/config/config.rs index 5b5334b0a5572..a49835b93b662 100644 --- a/src/bootstrap/src/core/config/config.rs +++ b/src/bootstrap/src/core/config/config.rs @@ -173,6 +173,8 @@ pub struct Config { // llvm codegen options pub llvm_assertions: bool, pub llvm_tests: bool, + pub llvm_enzyme: bool, + pub llvm_enzyme_build: Option, pub llvm_plugins: bool, pub llvm_optimize: bool, pub llvm_thin_lto: bool, @@ -676,24 +678,24 @@ macro_rules! define_config { A: serde::de::MapAccess<'de>, { $(let mut $field: Option<$field_ty> = None;)* - while let Some(key) = - match serde::de::MapAccess::next_key::(&mut map) { - Ok(val) => val, - Err(err) => { - return Err(err); + while let Some(key) = + match serde::de::MapAccess::next_key::(&mut map) { + Ok(val) => val, + Err(err) => { + return Err(err); + } } - } { match &*key { $($field_key => { if $field.is_some() { return Err(::duplicate_field( - $field_key, - )); + $field_key, + )); } $field = match serde::de::MapAccess::next_value::<$field_ty>( &mut map, - ) { + ) { Ok(val) => Some(val), Err(err) => { return Err(err); @@ -823,6 +825,7 @@ define_config! { release_debuginfo: Option = "release-debuginfo", assertions: Option = "assertions", tests: Option = "tests", + enzyme: Option = "enzyme", plugins: Option = "plugins", ccache: Option = "ccache", static_libstdcpp: Option = "static-libstdcpp", @@ -1356,6 +1359,7 @@ impl Config { // we'll infer default values for them later let mut llvm_assertions = None; let mut llvm_tests = None; + let mut llvm_enzyme = None; let mut llvm_plugins = None; let mut debug = None; let mut debug_assertions = None; @@ -1500,6 +1504,7 @@ impl Config { set(&mut config.ninja_in_file, llvm.ninja); llvm_assertions = llvm.assertions; llvm_tests = llvm.tests; + llvm_enzyme = llvm.enzyme; llvm_plugins = llvm.plugins; set(&mut config.llvm_optimize, llvm.optimize); set(&mut config.llvm_thin_lto, llvm.thin_lto); @@ -1565,6 +1570,7 @@ impl Config { check_ci_llvm!(llvm.polly); check_ci_llvm!(llvm.clang); check_ci_llvm!(llvm.build_config); + check_ci_llvm!(llvm.enzyme); check_ci_llvm!(llvm.plugins); } @@ -1658,6 +1664,7 @@ impl Config { config.llvm_assertions = llvm_assertions.unwrap_or(false); config.llvm_tests = llvm_tests.unwrap_or(false); + config.llvm_enzyme = llvm_enzyme.unwrap_or(false); config.llvm_plugins = llvm_plugins.unwrap_or(false); config.rust_optimize = optimize.unwrap_or(RustOptimize::Bool(true)); diff --git a/src/bootstrap/src/lib.rs b/src/bootstrap/src/lib.rs index d7f49a6d11b9c..ac74c632a1b1f 100644 --- a/src/bootstrap/src/lib.rs +++ b/src/bootstrap/src/lib.rs @@ -806,6 +806,10 @@ impl Build { self.out.join(&*target.triple).join("lld") } + fn enzyme_out(&self, target: TargetSelection) -> PathBuf { + self.out.join(&*target.triple).join("enzyme") + } + /// Output directory for all documentation for a target fn doc_out(&self, target: TargetSelection) -> PathBuf { self.out.join(&*target.triple).join("doc") diff --git a/src/test/ui/terminal-width/flag-human.rs b/src/test/ui/terminal-width/flag-human.rs new file mode 100644 index 0000000000000..4b94ebb01fc8e --- /dev/null +++ b/src/test/ui/terminal-width/flag-human.rs @@ -0,0 +1,9 @@ +// compile-flags: --diagnostic-width=20 + +// This test checks that `-Z diagnostic-width` effects the human error output by restricting it to an +// arbitrarily low value so that the effect is visible. + +fn main() { + let _: () = 42; + //~^ ERROR mismatched types +} diff --git a/src/test/ui/terminal-width/flag-json.rs b/src/test/ui/terminal-width/flag-json.rs new file mode 100644 index 0000000000000..3add1d7d9301e --- /dev/null +++ b/src/test/ui/terminal-width/flag-json.rs @@ -0,0 +1,9 @@ +// compile-flags: --diagnostic-width=20 --error-format=json + +// This test checks that `-Z diagnostic-width` effects the JSON error output by restricting it to an +// arbitrarily low value so that the effect is visible. + +fn main() { + let _: () = 42; + //~^ ERROR mismatched types +} diff --git a/src/test/ui/terminal-width/flag-json.stderr b/src/test/ui/terminal-width/flag-json.stderr new file mode 100644 index 0000000000000..b21391d1640ef --- /dev/null +++ b/src/test/ui/terminal-width/flag-json.stderr @@ -0,0 +1,40 @@ +{"message":"mismatched types","code":{"code":"E0308","explanation":"Expected type did not match the received type. + +Erroneous code examples: + +```compile_fail,E0308 +fn plus_one(x: i32) -> i32 { + x + 1 +} + +plus_one(\"Not a number\"); +// ^^^^^^^^^^^^^^ expected `i32`, found `&str` + +if \"Not a bool\" { +// ^^^^^^^^^^^^ expected `bool`, found `&str` +} + +let x: f32 = \"Not a float\"; +// --- ^^^^^^^^^^^^^ expected `f32`, found `&str` +// | +// expected due to this +``` + +This error occurs when an expression was used in a place where the compiler +expected an expression of a different type. It can occur in several cases, the +most common being when calling a function and passing an argument which has a +different type than the matching type in the function declaration. +"},"level":"error","spans":[{"file_name":"$DIR/flag-json.rs","byte_start":243,"byte_end":245,"line_start":7,"line_end":7,"column_start":17,"column_end":19,"is_primary":true,"text":[{"text":" let _: () = 42;","highlight_start":17,"highlight_end":19}],"label":"expected `()`, found integer","suggested_replacement":null,"suggestion_applicability":null,"expansion":null},{"file_name":"$DIR/flag-json.rs","byte_start":238,"byte_end":240,"line_start":7,"line_end":7,"column_start":12,"column_end":14,"is_primary":false,"text":[{"text":" let _: () = 42;","highlight_start":12,"highlight_end":14}],"label":"expected due to this","suggested_replacement":null,"suggestion_applicability":null,"expansion":null}],"children":[],"rendered":"error[E0308]: mismatched types + --> $DIR/flag-json.rs:7:17 + | +LL | ..._: () = 42; + | -- ^^ expected `()`, found integer + | | + | expected due to this + +"} +{"message":"aborting due to previous error","code":null,"level":"error","spans":[],"children":[],"rendered":"error: aborting due to previous error + +"} +{"message":"For more information about this error, try `rustc --explain E0308`.","code":null,"level":"failure-note","spans":[],"children":[],"rendered":"For more information about this error, try `rustc --explain E0308`. +"} diff --git a/src/tools/enzyme b/src/tools/enzyme new file mode 160000 index 0000000000000..86fc287c5a396 --- /dev/null +++ b/src/tools/enzyme @@ -0,0 +1 @@ +Subproject commit 86fc287c5a39632364af2c48bc3efb5ef1f6652d diff --git a/tests/rustdoc-ui/doctest/terminal-width.rs b/tests/rustdoc-ui/doctest/terminal-width.rs new file mode 100644 index 0000000000000..61961d5ec710e --- /dev/null +++ b/tests/rustdoc-ui/doctest/terminal-width.rs @@ -0,0 +1,5 @@ +// compile-flags: -Zunstable-options --diagnostic-width=10 +#![deny(rustdoc::bare_urls)] + +/// This is a long line that contains a http://link.com +pub struct Foo; //~^ ERROR diff --git a/tests/rustdoc-ui/doctest/terminal-width.stderr b/tests/rustdoc-ui/doctest/terminal-width.stderr new file mode 100644 index 0000000000000..fed049d2b37bc --- /dev/null +++ b/tests/rustdoc-ui/doctest/terminal-width.stderr @@ -0,0 +1,15 @@ +error: this URL is not a hyperlink + --> $DIR/diagnostic-width.rs:4:41 + | +LL | ... a http://link.com + | ^^^^^^^^^^^^^^^ help: use an automatic link instead: `` + | +note: the lint level is defined here + --> $DIR/diagnostic-width.rs:2:9 + | +LL | ...ny(rustdoc::bare_url... + | ^^^^^^^^^^^^^^^^^^ + = note: bare URLs are not automatically turned into clickable links + +error: aborting due to previous error + diff --git a/tests/ui/json/autodiff.rs b/tests/ui/json/autodiff.rs new file mode 100644 index 0000000000000..54f94c3765bf6 --- /dev/null +++ b/tests/ui/json/autodiff.rs @@ -0,0 +1,16 @@ +// Check autodiff attribute +// edition:2018 + +extern "C" fn rosenbrock(a: f32, b: f32, x: f32, y: f32) -> f32 { + let (z, w) = (a-x, y-x*x); + + z*z + b*w*w +} + +#[autodiff(rosenbrock, mode = "forward")] +extern "C" { + fn dx_rosenbrock(a: f32, b: f32, x: f32, y: f32, d_x: &mut f32); + fn dy_rosenbrock(a: f32, b: f32, x: f32, y: f32, d_y: &mut f32); +} + +fn main() {}