diff --git a/.noir-sync-commit b/.noir-sync-commit index 3a367e22506..cf76a4efc4b 100644 --- a/.noir-sync-commit +++ b/.noir-sync-commit @@ -1 +1 @@ -c172880ae47ec4906cda662801bd4b7866c9586b +c44b62615f1c8ee657eedd82f2b80e2ec76c9078 diff --git a/barretenberg/cpp/src/barretenberg/vm/avm/trace/trace.cpp b/barretenberg/cpp/src/barretenberg/vm/avm/trace/trace.cpp index 257c31e8081..9749ec073d3 100644 --- a/barretenberg/cpp/src/barretenberg/vm/avm/trace/trace.cpp +++ b/barretenberg/cpp/src/barretenberg/vm/avm/trace/trace.cpp @@ -344,7 +344,8 @@ void AvmTraceBuilder::pay_fee() FF current_balance = read_hint.leaf_preimage.value; const auto updated_balance = current_balance - tx_fee; - if (current_balance < tx_fee) { + // Comparison on Field gives inverted results, so we cast to uint128, which should be enough for fees. + if (static_cast(current_balance) < static_cast(tx_fee)) { info("Not enough balance for fee payer to pay for transaction (got ", current_balance, " needs ", tx_fee); throw std::runtime_error("Not enough balance for fee payer to pay for transaction"); } diff --git a/noir/noir-repo/.github/workflows/reports.yml b/noir/noir-repo/.github/workflows/reports.yml index ccf86b83200..4d8f036a64a 100644 --- a/noir/noir-repo/.github/workflows/reports.yml +++ b/noir/noir-repo/.github/workflows/reports.yml @@ -315,7 +315,7 @@ jobs: external_repo_compilation_and_execution_report: needs: [build-nargo] - runs-on: ubuntu-latest + runs-on: ubuntu-22.04 timeout-minutes: 15 strategy: fail-fast: false @@ -421,53 +421,9 @@ jobs: retention-days: 3 overwrite: true - upload_compilation_report: - name: Upload compilation report - needs: [generate_compilation_and_execution_report, external_repo_compilation_and_execution_report] - # We want this job to run even if one variation of the matrix in `external_repo_compilation_and_execution_report` fails - if: always() - runs-on: ubuntu-latest - permissions: - pull-requests: write - - steps: - - uses: actions/checkout@v4 - - - name: Download initial compilation report - uses: actions/download-artifact@v4 - with: - name: in_progress_compilation_report - - - name: Download matrix compilation reports - uses: actions/download-artifact@v4 - with: - pattern: compilation_report_* - path: ./reports - - - name: Merge compilation reports using jq - run: | - mv ./.github/scripts/merge-bench-reports.sh merge-bench-reports.sh - ./merge-bench-reports.sh compilation_report - - - name: Parse compilation report - id: compilation_report - uses: noir-lang/noir-bench-report@6ba151d7795042c4ff51864fbeb13c0a6a79246c - with: - report: compilation_report.json - header: | - Compilation Report - memory_report: false - - - name: Add memory report to sticky comment - if: github.event_name == 'pull_request' || github.event_name == 'pull_request_target' - uses: marocchino/sticky-pull-request-comment@v2 - with: - header: compilation - message: ${{ steps.compilation_report.outputs.markdown }} - external_repo_memory_report: needs: [build-nargo] - runs-on: ubuntu-latest + runs-on: ubuntu-22.04 timeout-minutes: 30 strategy: fail-fast: false @@ -508,6 +464,7 @@ jobs: path: scripts sparse-checkout: | test_programs/memory_report.sh + test_programs/parse_memory.sh sparse-checkout-cone-mode: false - name: Checkout @@ -521,6 +478,7 @@ jobs: working-directory: ./test-repo/${{ matrix.project.path }} run: | mv /home/runner/work/noir/noir/scripts/test_programs/memory_report.sh ./memory_report.sh + mv /home/runner/work/noir/noir/scripts/test_programs/parse_memory.sh ./parse_memory.sh ./memory_report.sh 1 # Rename the memory report as the execution report is about to write to the same file cp memory_report.json compilation_memory_report.json @@ -568,14 +526,67 @@ jobs: retention-days: 3 overwrite: true + upload_compilation_report: + name: Upload compilation report + needs: [generate_compilation_and_execution_report, external_repo_compilation_and_execution_report] + # We want this job to run even if one variation of the matrix in `external_repo_compilation_and_execution_report` fails + if: always() + runs-on: ubuntu-22.04 + permissions: + pull-requests: write + # deployments permission to deploy GitHub pages website + deployments: write + # contents permission to update benchmark contents in gh-pages branch + contents: write + + steps: + - uses: actions/checkout@v4 + + - name: Download initial compilation report + uses: actions/download-artifact@v4 + with: + name: in_progress_compilation_report + + - name: Download matrix compilation reports + uses: actions/download-artifact@v4 + with: + pattern: compilation_report_* + path: ./reports + + - name: Merge compilation reports using jq + run: | + mv ./.github/scripts/merge-bench-reports.sh merge-bench-reports.sh + ./merge-bench-reports.sh compilation_report + jq ".compilation_reports | map({name: .artifact_name, value: (.time[:-1] | tonumber), unit: \"s\"}) " ./compilation_report.json > time_bench.json + + - name: Store benchmark result + continue-on-error: true + uses: benchmark-action/github-action-benchmark@4de1bed97a47495fc4c5404952da0499e31f5c29 + with: + name: "Compilation Time" + tool: "customSmallerIsBetter" + output-file-path: ./time_bench.json + github-token: ${{ secrets.GITHUB_TOKEN }} + # We want this to only run on master to avoid garbage data from PRs being added. + auto-push: ${{ github.ref == 'refs/heads/master' }} + alert-threshold: "120%" + comment-on-alert: true + fail-on-alert: false + alert-comment-cc-users: "@TomAFrench" + max-items-in-chart: 50 + upload_compilation_memory_report: name: Upload compilation memory report needs: [generate_memory_report, external_repo_memory_report] # We want this job to run even if one variation of the matrix in `external_repo_memory_report` fails if: always() - runs-on: ubuntu-latest + runs-on: ubuntu-22.04 permissions: pull-requests: write + # deployments permission to deploy GitHub pages website + deployments: write + # contents permission to update benchmark contents in gh-pages branch + contents: write steps: - uses: actions/checkout@v4 @@ -595,33 +606,36 @@ jobs: run: | mv ./.github/scripts/merge-bench-reports.sh merge-bench-reports.sh ./merge-bench-reports.sh memory_report - # Rename the memory report as to not clash with the compilation memory report file name - cp memory_report.json execution_memory_report.json - - - name: Parse compilation memory report - id: compilation_mem_report - uses: noir-lang/noir-bench-report@6ba151d7795042c4ff51864fbeb13c0a6a79246c - with: - report: execution_memory_report.json - header: | - Compilation Memory Report - memory_report: true - - - name: Add execution memory report to sticky comment - if: github.event_name == 'pull_request' || github.event_name == 'pull_request_target' - uses: marocchino/sticky-pull-request-comment@v2 - with: - header: compilation_memory - message: ${{ steps.compilation_mem_report.outputs.markdown }} + jq ".memory_reports | map({name: .artifact_name, value: (.peak_memory | tonumber), unit: \"MB\"}) " ./memory_report.json > memory_bench.json + + - name: Store benchmark result + continue-on-error: true + uses: benchmark-action/github-action-benchmark@4de1bed97a47495fc4c5404952da0499e31f5c29 + with: + name: "Compilation Memory" + tool: "customSmallerIsBetter" + output-file-path: ./memory_bench.json + github-token: ${{ secrets.GITHUB_TOKEN }} + # We want this to only run on master to avoid garbage data from PRs being added. + auto-push: ${{ github.ref == 'refs/heads/master' }} + alert-threshold: "120%" + comment-on-alert: true + fail-on-alert: false + alert-comment-cc-users: "@TomAFrench" + max-items-in-chart: 50 upload_execution_memory_report: name: Upload execution memory report needs: [generate_memory_report, external_repo_memory_report] # We want this job to run even if one variation of the matrix in `external_repo_memory_report` fails if: always() - runs-on: ubuntu-latest + runs-on: ubuntu-22.04 permissions: pull-requests: write + # deployments permission to deploy GitHub pages website + deployments: write + # contents permission to update benchmark contents in gh-pages branch + contents: write steps: - uses: actions/checkout@v4 @@ -643,31 +657,37 @@ jobs: ./merge-bench-reports.sh memory_report # Rename the memory report as to not clash with the compilation memory report file name cp memory_report.json execution_memory_report.json + jq ".memory_reports | map({name: .artifact_name, value: (.peak_memory | tonumber), unit: \"MB\"}) " ./execution_memory_report.json > memory_bench.json - - name: Parse execution memory report - id: execution_mem_report - uses: noir-lang/noir-bench-report@6ba151d7795042c4ff51864fbeb13c0a6a79246c + - name: Store benchmark result + continue-on-error: true + uses: benchmark-action/github-action-benchmark@4de1bed97a47495fc4c5404952da0499e31f5c29 with: - report: execution_memory_report.json - header: | - Execution Memory Report - memory_report: true + name: "Execution Memory" + tool: "customSmallerIsBetter" + output-file-path: ./memory_bench.json + github-token: ${{ secrets.GITHUB_TOKEN }} + # We want this to only run on master to avoid garbage data from PRs being added. + auto-push: ${{ github.ref == 'refs/heads/master' }} + alert-threshold: "120%" + comment-on-alert: true + fail-on-alert: false + alert-comment-cc-users: "@TomAFrench" + max-items-in-chart: 50 - - name: Add execution memory report to sticky comment - if: github.event_name == 'pull_request' || github.event_name == 'pull_request_target' - uses: marocchino/sticky-pull-request-comment@v2 - with: - header: execution_memory - message: ${{ steps.execution_mem_report.outputs.markdown }} upload_execution_report: name: Upload execution report needs: [generate_compilation_and_execution_report, external_repo_compilation_and_execution_report] # We want this job to run even if one variation of the matrix in `external_repo_compilation_and_execution_report` fails if: always() - runs-on: ubuntu-latest + runs-on: ubuntu-22.04 permissions: pull-requests: write + # deployments permission to deploy GitHub pages website + deployments: write + # contents permission to update benchmark contents in gh-pages branch + contents: write steps: - uses: actions/checkout@v4 @@ -687,20 +707,20 @@ jobs: run: | mv ./.github/scripts/merge-bench-reports.sh merge-bench-reports.sh ./merge-bench-reports.sh execution_report - - - name: Parse execution report - id: execution_report - uses: noir-lang/noir-bench-report@6ba151d7795042c4ff51864fbeb13c0a6a79246c - with: - report: execution_report.json - header: | - Execution Report - execution_report: true - - - name: Add memory report to sticky comment - if: github.event_name == 'pull_request' || github.event_name == 'pull_request_target' - uses: marocchino/sticky-pull-request-comment@v2 - with: - header: execution_time - message: ${{ steps.execution_report.outputs.markdown }} - + jq ".execution_reports | map({name: .artifact_name, value: (.time[:-1] | tonumber), unit: \"s\"}) " ./execution_report.json > time_bench.json + + - name: Store benchmark result + continue-on-error: true + uses: benchmark-action/github-action-benchmark@4de1bed97a47495fc4c5404952da0499e31f5c29 + with: + name: "Execution Time" + tool: "customSmallerIsBetter" + output-file-path: ./time_bench.json + github-token: ${{ secrets.GITHUB_TOKEN }} + # We want this to only run on master to avoid garbage data from PRs being added. + auto-push: ${{ github.ref == 'refs/heads/master' }} + alert-threshold: "120%" + comment-on-alert: true + fail-on-alert: false + alert-comment-cc-users: "@TomAFrench" + max-items-in-chart: 50 diff --git a/noir/noir-repo/.github/workflows/test-rust-workspace-msrv.yml b/noir/noir-repo/.github/workflows/test-rust-workspace-msrv.yml index f4fbbf79d89..38bc3cba153 100644 --- a/noir/noir-repo/.github/workflows/test-rust-workspace-msrv.yml +++ b/noir/noir-repo/.github/workflows/test-rust-workspace-msrv.yml @@ -52,7 +52,7 @@ jobs: tool: nextest@0.9.67 - name: Build and archive tests - run: cargo nextest archive --workspace --release --archive-file nextest-archive.tar.zst + run: cargo nextest archive --workspace --archive-file nextest-archive.tar.zst - name: Upload archive to workflow uses: actions/upload-artifact@v4 diff --git a/noir/noir-repo/.github/workflows/test-rust-workspace.yml b/noir/noir-repo/.github/workflows/test-rust-workspace.yml index 5d8abbc3e55..fe421361072 100644 --- a/noir/noir-repo/.github/workflows/test-rust-workspace.yml +++ b/noir/noir-repo/.github/workflows/test-rust-workspace.yml @@ -29,7 +29,7 @@ jobs: - uses: Swatinem/rust-cache@v2 with: - key: x86_64-unknown-linux-gnu + key: x86_64-unknown-linux-gnu-debug cache-on-failure: true save-if: ${{ github.event_name != 'merge_group' }} @@ -39,7 +39,7 @@ jobs: tool: nextest@0.9.67 - name: Build and archive tests - run: cargo nextest archive --workspace --release --archive-file nextest-archive.tar.zst + run: cargo nextest archive --workspace --archive-file nextest-archive.tar.zst - name: Upload archive to workflow uses: actions/upload-artifact@v4 diff --git a/noir/noir-repo/Cargo.lock b/noir/noir-repo/Cargo.lock index f8a16a2cfe2..f961c452862 100644 --- a/noir/noir-repo/Cargo.lock +++ b/noir/noir-repo/Cargo.lock @@ -3176,6 +3176,20 @@ dependencies = [ "rand", ] +[[package]] +name = "noir_inspector" +version = "1.0.0-beta.1" +dependencies = [ + "acir", + "clap", + "color-eyre", + "const_format", + "noirc_artifacts", + "noirc_artifacts_info", + "serde", + "serde_json", +] + [[package]] name = "noir_lsp" version = "1.0.0-beta.1" diff --git a/noir/noir-repo/Cargo.toml b/noir/noir-repo/Cargo.toml index ca899558420..58ca7665c0c 100644 --- a/noir/noir-repo/Cargo.toml +++ b/noir/noir-repo/Cargo.toml @@ -24,6 +24,7 @@ members = [ "tooling/noirc_abi_wasm", "tooling/acvm_cli", "tooling/profiler", + "tooling/inspector", # ACVM "acvm-repo/acir_field", "acvm-repo/acir", @@ -40,6 +41,7 @@ default-members = [ "tooling/nargo_cli", "tooling/acvm_cli", "tooling/profiler", + "tooling/inspector", ] resolver = "2" diff --git a/noir/noir-repo/compiler/noirc_driver/src/abi_gen.rs b/noir/noir-repo/compiler/noirc_driver/src/abi_gen.rs index 625a35c8d15..9838a8af210 100644 --- a/noir/noir-repo/compiler/noirc_driver/src/abi_gen.rs +++ b/noir/noir-repo/compiler/noirc_driver/src/abi_gen.rs @@ -110,9 +110,9 @@ pub(super) fn abi_type_from_hir_type(context: &Context, typ: &Type) -> AbiType { AbiType::String { length: size } } - Type::Struct(def, args) => { + Type::DataType(def, args) => { let struct_type = def.borrow(); - let fields = struct_type.get_fields(args); + let fields = struct_type.get_fields(args).unwrap_or_default(); let fields = vecmap(fields, |(name, typ)| (name, abi_type_from_hir_type(context, &typ))); // For the ABI, we always want to resolve the struct paths from the root crate diff --git a/noir/noir-repo/compiler/noirc_driver/src/lib.rs b/noir/noir-repo/compiler/noirc_driver/src/lib.rs index a7e7e2d4e2f..9b0172853c0 100644 --- a/noir/noir-repo/compiler/noirc_driver/src/lib.rs +++ b/noir/noir-repo/compiler/noirc_driver/src/lib.rs @@ -549,11 +549,12 @@ fn compile_contract_inner( let structs = structs .into_iter() .map(|struct_id| { - let typ = context.def_interner.get_struct(struct_id); + let typ = context.def_interner.get_type(struct_id); let typ = typ.borrow(); - let fields = vecmap(typ.get_fields(&[]), |(name, typ)| { - (name, abi_type_from_hir_type(context, &typ)) - }); + let fields = + vecmap(typ.get_fields(&[]).unwrap_or_default(), |(name, typ)| { + (name, abi_type_from_hir_type(context, &typ)) + }); let path = context.fully_qualified_struct_path(context.root_crate_id(), typ.id); AbiType::Struct { path, fields } diff --git a/noir/noir-repo/compiler/noirc_evaluator/src/brillig/brillig_gen.rs b/noir/noir-repo/compiler/noirc_evaluator/src/brillig/brillig_gen.rs index 5a81c79ae0d..a6117a8f2da 100644 --- a/noir/noir-repo/compiler/noirc_evaluator/src/brillig/brillig_gen.rs +++ b/noir/noir-repo/compiler/noirc_evaluator/src/brillig/brillig_gen.rs @@ -2,11 +2,13 @@ pub(crate) mod brillig_black_box; pub(crate) mod brillig_block; pub(crate) mod brillig_block_variables; pub(crate) mod brillig_fn; +pub(crate) mod brillig_globals; pub(crate) mod brillig_slice_ops; mod constant_allocation; mod variable_liveness; use acvm::FieldElement; +use fxhash::FxHashMap as HashMap; use self::{brillig_block::BrilligBlock, brillig_fn::FunctionContext}; use super::{ @@ -14,7 +16,7 @@ use super::{ artifact::{BrilligArtifact, BrilligParameter, GeneratedBrillig, Label}, BrilligContext, }, - Brillig, + Brillig, BrilligVariable, ValueId, }; use crate::{ errors::InternalError, @@ -25,6 +27,7 @@ use crate::{ pub(crate) fn convert_ssa_function( func: &Function, enable_debug_trace: bool, + globals: &HashMap, ) -> BrilligArtifact { let mut brillig_context = BrilligContext::new(enable_debug_trace); @@ -35,7 +38,13 @@ pub(crate) fn convert_ssa_function( brillig_context.call_check_max_stack_depth_procedure(); for block in function_context.blocks.clone() { - BrilligBlock::compile(&mut function_context, &mut brillig_context, block, &func.dfg); + BrilligBlock::compile( + &mut function_context, + &mut brillig_context, + block, + &func.dfg, + globals, + ); } let mut artifact = brillig_context.artifact(); @@ -53,6 +62,8 @@ pub(crate) fn gen_brillig_for( arguments, FunctionContext::return_values(func), func.id(), + true, + brillig.globals_memory_size, ); entry_point.name = func.name().to_string(); diff --git a/noir/noir-repo/compiler/noirc_evaluator/src/brillig/brillig_gen/brillig_block.rs b/noir/noir-repo/compiler/noirc_evaluator/src/brillig/brillig_gen/brillig_block.rs index ec918c51ff1..97de1aea8c7 100644 --- a/noir/noir-repo/compiler/noirc_evaluator/src/brillig/brillig_gen/brillig_block.rs +++ b/noir/noir-repo/compiler/noirc_evaluator/src/brillig/brillig_gen/brillig_block.rs @@ -3,7 +3,7 @@ use crate::brillig::brillig_ir::brillig_variable::{ type_to_heap_value_type, BrilligArray, BrilligVariable, SingleAddrVariable, }; -use crate::brillig::brillig_ir::registers::Stack; +use crate::brillig::brillig_ir::registers::RegisterAllocator; use crate::brillig::brillig_ir::{ BrilligBinaryOp, BrilligContext, ReservedRegisters, BRILLIG_MEMORY_ADDRESSING_BIT_SIZE, }; @@ -32,28 +32,41 @@ use super::brillig_fn::FunctionContext; use super::constant_allocation::InstructionLocation; /// Generate the compilation artifacts for compiling a function into brillig bytecode. -pub(crate) struct BrilligBlock<'block> { +pub(crate) struct BrilligBlock<'block, Registers: RegisterAllocator> { pub(crate) function_context: &'block mut FunctionContext, /// The basic block that is being converted pub(crate) block_id: BasicBlockId, /// Context for creating brillig opcodes - pub(crate) brillig_context: &'block mut BrilligContext, + pub(crate) brillig_context: &'block mut BrilligContext, /// Tracks the available variable during the codegen of the block pub(crate) variables: BlockVariables, /// For each instruction, the set of values that are not used anymore after it. pub(crate) last_uses: HashMap>, + + pub(crate) globals: &'block HashMap, + + pub(crate) building_globals: bool, } -impl<'block> BrilligBlock<'block> { +impl<'block, Registers: RegisterAllocator> BrilligBlock<'block, Registers> { /// Converts an SSA Basic block into a sequence of Brillig opcodes pub(crate) fn compile( function_context: &'block mut FunctionContext, - brillig_context: &'block mut BrilligContext, + brillig_context: &'block mut BrilligContext, block_id: BasicBlockId, dfg: &DataFlowGraph, + globals: &'block HashMap, ) { let live_in = function_context.liveness.get_live_in(&block_id); - let variables = BlockVariables::new(live_in.clone()); + + let mut live_in_no_globals = HashSet::default(); + for value in live_in { + if !dfg.is_global(*value) { + live_in_no_globals.insert(*value); + } + } + + let variables = BlockVariables::new(live_in_no_globals); brillig_context.set_allocated_registers( variables @@ -64,12 +77,44 @@ impl<'block> BrilligBlock<'block> { ); let last_uses = function_context.liveness.get_last_uses(&block_id).clone(); - let mut brillig_block = - BrilligBlock { function_context, block_id, brillig_context, variables, last_uses }; + let mut brillig_block = BrilligBlock { + function_context, + block_id, + brillig_context, + variables, + last_uses, + globals, + building_globals: false, + }; brillig_block.convert_block(dfg); } + pub(crate) fn compile_globals( + &mut self, + globals: &DataFlowGraph, + used_globals: &HashSet, + ) { + for (id, value) in globals.values_iter() { + if !used_globals.contains(&id) { + continue; + } + match value { + Value::NumericConstant { .. } => { + self.convert_ssa_value(id, globals); + } + Value::Instruction { instruction, .. } => { + self.convert_ssa_instruction(*instruction, globals); + } + _ => { + panic!( + "Expected either an instruction or a numeric constant for a global value" + ) + } + } + } + } + fn convert_block(&mut self, dfg: &DataFlowGraph) { // Add a label for this block let block_label = self.create_block_label_for_current_function(self.block_id); @@ -97,7 +142,7 @@ impl<'block> BrilligBlock<'block> { /// Making the assumption that the block ID passed in belongs to this /// function. fn create_block_label_for_current_function(&self, block_id: BasicBlockId) -> Label { - Self::create_block_label(self.function_context.function_id, block_id) + Self::create_block_label(self.function_context.function_id(), block_id) } /// Creates a unique label for a block using the function Id and the block ID. /// @@ -199,7 +244,11 @@ impl<'block> BrilligBlock<'block> { } /// Converts an SSA instruction into a sequence of Brillig opcodes. - fn convert_ssa_instruction(&mut self, instruction_id: InstructionId, dfg: &DataFlowGraph) { + pub(crate) fn convert_ssa_instruction( + &mut self, + instruction_id: InstructionId, + dfg: &DataFlowGraph, + ) { let instruction = &dfg[instruction_id]; self.brillig_context.set_call_stack(dfg.get_instruction_call_stack(instruction_id)); @@ -847,18 +896,24 @@ impl<'block> BrilligBlock<'block> { Instruction::Noop => (), }; - let dead_variables = self - .last_uses - .get(&instruction_id) - .expect("Last uses for instruction should have been computed"); - - for dead_variable in dead_variables { - self.variables.remove_variable( - dead_variable, - self.function_context, - self.brillig_context, - ); + if !self.building_globals { + let dead_variables = self + .last_uses + .get(&instruction_id) + .expect("Last uses for instruction should have been computed"); + + for dead_variable in dead_variables { + // Globals are reserved throughout the entirety of the program + if !dfg.is_global(*dead_variable) { + self.variables.remove_variable( + dead_variable, + self.function_context, + self.brillig_context, + ); + } + } } + self.brillig_context.set_call_stack(CallStack::new()); } @@ -1289,8 +1344,8 @@ impl<'block> BrilligBlock<'block> { result_variable: SingleAddrVariable, ) { let binary_type = type_of_binary_operation( - dfg[binary.lhs].get_type().as_ref(), - dfg[binary.rhs].get_type().as_ref(), + dfg[dfg.resolve(binary.lhs)].get_type().as_ref(), + dfg[dfg.resolve(binary.rhs)].get_type().as_ref(), binary.operator, ); @@ -1559,25 +1614,38 @@ impl<'block> BrilligBlock<'block> { } /// Converts an SSA `ValueId` into a `RegisterOrMemory`. Initializes if necessary. - fn convert_ssa_value(&mut self, value_id: ValueId, dfg: &DataFlowGraph) -> BrilligVariable { + pub(crate) fn convert_ssa_value( + &mut self, + value_id: ValueId, + dfg: &DataFlowGraph, + ) -> BrilligVariable { let value_id = dfg.resolve(value_id); let value = &dfg[value_id]; match value { Value::Global(_) => { - unreachable!("ICE: All globals should have been inlined"); + unreachable!("Expected global value to be resolve to its inner value"); } Value::Param { .. } | Value::Instruction { .. } => { // All block parameters and instruction results should have already been // converted to registers so we fetch from the cache. - - self.variables.get_allocation(self.function_context, value_id, dfg) + if dfg.is_global(value_id) { + *self.globals.get(&value_id).unwrap_or_else(|| { + panic!("ICE: Global value not found in cache {value_id}") + }) + } else { + self.variables.get_allocation(self.function_context, value_id, dfg) + } } Value::NumericConstant { constant, .. } => { // Constants might have been converted previously or not, so we get or create and // (re)initialize the value inside. if self.variables.is_allocated(&value_id) { self.variables.get_allocation(self.function_context, value_id, dfg) + } else if dfg.is_global(value_id) { + *self.globals.get(&value_id).unwrap_or_else(|| { + panic!("ICE: Global value not found in cache {value_id}") + }) } else { let new_variable = self.variables.define_variable( self.function_context, diff --git a/noir/noir-repo/compiler/noirc_evaluator/src/brillig/brillig_gen/brillig_block_variables.rs b/noir/noir-repo/compiler/noirc_evaluator/src/brillig/brillig_gen/brillig_block_variables.rs index bf0a1bc7347..4cf8e921483 100644 --- a/noir/noir-repo/compiler/noirc_evaluator/src/brillig/brillig_gen/brillig_block_variables.rs +++ b/noir/noir-repo/compiler/noirc_evaluator/src/brillig/brillig_gen/brillig_block_variables.rs @@ -7,7 +7,7 @@ use crate::{ get_bit_size_from_ssa_type, BrilligArray, BrilligVariable, BrilligVector, SingleAddrVariable, }, - registers::{RegisterAllocator, Stack}, + registers::RegisterAllocator, BrilligContext, }, ssa::ir::{ @@ -48,10 +48,10 @@ impl BlockVariables { } /// For a given SSA value id, define the variable and return the corresponding cached allocation. - pub(crate) fn define_variable( + pub(crate) fn define_variable( &mut self, function_context: &mut FunctionContext, - brillig_context: &mut BrilligContext, + brillig_context: &mut BrilligContext, value_id: ValueId, dfg: &DataFlowGraph, ) -> BrilligVariable { @@ -68,10 +68,10 @@ impl BlockVariables { } /// Defines a variable that fits in a single register and returns the allocated register. - pub(crate) fn define_single_addr_variable( + pub(crate) fn define_single_addr_variable( &mut self, function_context: &mut FunctionContext, - brillig_context: &mut BrilligContext, + brillig_context: &mut BrilligContext, value: ValueId, dfg: &DataFlowGraph, ) -> SingleAddrVariable { @@ -80,11 +80,11 @@ impl BlockVariables { } /// Removes a variable so it's not used anymore within this block. - pub(crate) fn remove_variable( + pub(crate) fn remove_variable( &mut self, value_id: &ValueId, function_context: &mut FunctionContext, - brillig_context: &mut BrilligContext, + brillig_context: &mut BrilligContext, ) { assert!(self.available_variables.remove(value_id), "ICE: Variable is not available"); let variable = function_context @@ -133,6 +133,14 @@ pub(crate) fn allocate_value( ) -> BrilligVariable { let typ = dfg.type_of_value(value_id); + allocate_value_with_type(brillig_context, typ) +} + +/// For a given value_id, allocates the necessary registers to hold it. +pub(crate) fn allocate_value_with_type( + brillig_context: &mut BrilligContext, + typ: Type, +) -> BrilligVariable { match typ { Type::Numeric(_) | Type::Reference(_) | Type::Function => { BrilligVariable::SingleAddr(SingleAddrVariable { diff --git a/noir/noir-repo/compiler/noirc_evaluator/src/brillig/brillig_gen/brillig_fn.rs b/noir/noir-repo/compiler/noirc_evaluator/src/brillig/brillig_gen/brillig_fn.rs index 3dea7b3e7f5..6e406e2b3cb 100644 --- a/noir/noir-repo/compiler/noirc_evaluator/src/brillig/brillig_gen/brillig_fn.rs +++ b/noir/noir-repo/compiler/noirc_evaluator/src/brillig/brillig_gen/brillig_fn.rs @@ -17,8 +17,11 @@ use fxhash::FxHashMap as HashMap; use super::{constant_allocation::ConstantAllocation, variable_liveness::VariableLiveness}; +#[derive(Default)] pub(crate) struct FunctionContext { - pub(crate) function_id: FunctionId, + /// A `FunctionContext` is necessary for using a Brillig block's code gen, but sometimes + /// such as with globals, we are not within a function and do not have a function id. + function_id: Option, /// Map from SSA values its allocation. Since values can be only defined once in SSA form, we insert them here on when we allocate them at their definition. pub(crate) ssa_value_allocations: HashMap, /// The block ids of the function in reverse post order. @@ -42,7 +45,7 @@ impl FunctionContext { let liveness = VariableLiveness::from_function(function, &constants); Self { - function_id: id, + function_id: Some(id), ssa_value_allocations: HashMap::default(), blocks: reverse_post_order, liveness, @@ -50,6 +53,10 @@ impl FunctionContext { } } + pub(crate) fn function_id(&self) -> FunctionId { + self.function_id.expect("ICE: function_id should already be set") + } + pub(crate) fn ssa_type_to_parameter(typ: &Type) -> BrilligParameter { match typ { Type::Numeric(_) | Type::Reference(_) => { diff --git a/noir/noir-repo/compiler/noirc_evaluator/src/brillig/brillig_gen/brillig_globals.rs b/noir/noir-repo/compiler/noirc_evaluator/src/brillig/brillig_gen/brillig_globals.rs new file mode 100644 index 00000000000..9f9d271283d --- /dev/null +++ b/noir/noir-repo/compiler/noirc_evaluator/src/brillig/brillig_gen/brillig_globals.rs @@ -0,0 +1,43 @@ +use acvm::FieldElement; +use fxhash::{FxHashMap as HashMap, FxHashSet as HashSet}; + +use super::{BrilligArtifact, BrilligBlock, BrilligVariable, FunctionContext, Label, ValueId}; +use crate::{ + brillig::{brillig_ir::BrilligContext, DataFlowGraph}, + ssa::ir::dfg::GlobalsGraph, +}; + +pub(crate) fn convert_ssa_globals( + enable_debug_trace: bool, + globals: GlobalsGraph, + used_globals: &HashSet, +) -> (BrilligArtifact, HashMap, usize) { + let mut brillig_context = BrilligContext::new_for_global_init(enable_debug_trace); + // The global space does not have globals itself + let empty_globals = HashMap::default(); + // We can use any ID here as this context is only going to be used for globals which does not differentiate + // by functions and blocks. The only Label that should be used in the globals context is `Label::globals_init()` + let mut function_context = FunctionContext::default(); + brillig_context.enter_context(Label::globals_init()); + + let block_id = DataFlowGraph::default().make_block(); + let mut brillig_block = BrilligBlock { + function_context: &mut function_context, + block_id, + brillig_context: &mut brillig_context, + variables: Default::default(), + last_uses: HashMap::default(), + globals: &empty_globals, + building_globals: true, + }; + + let globals_dfg = DataFlowGraph::from(globals); + brillig_block.compile_globals(&globals_dfg, used_globals); + + let globals_size = brillig_block.brillig_context.global_space_size(); + + brillig_context.return_instruction(); + + let artifact = brillig_context.artifact(); + (artifact, function_context.ssa_value_allocations, globals_size) +} diff --git a/noir/noir-repo/compiler/noirc_evaluator/src/brillig/brillig_gen/brillig_slice_ops.rs b/noir/noir-repo/compiler/noirc_evaluator/src/brillig/brillig_gen/brillig_slice_ops.rs index 26c7151bf07..1ec2d165b12 100644 --- a/noir/noir-repo/compiler/noirc_evaluator/src/brillig/brillig_gen/brillig_slice_ops.rs +++ b/noir/noir-repo/compiler/noirc_evaluator/src/brillig/brillig_gen/brillig_slice_ops.rs @@ -2,12 +2,13 @@ use acvm::acir::brillig::MemoryAddress; use crate::brillig::brillig_ir::{ brillig_variable::{BrilligVariable, BrilligVector, SingleAddrVariable}, + registers::RegisterAllocator, BrilligBinaryOp, }; use super::brillig_block::BrilligBlock; -impl<'block> BrilligBlock<'block> { +impl<'block, Registers: RegisterAllocator> BrilligBlock<'block, Registers> { fn write_variables(&mut self, write_pointer: MemoryAddress, variables: &[BrilligVariable]) { for (index, variable) in variables.iter().enumerate() { self.brillig_context.store_instruction(write_pointer, variable.extract_register()); @@ -159,6 +160,7 @@ mod tests { use std::vec; use acvm::FieldElement; + use fxhash::FxHashMap as HashMap; use noirc_frontend::monomorphization::ast::InlineType; use crate::brillig::brillig_gen::brillig_block::BrilligBlock; @@ -173,6 +175,7 @@ mod tests { create_and_run_vm, create_context, create_entry_point_bytecode, }; use crate::brillig::brillig_ir::{BrilligContext, BRILLIG_MEMORY_ADDRESSING_BIT_SIZE}; + use crate::brillig::ValueId; use crate::ssa::function_builder::FunctionBuilder; use crate::ssa::ir::function::RuntimeType; use crate::ssa::ir::map::Id; @@ -193,7 +196,8 @@ mod tests { fn create_brillig_block<'a>( function_context: &'a mut FunctionContext, brillig_context: &'a mut BrilligContext, - ) -> BrilligBlock<'a> { + globals: &'a HashMap, + ) -> BrilligBlock<'a, Stack> { let variables = BlockVariables::default(); BrilligBlock { function_context, @@ -201,6 +205,8 @@ mod tests { brillig_context, variables, last_uses: Default::default(), + globals, + building_globals: false, } } @@ -242,7 +248,9 @@ mod tests { // Allocate the results let target_vector = BrilligVector { pointer: context.allocate_register() }; - let mut block = create_brillig_block(&mut function_context, &mut context); + let brillig_globals = HashMap::default(); + let mut block = + create_brillig_block(&mut function_context, &mut context, &brillig_globals); if push_back { block.slice_push_back_operation( @@ -358,7 +366,9 @@ mod tests { bit_size: BRILLIG_MEMORY_ADDRESSING_BIT_SIZE, }; - let mut block = create_brillig_block(&mut function_context, &mut context); + let brillig_globals = HashMap::default(); + let mut block = + create_brillig_block(&mut function_context, &mut context, &brillig_globals); if pop_back { block.slice_pop_back_operation( @@ -464,7 +474,9 @@ mod tests { // Allocate the results let target_vector = BrilligVector { pointer: context.allocate_register() }; - let mut block = create_brillig_block(&mut function_context, &mut context); + let brillig_globals = HashMap::default(); + let mut block = + create_brillig_block(&mut function_context, &mut context, &brillig_globals); block.slice_insert_operation( target_vector, @@ -604,7 +616,9 @@ mod tests { bit_size: BRILLIG_MEMORY_ADDRESSING_BIT_SIZE, }; - let mut block = create_brillig_block(&mut function_context, &mut context); + let brillig_globals = HashMap::default(); + let mut block = + create_brillig_block(&mut function_context, &mut context, &brillig_globals); block.slice_remove_operation( target_vector, diff --git a/noir/noir-repo/compiler/noirc_evaluator/src/brillig/brillig_gen/constant_allocation.rs b/noir/noir-repo/compiler/noirc_evaluator/src/brillig/brillig_gen/constant_allocation.rs index 61ca20be2f5..64741393dd7 100644 --- a/noir/noir-repo/compiler/noirc_evaluator/src/brillig/brillig_gen/constant_allocation.rs +++ b/noir/noir-repo/compiler/noirc_evaluator/src/brillig/brillig_gen/constant_allocation.rs @@ -22,6 +22,7 @@ pub(crate) enum InstructionLocation { Terminator, } +#[derive(Default)] pub(crate) struct ConstantAllocation { constant_usage: HashMap>>, allocation_points: HashMap>>, diff --git a/noir/noir-repo/compiler/noirc_evaluator/src/brillig/brillig_gen/variable_liveness.rs b/noir/noir-repo/compiler/noirc_evaluator/src/brillig/brillig_gen/variable_liveness.rs index 6bcadc3910d..37a63466119 100644 --- a/noir/noir-repo/compiler/noirc_evaluator/src/brillig/brillig_gen/variable_liveness.rs +++ b/noir/noir-repo/compiler/noirc_evaluator/src/brillig/brillig_gen/variable_liveness.rs @@ -114,6 +114,7 @@ fn compute_used_before_def( type LastUses = HashMap; /// A struct representing the liveness of variables throughout a function. +#[derive(Default)] pub(crate) struct VariableLiveness { cfg: ControlFlowGraph, post_order: PostOrder, diff --git a/noir/noir-repo/compiler/noirc_evaluator/src/brillig/brillig_ir.rs b/noir/noir-repo/compiler/noirc_evaluator/src/brillig/brillig_ir.rs index 55e12c993fa..ad09f73e90f 100644 --- a/noir/noir-repo/compiler/noirc_evaluator/src/brillig/brillig_ir.rs +++ b/noir/noir-repo/compiler/noirc_evaluator/src/brillig/brillig_ir.rs @@ -37,7 +37,7 @@ use acvm::{ }; use debug_show::DebugShow; -use super::ProcedureId; +use super::{GlobalSpace, ProcedureId}; /// The Brillig VM does not apply a limit to the memory address space, /// As a convention, we take use 32 bits. This means that we assume that @@ -95,6 +95,8 @@ pub(crate) struct BrilligContext { /// Whether this context can call procedures or not. /// This is used to prevent a procedure from calling another procedure. can_call_procedures: bool, + + globals_memory_size: Option, } /// Regular brillig context to codegen user defined functions @@ -108,9 +110,12 @@ impl BrilligContext { next_section: 1, debug_show: DebugShow::new(enable_debug_trace), can_call_procedures: true, + globals_memory_size: None, } } +} +impl BrilligContext { /// Splits a two's complement signed integer in the sign bit and the absolute value. /// For example, -6 i8 (11111010) is split to 00000110 (6, absolute value) and 1 (is_negative). pub(crate) fn absolute_value( @@ -209,10 +214,32 @@ impl BrilligContext { next_section: 1, debug_show: DebugShow::new(enable_debug_trace), can_call_procedures: false, + globals_memory_size: None, } } } +/// Special brillig context to codegen global values initialization +impl BrilligContext { + pub(crate) fn new_for_global_init(enable_debug_trace: bool) -> BrilligContext { + BrilligContext { + obj: BrilligArtifact::default(), + registers: GlobalSpace::new(), + context_label: Label::globals_init(), + current_section: 0, + next_section: 1, + debug_show: DebugShow::new(enable_debug_trace), + can_call_procedures: false, + globals_memory_size: None, + } + } + + pub(crate) fn global_space_size(&self) -> usize { + // `GlobalSpace::start()` is inclusive so we must add one to get the accurate total global memory size + (self.registers.max_memory_address() + 1) - GlobalSpace::start() + } +} + impl BrilligContext { /// Adds a brillig instruction to the brillig byte code fn push_opcode(&mut self, opcode: BrilligOpcode) { @@ -299,8 +326,13 @@ pub(crate) mod tests { returns: Vec, ) -> GeneratedBrillig { let artifact = context.artifact(); - let mut entry_point_artifact = - BrilligContext::new_entry_point_artifact(arguments, returns, FunctionId::test_new(0)); + let mut entry_point_artifact = BrilligContext::new_entry_point_artifact( + arguments, + returns, + FunctionId::test_new(0), + false, + 0, + ); entry_point_artifact.link_with(&artifact); while let Some(unresolved_fn_label) = entry_point_artifact.first_unresolved_function_call() { diff --git a/noir/noir-repo/compiler/noirc_evaluator/src/brillig/brillig_ir/artifact.rs b/noir/noir-repo/compiler/noirc_evaluator/src/brillig/brillig_ir/artifact.rs index 3654a95a03f..4c48675d1e7 100644 --- a/noir/noir-repo/compiler/noirc_evaluator/src/brillig/brillig_ir/artifact.rs +++ b/noir/noir-repo/compiler/noirc_evaluator/src/brillig/brillig_ir/artifact.rs @@ -75,6 +75,8 @@ pub(crate) enum LabelType { Function(FunctionId, Option), /// Labels for intrinsic procedures Procedure(ProcedureId), + /// Label for initialization of globals + GlobalInit, } impl std::fmt::Display for LabelType { @@ -89,6 +91,7 @@ impl std::fmt::Display for LabelType { } LabelType::Entrypoint => write!(f, "Entrypoint"), LabelType::Procedure(procedure_id) => write!(f, "Procedure({:?})", procedure_id), + LabelType::GlobalInit => write!(f, "Globals Initialization"), } } } @@ -123,6 +126,10 @@ impl Label { pub(crate) fn procedure(procedure_id: ProcedureId) -> Self { Label { label_type: LabelType::Procedure(procedure_id), section: None } } + + pub(crate) fn globals_init() -> Self { + Label { label_type: LabelType::GlobalInit, section: None } + } } impl std::fmt::Display for Label { diff --git a/noir/noir-repo/compiler/noirc_evaluator/src/brillig/brillig_ir/codegen_calls.rs b/noir/noir-repo/compiler/noirc_evaluator/src/brillig/brillig_ir/codegen_calls.rs index da310873cff..4da3aa4d6d2 100644 --- a/noir/noir-repo/compiler/noirc_evaluator/src/brillig/brillig_ir/codegen_calls.rs +++ b/noir/noir-repo/compiler/noirc_evaluator/src/brillig/brillig_ir/codegen_calls.rs @@ -9,7 +9,8 @@ use super::{ BrilligBinaryOp, BrilligContext, ReservedRegisters, }; -impl BrilligContext { +impl BrilligContext { + // impl BrilligContext { pub(crate) fn codegen_call( &mut self, func_id: FunctionId, @@ -17,7 +18,7 @@ impl BrilligContext { returns: &[BrilligVariable], ) { let stack_size_register = SingleAddrVariable::new_usize(self.allocate_register()); - let previous_stack_pointer = self.registers.empty_stack_start(); + let previous_stack_pointer = self.registers.empty_registers_start(); let stack_size = previous_stack_pointer.unwrap_relative(); // Write the stack size self.const_instruction(stack_size_register, stack_size.into()); diff --git a/noir/noir-repo/compiler/noirc_evaluator/src/brillig/brillig_ir/entry_point.rs b/noir/noir-repo/compiler/noirc_evaluator/src/brillig/brillig_ir/entry_point.rs index 2dbee48b277..030ed7133e8 100644 --- a/noir/noir-repo/compiler/noirc_evaluator/src/brillig/brillig_ir/entry_point.rs +++ b/noir/noir-repo/compiler/noirc_evaluator/src/brillig/brillig_ir/entry_point.rs @@ -22,23 +22,34 @@ impl BrilligContext { arguments: Vec, return_parameters: Vec, target_function: FunctionId, + globals_init: bool, + globals_memory_size: usize, ) -> BrilligArtifact { let mut context = BrilligContext::new(false); + context.globals_memory_size = Some(globals_memory_size); + context.codegen_entry_point(&arguments, &return_parameters); + if globals_init { + context.add_globals_init_instruction(); + } + context.add_external_call_instruction(target_function); context.codegen_exit_point(&arguments, &return_parameters); context.artifact() } - fn calldata_start_offset() -> usize { - ReservedRegisters::len() + MAX_STACK_SIZE + MAX_SCRATCH_SPACE + fn calldata_start_offset(&self) -> usize { + ReservedRegisters::len() + + MAX_STACK_SIZE + + MAX_SCRATCH_SPACE + + self.globals_memory_size.expect("The memory size of globals should be set") } - fn return_data_start_offset(calldata_size: usize) -> usize { - Self::calldata_start_offset() + calldata_size + fn return_data_start_offset(&self, calldata_size: usize) -> usize { + self.calldata_start_offset() + calldata_size } /// Adds the instructions needed to handle entry point parameters @@ -64,7 +75,7 @@ impl BrilligContext { // Set initial value of free memory pointer: calldata_start_offset + calldata_size + return_data_size self.const_instruction( SingleAddrVariable::new_usize(ReservedRegisters::free_memory_pointer()), - (Self::calldata_start_offset() + calldata_size + return_data_size).into(), + (self.calldata_start_offset() + calldata_size + return_data_size).into(), ); // Set initial value of stack pointer: ReservedRegisters.len() @@ -76,7 +87,7 @@ impl BrilligContext { // Copy calldata self.copy_and_cast_calldata(arguments); - let mut current_calldata_pointer = Self::calldata_start_offset(); + let mut current_calldata_pointer = self.calldata_start_offset(); // Initialize the variables with the calldata for (argument_variable, argument) in argument_variables.iter_mut().zip(arguments) { @@ -152,7 +163,7 @@ impl BrilligContext { fn copy_and_cast_calldata(&mut self, arguments: &[BrilligParameter]) { let calldata_size = Self::flattened_tuple_size(arguments); self.calldata_copy_instruction( - MemoryAddress::direct(Self::calldata_start_offset()), + MemoryAddress::direct(self.calldata_start_offset()), calldata_size, 0, ); @@ -172,11 +183,11 @@ impl BrilligContext { if bit_size < F::max_num_bits() { self.cast_instruction( SingleAddrVariable::new( - MemoryAddress::direct(Self::calldata_start_offset() + i), + MemoryAddress::direct(self.calldata_start_offset() + i), bit_size, ), SingleAddrVariable::new_field(MemoryAddress::direct( - Self::calldata_start_offset() + i, + self.calldata_start_offset() + i, )), ); } @@ -330,7 +341,7 @@ impl BrilligContext { let return_data_size = Self::flattened_tuple_size(return_parameters); // Return data has a reserved space after calldata - let return_data_offset = Self::return_data_start_offset(calldata_size); + let return_data_offset = self.return_data_start_offset(calldata_size); let mut return_data_index = return_data_offset; for (return_param, returned_variable) in return_parameters.iter().zip(&returned_variables) { diff --git a/noir/noir-repo/compiler/noirc_evaluator/src/brillig/brillig_ir/instructions.rs b/noir/noir-repo/compiler/noirc_evaluator/src/brillig/brillig_ir/instructions.rs index 2bf5364414c..d67da423d44 100644 --- a/noir/noir-repo/compiler/noirc_evaluator/src/brillig/brillig_ir/instructions.rs +++ b/noir/noir-repo/compiler/noirc_evaluator/src/brillig/brillig_ir/instructions.rs @@ -200,6 +200,13 @@ impl BrilligContext< self.obj.add_unresolved_external_call(BrilligOpcode::Call { location: 0 }, proc_label); } + pub(super) fn add_globals_init_instruction(&mut self) { + let globals_init_label = Label::globals_init(); + self.debug_show.add_external_call_instruction(globals_init_label.to_string()); + self.obj + .add_unresolved_external_call(BrilligOpcode::Call { location: 0 }, globals_init_label); + } + /// Adds a unresolved `Jump` instruction to the bytecode. pub(crate) fn jump_instruction(&mut self, target_label: Label) { self.debug_show.jump_instruction(target_label.to_string()); diff --git a/noir/noir-repo/compiler/noirc_evaluator/src/brillig/brillig_ir/registers.rs b/noir/noir-repo/compiler/noirc_evaluator/src/brillig/brillig_ir/registers.rs index dd7766f40aa..88b8a598b10 100644 --- a/noir/noir-repo/compiler/noirc_evaluator/src/brillig/brillig_ir/registers.rs +++ b/noir/noir-repo/compiler/noirc_evaluator/src/brillig/brillig_ir/registers.rs @@ -24,6 +24,8 @@ pub(crate) trait RegisterAllocator { fn ensure_register_is_allocated(&mut self, register: MemoryAddress); /// Creates a new register context from a set of registers allocated previously. fn from_preallocated_registers(preallocated_registers: Vec) -> Self; + /// Finds the first register that is available based upon the deallocation list + fn empty_registers_start(&self) -> MemoryAddress; } /// Every brillig stack frame/call context has its own view of register space. @@ -41,10 +43,6 @@ impl Stack { let offset = register.unwrap_relative(); offset >= Self::start() && offset < Self::end() } - - pub(crate) fn empty_stack_start(&self) -> MemoryAddress { - MemoryAddress::relative(self.storage.empty_registers_start(Self::start())) - } } impl RegisterAllocator for Stack { @@ -83,6 +81,10 @@ impl RegisterAllocator for Stack { ), } } + + fn empty_registers_start(&self) -> MemoryAddress { + MemoryAddress::relative(self.storage.empty_registers_start(Self::start())) + } } /// Procedure arguments and returns are passed through scratch space. @@ -109,7 +111,7 @@ impl RegisterAllocator for ScratchSpace { } fn end() -> usize { - ReservedRegisters::len() + MAX_STACK_SIZE + MAX_SCRATCH_SPACE + Self::start() + MAX_SCRATCH_SPACE } fn ensure_register_is_allocated(&mut self, register: MemoryAddress) { @@ -139,6 +141,87 @@ impl RegisterAllocator for ScratchSpace { ), } } + + fn empty_registers_start(&self) -> MemoryAddress { + MemoryAddress::direct(self.storage.empty_registers_start(Self::start())) + } +} + +/// Globals have a separate memory space +/// This memory space is initialized once at the beginning of a program +/// and is read-only. +pub(crate) struct GlobalSpace { + storage: DeallocationListAllocator, + max_memory_address: usize, +} + +impl GlobalSpace { + pub(super) fn new() -> Self { + Self { + storage: DeallocationListAllocator::new(Self::start()), + max_memory_address: Self::start(), + } + } + + fn is_within_bounds(register: MemoryAddress) -> bool { + let index = register.unwrap_direct(); + index >= Self::start() + } + + fn update_max_address(&mut self, register: MemoryAddress) { + let index = register.unwrap_direct(); + assert!(index >= Self::start(), "Global space malformed"); + if index > self.max_memory_address { + self.max_memory_address = index; + } + } + + pub(super) fn max_memory_address(&self) -> usize { + self.max_memory_address + } +} + +impl RegisterAllocator for GlobalSpace { + fn start() -> usize { + ScratchSpace::end() + } + + fn end() -> usize { + unreachable!("The global space is set by the program"); + } + + fn allocate_register(&mut self) -> MemoryAddress { + let allocated = MemoryAddress::direct(self.storage.allocate_register()); + self.update_max_address(allocated); + allocated + } + + fn deallocate_register(&mut self, register_index: MemoryAddress) { + self.storage.deallocate_register(register_index.unwrap_direct()); + } + + fn ensure_register_is_allocated(&mut self, register: MemoryAddress) { + self.update_max_address(register); + self.storage.ensure_register_is_allocated(register.unwrap_direct()); + } + + fn from_preallocated_registers(preallocated_registers: Vec) -> Self { + for register in &preallocated_registers { + assert!(Self::is_within_bounds(*register), "Register out of global space bounds"); + } + + Self { + storage: DeallocationListAllocator::from_preallocated_registers( + Self::start(), + vecmap(preallocated_registers, |r| r.unwrap_direct()), + ), + max_memory_address: Self::start(), + } + } + + fn empty_registers_start(&self) -> MemoryAddress { + MemoryAddress::direct(self.storage.empty_registers_start(Self::start())) + } } struct DeallocationListAllocator { diff --git a/noir/noir-repo/compiler/noirc_evaluator/src/brillig/mod.rs b/noir/noir-repo/compiler/noirc_evaluator/src/brillig/mod.rs index cb8c35cd8e0..b74c519f61a 100644 --- a/noir/noir-repo/compiler/noirc_evaluator/src/brillig/mod.rs +++ b/noir/noir-repo/compiler/noirc_evaluator/src/brillig/mod.rs @@ -2,7 +2,8 @@ pub(crate) mod brillig_gen; pub(crate) mod brillig_ir; use acvm::FieldElement; -use brillig_ir::artifact::LabelType; +use brillig_gen::brillig_globals::convert_ssa_globals; +use brillig_ir::{artifact::LabelType, brillig_variable::BrilligVariable, registers::GlobalSpace}; use self::{ brillig_gen::convert_ssa_function, @@ -12,7 +13,11 @@ use self::{ }, }; use crate::ssa::{ - ir::function::{Function, FunctionId}, + ir::{ + dfg::DataFlowGraph, + function::{Function, FunctionId}, + value::ValueId, + }, ssa_gen::Ssa, }; use fxhash::FxHashMap as HashMap; @@ -26,12 +31,19 @@ pub use self::brillig_ir::procedures::ProcedureId; pub struct Brillig { /// Maps SSA function labels to their brillig artifact ssa_function_to_brillig: HashMap>, + globals: BrilligArtifact, + globals_memory_size: usize, } impl Brillig { /// Compiles a function into brillig and store the compilation artifacts - pub(crate) fn compile(&mut self, func: &Function, enable_debug_trace: bool) { - let obj = convert_ssa_function(func, enable_debug_trace); + pub(crate) fn compile( + &mut self, + func: &Function, + enable_debug_trace: bool, + globals: &HashMap, + ) { + let obj = convert_ssa_function(func, enable_debug_trace, globals); self.ssa_function_to_brillig.insert(func.id(), obj); } @@ -46,6 +58,7 @@ impl Brillig { } // Procedures are compiled as needed LabelType::Procedure(procedure_id) => Some(Cow::Owned(compile_procedure(procedure_id))), + LabelType::GlobalInit => Some(Cow::Borrowed(&self.globals)), _ => unreachable!("ICE: Expected a function or procedure label"), } } @@ -71,9 +84,22 @@ impl Ssa { .collect::>(); let mut brillig = Brillig::default(); + + if brillig_reachable_function_ids.is_empty() { + return brillig; + } + + // Globals are computed once at compile time and shared across all functions, + // thus we can just fetch globals from the main function. + let globals = (*self.functions[&self.main_id].dfg.globals).clone(); + let (artifact, brillig_globals, globals_size) = + convert_ssa_globals(enable_debug_trace, globals, &self.used_global_values); + brillig.globals = artifact; + brillig.globals_memory_size = globals_size; + for brillig_function_id in brillig_reachable_function_ids { let func = &self.functions[&brillig_function_id]; - brillig.compile(func, enable_debug_trace); + brillig.compile(func, enable_debug_trace, &brillig_globals); } brillig diff --git a/noir/noir-repo/compiler/noirc_evaluator/src/ssa.rs b/noir/noir-repo/compiler/noirc_evaluator/src/ssa.rs index ed515bbe98c..3c2857f7399 100644 --- a/noir/noir-repo/compiler/noirc_evaluator/src/ssa.rs +++ b/noir/noir-repo/compiler/noirc_evaluator/src/ssa.rs @@ -150,15 +150,23 @@ pub(crate) fn optimize_into_acir( /// Run all SSA passes. fn optimize_all(builder: SsaBuilder, options: &SsaEvaluatorOptions) -> Result { Ok(builder - .run_pass(Ssa::remove_unreachable_functions, "Removing Unreachable Functions") + .run_pass(Ssa::remove_unreachable_functions, "Removing Unreachable Functions (1st)") .run_pass(Ssa::defunctionalize, "Defunctionalization") + .run_pass(Ssa::inline_simple_functions, "Inlining simple functions") + // BUG: Enabling this mem2reg causes an integration test failure in aztec-package; see: + // https://github.com/AztecProtocol/aztec-packages/pull/11294#issuecomment-2622809518 + //.run_pass(Ssa::mem2reg, "Mem2Reg (1st)") .run_pass(Ssa::remove_paired_rc, "Removing Paired rc_inc & rc_decs") + .run_pass( + |ssa| ssa.preprocess_functions(options.inliner_aggressiveness), + "Preprocessing Functions", + ) .run_pass(|ssa| ssa.inline_functions(options.inliner_aggressiveness), "Inlining (1st)") // Run mem2reg with the CFG separated into blocks - .run_pass(Ssa::mem2reg, "Mem2Reg (1st)") + .run_pass(Ssa::mem2reg, "Mem2Reg (2nd)") .run_pass(Ssa::simplify_cfg, "Simplifying (1st)") .run_pass(Ssa::as_slice_optimization, "`as_slice` optimization") - .run_pass(Ssa::remove_unreachable_functions, "Removing Unreachable Functions") + .run_pass(Ssa::remove_unreachable_functions, "Removing Unreachable Functions (2nd)") .try_run_pass( Ssa::evaluate_static_assert_and_assert_constant, "`static_assert` and `assert_constant`", @@ -169,11 +177,11 @@ fn optimize_all(builder: SsaBuilder, options: &SsaEvaluatorOptions) -> Result Result, // Map keeping track of values stored at memory locations memory_slots: HashMap, - // Map of values resulting from array get instructions - // to the actual array values - array_elements: HashMap, - // Map of brillig call ids to sets of the value ids descending + // Value currently affecting every instruction (i.e. being + // considered a parent of every value id met) because + // of its involvement in an EnableSideEffectsIf condition + side_effects_condition: Option, + // Map of Brillig call ids to sets of the value ids descending // from their arguments and results tainted: BTreeMap, + // Map of argument value ids to the Brillig call ids employing them + call_arguments: HashMap>, + // Maintains count of calls being tracked + tracking_count: usize, + // Map of block indices to Brillig call ids that should not be + // followed after meeting them + search_limits: HashMap, } /// Structure keeping track of value ids descending from Brillig calls' @@ -116,8 +128,12 @@ struct BrilligTaintedIds { arguments: HashSet, // Results status results: Vec, - // Initial result value ids + // Indices of the array elements in the results vector + array_elements: HashMap>, + // Initial result value ids, along with element ids for arrays root_results: HashSet, + // The flag signaling that the call should be now tracked + tracking: bool, } #[derive(Clone, Debug)] @@ -128,17 +144,60 @@ enum ResultStatus { } impl BrilligTaintedIds { - fn new(arguments: &[ValueId], results: &[ValueId]) -> Self { + fn new(function: &Function, arguments: &[ValueId], results: &[ValueId]) -> Self { + // Exclude numeric constants + let arguments: Vec = arguments + .iter() + .filter(|value| function.dfg.get_numeric_constant(**value).is_none()) + .copied() + .map(|value| function.dfg.resolve(value)) + .collect(); + let results: Vec = results + .iter() + .filter(|value| function.dfg.get_numeric_constant(**value).is_none()) + .copied() + .map(|value| function.dfg.resolve(value)) + .collect(); + + let mut results_status: Vec = vec![]; + let mut array_elements: HashMap> = HashMap::new(); + + for result in &results { + match function.dfg.try_get_array_length(*result) { + // If the result value is an array, create an empty descendant set for + // every element to be accessed further on and record the indices + // of the resulting sets for future reference + Some(length) => { + array_elements.insert(*result, vec![]); + for _ in 0..length { + array_elements[result].push(results_status.len()); + results_status + .push(ResultStatus::Unconstrained { descendants: HashSet::new() }); + } + } + // Otherwise initialize a descendant set with the current value + None => { + results_status.push(ResultStatus::Unconstrained { + descendants: HashSet::from([*result]), + }); + } + } + } + BrilligTaintedIds { arguments: HashSet::from_iter(arguments.iter().copied()), - results: results - .iter() - .map(|result| ResultStatus::Unconstrained { descendants: HashSet::from([*result]) }) - .collect(), + results: results_status, + array_elements, root_results: HashSet::from_iter(results.iter().copied()), + tracking: false, } } + /// Check if the call being tracked is a simple wrapper of another call + fn is_wrapper(&self, other: &BrilligTaintedIds) -> bool { + other.root_results == self.arguments + } + /// Add children of a given parent to the tainted value set /// (for arguments one set is enough, for results we keep them /// separate as the forthcoming check considers the call covered @@ -147,12 +206,11 @@ impl BrilligTaintedIds { if self.arguments.intersection(parents).next().is_some() { self.arguments.extend(children); } - for result_status in &mut self.results.iter_mut() { - match result_status { + + for result in &mut self.results.iter_mut() { + match result { // Skip updating results already found covered - ResultStatus::Constrained => { - continue; - } + ResultStatus::Constrained => {} ResultStatus::Unconstrained { descendants } => { if descendants.intersection(parents).next().is_some() { descendants.extend(children); @@ -162,6 +220,20 @@ impl BrilligTaintedIds { } } + /// Update children of all the results (helper function for + /// chained Brillig call handling) + fn update_results_children(&mut self, children: &[ValueId]) { + for result in &mut self.results.iter_mut() { + match result { + // Skip updating results already found covered + ResultStatus::Constrained => {} + ResultStatus::Unconstrained { descendants } => { + descendants.extend(children); + } + } + } + } + /// If Brillig call is properly constrained by the given ids, return true fn check_constrained(&self) -> bool { // If every result has now been constrained, @@ -181,9 +253,7 @@ impl BrilligTaintedIds { for (i, result_status) in self.results.iter().enumerate() { match result_status { // Skip checking already covered results - ResultStatus::Constrained => { - continue; - } + ResultStatus::Constrained => {} ResultStatus::Unconstrained { descendants } => { if descendants.intersection(constrained_values).next().is_some() { results_involved.push(i); @@ -205,6 +275,21 @@ impl BrilligTaintedIds { results_involved.iter().for_each(|i| self.results[*i] = ResultStatus::Constrained); } } + + /// When an ArrayGet instruction occurs, place the resulting ValueId into + /// the corresponding sets of the call's array element result values + fn process_array_get(&mut self, array: ValueId, index: usize, element_results: &[ValueId]) { + if let Some(element_indices) = self.array_elements.get(&array) { + if let Some(result_index) = element_indices.get(index) { + if let Some(ResultStatus::Unconstrained { descendants }) = + self.results.get_mut(*result_index) + { + descendants.extend(element_results); + self.root_results.extend(element_results); + } + } + } + } } impl DependencyContext { @@ -231,9 +316,57 @@ impl DependencyContext { ) { trace!("processing instructions of block {} of function {}", block, function.id()); - for instruction in function.dfg[block].instructions() { + // First, gather information on all Brillig calls in the block + // to be able to follow their arguments first appearing in the + // flow graph before the calls themselves + function.dfg[block].instructions().iter().enumerate().for_each( + |(block_index, instruction)| { + if let Instruction::Call { func, arguments } = &function.dfg[*instruction] { + if let Value::Function(callee) = &function.dfg[*func] { + if all_functions[&callee].runtime().is_brillig() { + let results = function.dfg.instruction_results(*instruction); + let current_tainted = + BrilligTaintedIds::new(function, arguments, results); + + // Record arguments/results for each Brillig call for the check. + // + // Do not track Brillig calls acting as simple wrappers over + // another registered Brillig call, update the tainted sets of + // the wrapped call instead + let mut wrapped_call_found = false; + for (_, tainted_call) in self.tainted.iter_mut() { + if current_tainted.is_wrapper(tainted_call) { + tainted_call.update_results_children(results); + wrapped_call_found = true; + break; + } + } + + if !wrapped_call_found { + // Record the current call, remember the argument values involved + self.tainted.insert(*instruction, current_tainted); + arguments.iter().for_each(|value| { + self.call_arguments + .entry(*value) + .or_default() + .push(*instruction); + }); + + // Set the constraint search limit for the call + self.search_limits.insert( + block_index + BRILLIG_CONSTRAINT_SEARCH_DEPTH, + *instruction, + ); + } + } + } + } + }, + ); + + //Then, go over the instructions + for (block_index, instruction) in function.dfg[block].instructions().iter().enumerate() { let mut arguments = Vec::new(); - let mut results = Vec::new(); // Collect non-constant instruction arguments function.dfg[*instruction].for_each_value(|value_id| { @@ -242,137 +375,172 @@ impl DependencyContext { } }); - // Collect non-constant instruction results - for value_id in function.dfg.instruction_results(*instruction).iter() { - if function.dfg.get_numeric_constant(*value_id).is_none() { - results.push(function.dfg.resolve(*value_id)); + // Start tracking calls when their argument value ids first appear, + // or when their instruction id comes up (in case there were + // no non-constant arguments) + for argument in &arguments { + if let Some(calls) = self.call_arguments.get(argument) { + for call in calls { + if let Some(tainted_ids) = self.tainted.get_mut(call) { + tainted_ids.tracking = true; + self.tracking_count += 1; + } + } } } + if let Some(tainted_ids) = self.tainted.get_mut(instruction) { + tainted_ids.tracking = true; + self.tracking_count += 1; + } - // Process instructions - match &function.dfg[*instruction] { - // For memory operations, we have to link up the stored value as a parent - // of one loaded from the same memory slot - Instruction::Store { address, value } => { - self.memory_slots.insert(*address, function.dfg.resolve(*value)); + // Stop tracking calls when their search limit is hit + if let Some(call) = self.search_limits.get(&block_index) { + if let Some(tainted_ids) = self.tainted.get_mut(call) { + tainted_ids.tracking = false; + self.tracking_count -= 1; } - Instruction::Load { address } => { - // Recall the value stored at address as parent for the results - if let Some(value_id) = self.memory_slots.get(address) { - self.update_children(&[*value_id], &results); - } else { - panic!("load instruction {} has attempted to access previously unused memory location", - instruction); + } + + // We can skip over instructions while nothing is being tracked + if self.tracking_count > 0 { + let mut results = Vec::new(); + + // Collect non-constant instruction results + for value_id in function.dfg.instruction_results(*instruction).iter() { + if function.dfg.get_numeric_constant(*value_id).is_none() { + results.push(function.dfg.resolve(*value_id)); } } - // Check the constrain instruction arguments against those - // involved in Brillig calls, remove covered calls - Instruction::Constrain(value_id1, value_id2, _) - | Instruction::ConstrainNotEqual(value_id1, value_id2, _) => { - self.clear_constrained( - &[function.dfg.resolve(*value_id1), function.dfg.resolve(*value_id2)], - function, - ); - } - // Consider range check to also be constraining - Instruction::RangeCheck { value, .. } => { - self.clear_constrained(&[function.dfg.resolve(*value)], function); - } - Instruction::Call { func: func_id, .. } => { - // For functions, we remove the first element of arguments, - // as .for_each_value() used previously also includes func_id - arguments.remove(0); - match &function.dfg[*func_id] { - Value::Intrinsic(intrinsic) => match intrinsic { - Intrinsic::ApplyRangeConstraint | Intrinsic::AssertConstant => { - // Consider these intrinsic arguments constrained - self.clear_constrained(&arguments, function); - } - Intrinsic::AsWitness | Intrinsic::IsUnconstrained => { - // These intrinsics won't affect the dependency graph + match &function.dfg[*instruction] { + // For memory operations, we have to link up the stored value as a parent + // of one loaded from the same memory slot + Instruction::Store { address, value } => { + self.memory_slots.insert(*address, function.dfg.resolve(*value)); + } + Instruction::Load { address } => { + // Recall the value stored at address as parent for the results + if let Some(value_id) = self.memory_slots.get(address) { + self.update_children(&[*value_id], &results); + } else { + panic!("load instruction {} has attempted to access previously unused memory location", + instruction); + } + } + // Record the condition to set as future parent for the following values + Instruction::EnableSideEffectsIf { condition: value } => { + self.side_effects_condition = + match function.dfg.get_numeric_constant(*value) { + None => Some(function.dfg.resolve(*value)), + Some(_) => None, } - Intrinsic::ArrayLen - | Intrinsic::ArrayRefCount - | Intrinsic::ArrayAsStrUnchecked - | Intrinsic::AsSlice - | Intrinsic::BlackBox(..) - | Intrinsic::DerivePedersenGenerators - | Intrinsic::Hint(..) - | Intrinsic::SlicePushBack - | Intrinsic::SlicePushFront - | Intrinsic::SlicePopBack - | Intrinsic::SlicePopFront - | Intrinsic::SliceRefCount - | Intrinsic::SliceInsert - | Intrinsic::SliceRemove - | Intrinsic::StaticAssert - | Intrinsic::StrAsBytes - | Intrinsic::ToBits(..) - | Intrinsic::ToRadix(..) - | Intrinsic::FieldLessThan => { - // Record all the function arguments as parents of the results - self.update_children(&arguments, &results); + } + // Check the constrain instruction arguments against those + // involved in Brillig calls, remove covered calls + Instruction::Constrain(value_id1, value_id2, _) + | Instruction::ConstrainNotEqual(value_id1, value_id2, _) => { + self.clear_constrained( + &[function.dfg.resolve(*value_id1), function.dfg.resolve(*value_id2)], + function, + ); + } + // Consider range check to also be constraining + Instruction::RangeCheck { value, .. } => { + self.clear_constrained(&[function.dfg.resolve(*value)], function); + } + Instruction::Call { func: func_id, .. } => { + // For functions, we remove the first element of arguments, + // as .for_each_value() used previously also includes func_id + arguments.remove(0); + + match &function.dfg[*func_id] { + Value::Intrinsic(intrinsic) => match intrinsic { + Intrinsic::ApplyRangeConstraint | Intrinsic::AssertConstant => { + // Consider these intrinsic arguments constrained + self.clear_constrained(&arguments, function); + } + Intrinsic::AsWitness | Intrinsic::IsUnconstrained => { + // These intrinsics won't affect the dependency graph + } + Intrinsic::ArrayLen + | Intrinsic::ArrayRefCount + | Intrinsic::ArrayAsStrUnchecked + | Intrinsic::AsSlice + | Intrinsic::BlackBox(..) + | Intrinsic::DerivePedersenGenerators + | Intrinsic::Hint(..) + | Intrinsic::SlicePushBack + | Intrinsic::SlicePushFront + | Intrinsic::SlicePopBack + | Intrinsic::SlicePopFront + | Intrinsic::SliceRefCount + | Intrinsic::SliceInsert + | Intrinsic::SliceRemove + | Intrinsic::StaticAssert + | Intrinsic::StrAsBytes + | Intrinsic::ToBits(..) + | Intrinsic::ToRadix(..) + | Intrinsic::FieldLessThan => { + // Record all the function arguments as parents of the results + self.update_children(&arguments, &results); + } + }, + Value::Function(callee) => match all_functions[&callee].runtime() { + // Only update tainted sets for non-Brillig calls, as + // the chained Brillig case should already be covered + RuntimeType::Acir(..) => { + self.update_children(&arguments, &results); + } + RuntimeType::Brillig(..) => {} + }, + Value::ForeignFunction(..) => { + panic!("should not be able to reach foreign function from non-Brillig functions, {func_id} in function {}", function.name()); } - }, - Value::Function(callee) => match all_functions[callee].runtime() { - RuntimeType::Brillig(_) => { - // Record arguments/results for each Brillig call for the check - self.tainted.insert( - *instruction, - BrilligTaintedIds::new(&arguments, &results), + Value::Instruction { .. } + | Value::NumericConstant { .. } + | Value::Param { .. } + | Value::Global(_) => { + panic!( + "calling non-function value with ID {func_id} in function {}", + function.name() ); } - RuntimeType::Acir(..) => { - // Record all the function arguments as parents of the results - self.update_children(&arguments, &results); - } - }, - Value::ForeignFunction(..) => { - panic!("should not be able to reach foreign function from non-Brillig functions, {func_id} in function {}", function.name()); - } - Value::Instruction { .. } - | Value::NumericConstant { .. } - | Value::Param { .. } - | Value::Global(_) => { - panic!( - "calling non-function value with ID {func_id} in function {}", - function.name() - ); } } - } - // For array get operations, we link the resulting values to - // the corresponding array value ids - // (this is required later because for now we consider array elements - // being constrained as valid as the whole arrays being constrained) - Instruction::ArrayGet { array, .. } => { - for result in &results { - self.array_elements.insert(*result, function.dfg.resolve(*array)); + // For array get operations, we check the Brillig calls for + // results involving the array in question, to properly + // populate the array element tainted sets + Instruction::ArrayGet { array, index } => { + self.process_array_get(function, *array, *index, &results); + // Record all the used arguments as parents of the results + self.update_children(&arguments, &results); } - // Record all the used arguments as parents of the results - self.update_children(&arguments, &results); - } - Instruction::ArraySet { .. } - | Instruction::Binary(..) - | Instruction::Cast(..) - | Instruction::IfElse { .. } - | Instruction::Not(..) - | Instruction::Truncate { .. } => { - // Record all the used arguments as parents of the results - self.update_children(&arguments, &results); + Instruction::ArraySet { .. } + | Instruction::Binary(..) + | Instruction::Cast(..) + | Instruction::IfElse { .. } + | Instruction::Not(..) + | Instruction::Truncate { .. } => { + // Record all the used arguments as parents of the results + self.update_children(&arguments, &results); + } + // These instructions won't affect the dependency graph + Instruction::Allocate { .. } + | Instruction::DecrementRc { .. } + | Instruction::IncrementRc { .. } + | Instruction::MakeArray { .. } + | Instruction::Noop => {} } - // These instructions won't affect the dependency graph - Instruction::Allocate { .. } - | Instruction::DecrementRc { .. } - | Instruction::EnableSideEffectsIf { .. } - | Instruction::IncrementRc { .. } - | Instruction::Noop - | Instruction::MakeArray { .. } => {} } } - trace!("Number tainted Brillig calls: {}", self.tainted.len()); + if !self.tainted.is_empty() { + trace!( + "number of Brillig calls in function {} left unchecked: {}", + function, + self.tainted.len() + ); + } } /// Every Brillig call not properly constrained should remain in the tainted set @@ -382,6 +550,7 @@ impl DependencyContext { .tainted .keys() .map(|brillig_call| { + trace!("tainted structure for {}: {:?}", brillig_call, self.tainted[brillig_call]); SsaReport::Bug(InternalBug::UncheckedBrilligCall { call_stack: function.dfg.get_instruction_call_stack(*brillig_call), }) @@ -389,7 +558,7 @@ impl DependencyContext { .collect(); trace!( - "making {} under constrained reports for function {}", + "making {} reports on underconstrained Brillig calls for function {}", warnings.len(), function.name() ); @@ -398,9 +567,17 @@ impl DependencyContext { /// Update sets of value ids that can be traced back to the Brillig calls being tracked fn update_children(&mut self, parents: &[ValueId], children: &[ValueId]) { - let parents: HashSet<_> = HashSet::from_iter(parents.iter().copied()); + let mut parents: HashSet<_> = HashSet::from_iter(parents.iter().copied()); + + // Also include the current EnableSideEffectsIf condition in parents + // (as it would affect every following statement) + self.side_effects_condition.map(|v| parents.insert(v)); + + // Don't update sets for the calls not yet being tracked for (_, tainted_ids) in self.tainted.iter_mut() { - tainted_ids.update_children(&parents, children); + if tainted_ids.tracking { + tainted_ids.update_children(&parents, children); + } } } @@ -408,28 +585,44 @@ impl DependencyContext { /// by given values after recording partial constraints, if so stop tracking them fn clear_constrained(&mut self, constrained_values: &[ValueId], function: &Function) { // Remove numeric constants - let constrained_values = - constrained_values.iter().filter(|v| function.dfg.get_numeric_constant(**v).is_none()); - - // For now, consider array element constraints to be array constraints - // TODO(https://github.com/noir-lang/noir/issues/6698): - // This probably has to be further looked into, to ensure _every_ element - // of an array result of a Brillig call has been constrained let constrained_values: HashSet<_> = constrained_values - .map(|v| { - if let Some(parent_array) = self.array_elements.get(v) { - *parent_array - } else { - *v - } - }) + .iter() + .filter(|v| function.dfg.get_numeric_constant(**v).is_none()) + .copied() .collect(); - self.tainted.iter_mut().for_each(|(_, tainted_ids)| { - tainted_ids.store_partial_constraints(&constrained_values); - }); + // Skip untracked calls + for (_, tainted_ids) in self.tainted.iter_mut() { + if tainted_ids.tracking { + tainted_ids.store_partial_constraints(&constrained_values); + } + } + self.tainted.retain(|_, tainted_ids| !tainted_ids.check_constrained()); } + + /// Process ArrayGet instruction for tracked Brillig calls + fn process_array_get( + &mut self, + function: &Function, + array: ValueId, + index: ValueId, + element_results: &[ValueId], + ) { + use acvm::acir::AcirField; + + // Only allow numeric constant indices + if let Some(value) = function.dfg.get_numeric_constant(index) { + if let Some(index) = value.try_to_u32() { + // Skip untracked calls + for (_, tainted_ids) in self.tainted.iter_mut() { + if tainted_ids.tracking { + tainted_ids.process_array_get(array, index as usize, element_results); + } + } + } + } + } } #[derive(Default)] @@ -499,7 +692,7 @@ impl Context { function: &Function, ) -> Vec { let mut warnings = Vec::new(); - // Find brillig-generated values in the set + // Find Brillig-generated values in the set let intersection = all_brillig_generated_values.intersection(current_set).copied(); // Go through all Brillig outputs in the set @@ -1018,4 +1211,189 @@ mod test { let ssa_level_warnings = ssa.check_for_missing_brillig_constraints(); assert_eq!(ssa_level_warnings.len(), 2); } + + #[test] + #[traced_test] + /// Test EnableSideEffectsIf conditions affecting the dependency graph + /// (SSA a bit convoluted to work around simplification breaking the flow + /// of the parsed test code) + fn test_enable_side_effects_if_affecting_following_statements() { + let program = r#" + acir(inline) fn main f0 { + b0(v0: Field, v1: Field): + v3 = call f1(v0, v1) -> Field + v5 = add v0, v1 + v6 = eq v3, v5 + v7 = add u1 1, u1 0 + enable_side_effects v6 + v8 = add v7, u1 1 + enable_side_effects u1 1 + constrain v8 == u1 2 + return v3 + } + + brillig(inline) fn foo f1 { + b0(v0: Field, v1: Field): + v2 = add v0, v1 + return v2 + } + "#; + + let mut ssa = Ssa::from_str(program).unwrap(); + let ssa_level_warnings = ssa.check_for_missing_brillig_constraints(); + assert_eq!(ssa_level_warnings.len(), 0); + } + + #[test] + #[traced_test] + /// Test call result array elements being underconstrained + fn test_brillig_result_array_missing_element_constraint() { + let program = r#" + acir(inline) fn main f0 { + b0(v0: u32): + v16 = call f1(v0) -> [u32; 3] + v17 = array_get v16, index u32 0 -> u32 + constrain v17 == v0 + v19 = array_get v16, index u32 2 -> u32 + constrain v19 == v0 + return v17 + } + + brillig(inline) fn into_array f1 { + b0(v0: u32): + v4 = make_array [v0, v0, v0] : [u32; 3] + return v4 + } + "#; + + let mut ssa = Ssa::from_str(program).unwrap(); + let ssa_level_warnings = ssa.check_for_missing_brillig_constraints(); + assert_eq!(ssa_level_warnings.len(), 1); + } + + #[test] + #[traced_test] + /// Test call result array elements being constrained properly + fn test_brillig_result_array_all_elements_constrained() { + let program = r#" + acir(inline) fn main f0 { + b0(v0: u32): + v16 = call f1(v0) -> [u32; 3] + v17 = array_get v16, index u32 0 -> u32 + constrain v17 == v0 + v20 = array_get v16, index u32 1 -> u32 + constrain v20 == v0 + v19 = array_get v16, index u32 2 -> u32 + constrain v19 == v0 + return v17 + } + + brillig(inline) fn into_array f1 { + b0(v0: u32): + v4 = make_array [v0, v0, v0] : [u32; 3] + return v4 + } + "#; + + let mut ssa = Ssa::from_str(program).unwrap(); + let ssa_level_warnings = ssa.check_for_missing_brillig_constraints(); + assert_eq!(ssa_level_warnings.len(), 0); + } + + #[test] + #[traced_test] + /// Test chained (wrapper) Brillig calls not producing a false positive + fn test_chained_brillig_calls_constrained() { + /* + struct Animal { + legs: Field, + eyes: u8, + tag: Tag, + } + + struct Tag { + no: Field, + } + + unconstrained fn foo(x: Field) -> Animal { + Animal { + legs: 4, + eyes: 2, + tag: Tag { no: x } + } + } + + unconstrained fn bar(x: Animal) -> Animal { + Animal { + legs: x.legs, + eyes: x.eyes, + tag: Tag { no: x.tag.no + 1 } + } + } + + fn main(x: Field) -> pub Animal { + let dog = bar(foo(x)); + assert(dog.legs == 4); + assert(dog.eyes == 2); + assert(dog.tag.no == x + 1); + + dog + } + */ + + let program = r#" + acir(inline) fn main f0 { + b0(v0: Field): + v27, v28, v29 = call f2(v0) -> (Field, u8, Field) + v30, v31, v32 = call f1(v27, v28, v29) -> (Field, u8, Field) + constrain v30 == Field 4 + constrain v31 == u8 2 + v35 = add v0, Field 1 + constrain v32 == v35 + return v30, v31, v32 + } + + brillig(inline) fn foo f2 { + b0(v0: Field): + return Field 4, u8 2, v0 + } + + brillig(inline) fn bar f1 { + b0(v0: Field, v1: u8, v2: Field): + v7 = add v2, Field 1 + return v0, v1, v7 + } + + "#; + + let mut ssa = Ssa::from_str(program).unwrap(); + let ssa_level_warnings = ssa.check_for_missing_brillig_constraints(); + assert_eq!(ssa_level_warnings.len(), 0); + } + + #[test] + #[traced_test] + /// Test for the argument descendants coming before Brillig calls themselves being + /// registered as such + fn test_brillig_argument_descendants_preceding_call() { + let program = r#" + acir(inline) fn main f0 { + b0(v0: Field, v1: Field): + v3 = add v0, v1 + v5 = call f1(v0, v1) -> Field + constrain v3 == v5 + return v3 + } + + brillig(inline) fn foo f1 { + b0(v0: Field, v1: Field): + v2 = add v0, v1 + return v2 + } + "#; + + let mut ssa = Ssa::from_str(program).unwrap(); + let ssa_level_warnings = ssa.check_for_missing_brillig_constraints(); + assert_eq!(ssa_level_warnings.len(), 0); + } } diff --git a/noir/noir-repo/compiler/noirc_evaluator/src/ssa/ir/dfg.rs b/noir/noir-repo/compiler/noirc_evaluator/src/ssa/ir/dfg.rs index 28242b223ac..ad4cd079c3b 100644 --- a/noir/noir-repo/compiler/noirc_evaluator/src/ssa/ir/dfg.rs +++ b/noir/noir-repo/compiler/noirc_evaluator/src/ssa/ir/dfg.rs @@ -109,6 +109,7 @@ pub(crate) struct DataFlowGraph { /// The GlobalsGraph contains the actual global data. /// Global data is expected to only be numeric constants or array constants (which are represented by Instruction::MakeArray). /// The global's data will shared across functions and should be accessible inside of a function's DataFlowGraph. +#[serde_as] #[derive(Debug, Clone, Serialize, Deserialize, Default)] pub(crate) struct GlobalsGraph { /// Storage for all of the global values @@ -116,19 +117,39 @@ pub(crate) struct GlobalsGraph { /// All of the instructions in the global value space. /// These are expected to all be Instruction::MakeArray instructions: DenseMap, + #[serde_as(as = "HashMap")] + results: HashMap>, + #[serde(skip)] + constants: HashMap<(FieldElement, NumericType), ValueId>, } impl GlobalsGraph { pub(crate) fn from_dfg(dfg: DataFlowGraph) -> Self { - Self { values: dfg.values, instructions: dfg.instructions } + Self { + values: dfg.values, + instructions: dfg.instructions, + results: dfg.results, + constants: dfg.constants, + } } /// Iterate over every Value in this DFG in no particular order, including unused Values - pub(crate) fn values_iter(&self) -> impl ExactSizeIterator { + pub(crate) fn values_iter(&self) -> impl DoubleEndedIterator { self.values.iter() } } +impl From for DataFlowGraph { + fn from(value: GlobalsGraph) -> Self { + DataFlowGraph { + values: value.values, + instructions: value.instructions, + results: value.results, + ..Default::default() + } + } +} + impl DataFlowGraph { /// Runtime type of the function. pub(crate) fn runtime(&self) -> RuntimeType { @@ -173,12 +194,12 @@ impl DataFlowGraph { /// The pairs are order by id, which is not guaranteed to be meaningful. pub(crate) fn basic_blocks_iter( &self, - ) -> impl ExactSizeIterator { + ) -> impl DoubleEndedIterator { self.blocks.iter() } /// Iterate over every Value in this DFG in no particular order, including unused Values - pub(crate) fn values_iter(&self) -> impl ExactSizeIterator { + pub(crate) fn values_iter(&self) -> impl DoubleEndedIterator { self.values.iter() } @@ -233,17 +254,18 @@ impl DataFlowGraph { pub(crate) fn insert_instruction_and_results_without_simplification( &mut self, - instruction_data: Instruction, + instruction: Instruction, block: BasicBlockId, ctrl_typevars: Option>, call_stack: CallStackId, ) -> InsertInstructionResult { - if !self.is_handled_by_runtime(&instruction_data) { - panic!("Attempted to insert instruction not handled by runtime: {instruction_data:?}"); + if !self.is_handled_by_runtime(&instruction) { + // Panicking to raise attention. If we're not supposed to simplify it immediately, + // pushing the instruction would just cause a potential panic later on. + panic!("Attempted to insert instruction not handled by runtime: {instruction:?}"); } - let id = self.insert_instruction_without_simplification( - instruction_data, + instruction, block, ctrl_typevars, call_stack, @@ -280,7 +302,10 @@ impl DataFlowGraph { existing_id: Option, ) -> InsertInstructionResult { if !self.is_handled_by_runtime(&instruction) { - panic!("Attempted to insert instruction not handled by runtime: {instruction:?}"); + // BUG: With panicking it fails to build the `token_contract`; see: + // https://github.com/AztecProtocol/aztec-packages/pull/11294#issuecomment-2624379102 + // panic!("Attempted to insert instruction not handled by runtime: {instruction:?}"); + return InsertInstructionResult::InstructionRemoved; } match instruction.simplify(self, block, ctrl_typevars.clone(), call_stack) { @@ -386,6 +411,9 @@ impl DataFlowGraph { if let Some(id) = self.constants.get(&(constant, typ)) { return *id; } + if let Some(id) = self.globals.constants.get(&(constant, typ)) { + return *id; + } let id = self.values.insert(Value::NumericConstant { constant, typ }); self.constants.insert((constant, typ), id); id @@ -484,7 +512,7 @@ impl DataFlowGraph { /// Should `value` be a numeric constant then this function will return the exact number of bits required, /// otherwise it will return the minimum number of bits based on type information. pub(crate) fn get_value_max_num_bits(&self, value: ValueId) -> u32 { - match self[value] { + match self[self.resolve(value)] { Value::Instruction { instruction, .. } => { let value_bit_size = self.type_of_value(value).bit_size(); if let Instruction::Cast(original_value, _) = self[instruction] { diff --git a/noir/noir-repo/compiler/noirc_evaluator/src/ssa/ir/function.rs b/noir/noir-repo/compiler/noirc_evaluator/src/ssa/ir/function.rs index b59b0c18a10..516cd8e318e 100644 --- a/noir/noir-repo/compiler/noirc_evaluator/src/ssa/ir/function.rs +++ b/noir/noir-repo/compiler/noirc_evaluator/src/ssa/ir/function.rs @@ -12,7 +12,7 @@ use super::map::Id; use super::types::Type; use super::value::ValueId; -#[derive(Clone, Copy, PartialEq, Eq, Debug, Hash, Serialize, Deserialize)] +#[derive(Clone, Copy, PartialEq, Eq, Debug, Hash, Serialize, Deserialize, PartialOrd, Ord)] pub(crate) enum RuntimeType { // A noir function, to be compiled in ACIR and executed by ACVM Acir(InlineType), @@ -143,10 +143,7 @@ impl Function { } pub(crate) fn is_no_predicates(&self) -> bool { - match self.runtime() { - RuntimeType::Acir(inline_type) => matches!(inline_type, InlineType::NoPredicates), - RuntimeType::Brillig(_) => false, - } + self.runtime().is_no_predicates() } /// Retrieves the entry block of a function. diff --git a/noir/noir-repo/compiler/noirc_evaluator/src/ssa/ir/instruction.rs b/noir/noir-repo/compiler/noirc_evaluator/src/ssa/ir/instruction.rs index 171ca30f5f4..5806e62bf95 100644 --- a/noir/noir-repo/compiler/noirc_evaluator/src/ssa/ir/instruction.rs +++ b/noir/noir-repo/compiler/noirc_evaluator/src/ssa/ir/instruction.rs @@ -10,7 +10,7 @@ use fxhash::FxHasher64; use iter_extended::vecmap; use noirc_frontend::hir_def::types::Type as HirType; -use crate::ssa::{ir::function::RuntimeType, opt::flatten_cfg::value_merger::ValueMerger}; +use crate::ssa::opt::flatten_cfg::value_merger::ValueMerger; use super::{ basic_block::BasicBlockId, @@ -506,7 +506,7 @@ impl Instruction { } } - pub(crate) fn can_eliminate_if_unused(&self, function: &Function) -> bool { + pub(crate) fn can_eliminate_if_unused(&self, function: &Function, flattened: bool) -> bool { use Instruction::*; match self { Binary(binary) => { @@ -539,8 +539,7 @@ impl Instruction { // pass where this check is done, but does mean that we cannot perform mem2reg // after the DIE pass. Store { .. } => { - matches!(function.runtime(), RuntimeType::Acir(_)) - && function.reachable_blocks().len() == 1 + flattened && function.runtime().is_acir() && function.reachable_blocks().len() == 1 } Constrain(..) diff --git a/noir/noir-repo/compiler/noirc_evaluator/src/ssa/ir/instruction/call.rs b/noir/noir-repo/compiler/noirc_evaluator/src/ssa/ir/instruction/call.rs index 992c633ffcd..6ee7aa0192c 100644 --- a/noir/noir-repo/compiler/noirc_evaluator/src/ssa/ir/instruction/call.rs +++ b/noir/noir-repo/compiler/noirc_evaluator/src/ssa/ir/instruction/call.rs @@ -650,7 +650,12 @@ fn constant_to_radix( ) -> SimplifyResult { let bit_size = u32::BITS - (radix - 1).leading_zeros(); let radix_big = BigUint::from(radix); - assert_eq!(BigUint::from(2u128).pow(bit_size), radix_big, "ICE: Radix must be a power of 2"); + let radix_range = BigUint::from(2u128)..=BigUint::from(256u128); + if !radix_range.contains(&radix_big) || BigUint::from(2u128).pow(bit_size) != radix_big { + // NOTE: expect an error to be thrown later in + // acir::generated_acir::radix_le_decompose + return SimplifyResult::None; + } let big_integer = BigUint::from_bytes_be(&field.to_be_bytes()); // Decompose the integer into its radix digits in little endian form. diff --git a/noir/noir-repo/compiler/noirc_evaluator/src/ssa/ir/map.rs b/noir/noir-repo/compiler/noirc_evaluator/src/ssa/ir/map.rs index 1d637309191..b6da107957c 100644 --- a/noir/noir-repo/compiler/noirc_evaluator/src/ssa/ir/map.rs +++ b/noir/noir-repo/compiler/noirc_evaluator/src/ssa/ir/map.rs @@ -190,7 +190,7 @@ impl DenseMap { /// Gets an iterator to a reference to each element in the dense map paired with its id. /// /// The id-element pairs are ordered by the numeric values of the ids. - pub(crate) fn iter(&self) -> impl ExactSizeIterator, &T)> { + pub(crate) fn iter(&self) -> impl DoubleEndedIterator, &T)> { let ids_iter = (0..self.storage.len() as u32).map(|idx| Id::new(idx)); ids_iter.zip(self.storage.iter()) } diff --git a/noir/noir-repo/compiler/noirc_evaluator/src/ssa/ir/printer.rs b/noir/noir-repo/compiler/noirc_evaluator/src/ssa/ir/printer.rs index 85f8dcaba48..e9c465d264f 100644 --- a/noir/noir-repo/compiler/noirc_evaluator/src/ssa/ir/printer.rs +++ b/noir/noir-repo/compiler/noirc_evaluator/src/ssa/ir/printer.rs @@ -20,13 +20,16 @@ use super::{ impl Display for Ssa { fn fmt(&self, f: &mut Formatter<'_>) -> Result { - for (id, global_value) in self.globals.dfg.values_iter() { + let globals = (*self.functions[&self.main_id].dfg.globals).clone(); + let globals_dfg = DataFlowGraph::from(globals); + + for (id, global_value) in globals_dfg.values_iter() { match global_value { Value::NumericConstant { constant, typ } => { writeln!(f, "g{} = {typ} {constant}", id.to_u32())?; } Value::Instruction { instruction, .. } => { - display_instruction(&self.globals.dfg, *instruction, true, f)?; + display_instruction(&globals_dfg, *instruction, true, f)?; } Value::Global(_) => { panic!("Value::Global should only be in the function dfg"); @@ -35,7 +38,7 @@ impl Display for Ssa { }; } - if self.globals.dfg.values_iter().len() > 0 { + if globals_dfg.values_iter().next().is_some() { writeln!(f)?; } diff --git a/noir/noir-repo/compiler/noirc_evaluator/src/ssa/opt/defunctionalize.rs b/noir/noir-repo/compiler/noirc_evaluator/src/ssa/opt/defunctionalize.rs index 186f10c53e6..4afddbef41a 100644 --- a/noir/noir-repo/compiler/noirc_evaluator/src/ssa/opt/defunctionalize.rs +++ b/noir/noir-repo/compiler/noirc_evaluator/src/ssa/opt/defunctionalize.rs @@ -8,12 +8,13 @@ use std::collections::{BTreeMap, BTreeSet, HashSet}; use acvm::FieldElement; use iter_extended::vecmap; +use noirc_frontend::monomorphization::ast::InlineType; use crate::ssa::{ function_builder::FunctionBuilder, ir::{ basic_block::BasicBlockId, - function::{Function, FunctionId, Signature}, + function::{Function, FunctionId, RuntimeType, Signature}, instruction::{BinaryOp, Instruction}, types::{NumericType, Type}, value::{Value, ValueId}, @@ -43,12 +44,15 @@ struct ApplyFunction { dispatches_to_multiple_functions: bool, } +type Variants = BTreeMap<(Signature, RuntimeType), Vec>; +type ApplyFunctions = HashMap<(Signature, RuntimeType), ApplyFunction>; + /// Performs defunctionalization on all functions /// This is done by changing all functions as value to be a number (FieldElement) /// And creating apply functions that dispatch to the correct target by runtime comparisons with constants #[derive(Debug, Clone)] struct DefunctionalizationContext { - apply_functions: HashMap, + apply_functions: ApplyFunctions, } impl Ssa { @@ -104,7 +108,7 @@ impl DefunctionalizationContext { }; // Find the correct apply function - let apply_function = self.get_apply_function(&signature); + let apply_function = self.get_apply_function(signature, func.runtime()); // Replace the instruction with a call to apply let apply_function_value_id = func.dfg.import_function(apply_function.id); @@ -152,19 +156,21 @@ impl DefunctionalizationContext { } /// Returns the apply function for the given signature - fn get_apply_function(&self, signature: &Signature) -> ApplyFunction { - *self.apply_functions.get(signature).expect("Could not find apply function") + fn get_apply_function(&self, signature: Signature, runtime: RuntimeType) -> ApplyFunction { + *self.apply_functions.get(&(signature, runtime)).expect("Could not find apply function") } } /// Collects all functions used as values that can be called by their signatures -fn find_variants(ssa: &Ssa) -> BTreeMap> { - let mut dynamic_dispatches: BTreeSet = BTreeSet::new(); +fn find_variants(ssa: &Ssa) -> Variants { + let mut dynamic_dispatches: BTreeSet<(Signature, RuntimeType)> = BTreeSet::new(); let mut functions_as_values: BTreeSet = BTreeSet::new(); for function in ssa.functions.values() { functions_as_values.extend(find_functions_as_values(function)); - dynamic_dispatches.extend(find_dynamic_dispatches(function)); + dynamic_dispatches.extend( + find_dynamic_dispatches(function).into_iter().map(|sig| (sig, function.runtime())), + ); } let mut signature_to_functions_as_value: BTreeMap> = BTreeMap::new(); @@ -174,16 +180,12 @@ fn find_variants(ssa: &Ssa) -> BTreeMap> { signature_to_functions_as_value.entry(signature).or_default().push(function_id); } - let mut variants = BTreeMap::new(); + let mut variants: Variants = BTreeMap::new(); - for dispatch_signature in dynamic_dispatches { - let mut target_fns = vec![]; - for (target_signature, functions) in &signature_to_functions_as_value { - if &dispatch_signature == target_signature { - target_fns.extend(functions); - } - } - variants.insert(dispatch_signature, target_fns); + for (dispatch_signature, caller_runtime) in dynamic_dispatches { + let target_fns = + signature_to_functions_as_value.get(&dispatch_signature).cloned().unwrap_or_default(); + variants.insert((dispatch_signature, caller_runtime), target_fns); } variants @@ -247,10 +249,10 @@ fn find_dynamic_dispatches(func: &Function) -> BTreeSet { fn create_apply_functions( ssa: &mut Ssa, - variants_map: BTreeMap>, -) -> HashMap { + variants_map: BTreeMap<(Signature, RuntimeType), Vec>, +) -> ApplyFunctions { let mut apply_functions = HashMap::default(); - for (signature, variants) in variants_map.into_iter() { + for ((signature, runtime), variants) in variants_map.into_iter() { assert!( !variants.is_empty(), "ICE: at least one variant should exist for a dynamic call {signature:?}" @@ -258,11 +260,12 @@ fn create_apply_functions( let dispatches_to_multiple_functions = variants.len() > 1; let id = if dispatches_to_multiple_functions { - create_apply_function(ssa, signature.clone(), variants) + create_apply_function(ssa, signature.clone(), runtime, variants) } else { variants[0] }; - apply_functions.insert(signature, ApplyFunction { id, dispatches_to_multiple_functions }); + apply_functions + .insert((signature, runtime), ApplyFunction { id, dispatches_to_multiple_functions }); } apply_functions } @@ -275,6 +278,7 @@ fn function_id_to_field(function_id: FunctionId) -> FieldElement { fn create_apply_function( ssa: &mut Ssa, signature: Signature, + caller_runtime: RuntimeType, function_ids: Vec, ) -> FunctionId { assert!(!function_ids.is_empty()); @@ -282,6 +286,14 @@ fn create_apply_function( ssa.add_fn(|id| { let mut function_builder = FunctionBuilder::new("apply".to_string(), id); function_builder.set_globals(globals); + + // We want to push for apply functions to be inlined more aggressively; + // they are expected to be optimized away by constants visible at the call site. + let runtime = match caller_runtime { + RuntimeType::Acir(_) => RuntimeType::Acir(InlineType::InlineAlways), + RuntimeType::Brillig(_) => RuntimeType::Brillig(InlineType::InlineAlways), + }; + function_builder.set_runtime(runtime); let target_id = function_builder.add_parameter(Type::field()); let params_ids = vecmap(signature.params, |typ| function_builder.add_parameter(typ)); @@ -339,22 +351,156 @@ fn create_apply_function( }) } -/// Crates a return block, if no previous return exists, it will create a final return -/// Else, it will create a bypass return block that points to the previous return block +/// If no previous return target exists, it will create a final return, +/// otherwise returns the existing return block to jump to. fn build_return_block( builder: &mut FunctionBuilder, previous_block: BasicBlockId, passed_types: &[Type], target: Option, ) -> BasicBlockId { + if let Some(return_block) = target { + return return_block; + } let return_block = builder.insert_block(); builder.switch_to_block(return_block); - let params = vecmap(passed_types, |typ| builder.add_block_parameter(return_block, typ.clone())); - match target { - None => builder.terminate_with_return(params), - Some(target) => builder.terminate_with_jmp(target, params), - } + builder.terminate_with_return(params); builder.switch_to_block(previous_block); return_block } + +#[cfg(test)] +mod tests { + use crate::ssa::opt::assert_normalized_ssa_equals; + + use super::Ssa; + + #[test] + fn apply_inherits_caller_runtime() { + // Extracted from `execution_success/brillig_fns_as_values` with `--force-brillig` + let src = " + brillig(inline) fn main f0 { + b0(v0: u32): + v3 = call f1(f2, v0) -> u32 + v5 = add v0, u32 1 + v6 = eq v3, v5 + constrain v3 == v5 + v9 = call f1(f3, v0) -> u32 + v10 = add v0, u32 1 + v11 = eq v9, v10 + constrain v9 == v10 + return + } + brillig(inline) fn wrapper f1 { + b0(v0: function, v1: u32): + v2 = call v0(v1) -> u32 + return v2 + } + brillig(inline) fn increment f2 { + b0(v0: u32): + v2 = add v0, u32 1 + return v2 + } + brillig(inline) fn increment_acir f3 { + b0(v0: u32): + v2 = add v0, u32 1 + return v2 + } + "; + + let ssa = Ssa::from_str(src).unwrap(); + let ssa = ssa.defunctionalize(); + + let expected = " + brillig(inline) fn main f0 { + b0(v0: u32): + v3 = call f1(Field 2, v0) -> u32 + v5 = add v0, u32 1 + v6 = eq v3, v5 + constrain v3 == v5 + v9 = call f1(Field 3, v0) -> u32 + v10 = add v0, u32 1 + v11 = eq v9, v10 + constrain v9 == v10 + return + } + brillig(inline) fn wrapper f1 { + b0(v0: Field, v1: u32): + v3 = call f4(v0, v1) -> u32 + return v3 + } + brillig(inline) fn increment f2 { + b0(v0: u32): + v2 = add v0, u32 1 + return v2 + } + brillig(inline) fn increment_acir f3 { + b0(v0: u32): + v2 = add v0, u32 1 + return v2 + } + brillig(inline_always) fn apply f4 { + b0(v0: Field, v1: u32): + v4 = eq v0, Field 2 + jmpif v4 then: b2, else: b1 + b1(): + constrain v0 == Field 3 + v7 = call f3(v1) -> u32 + jmp b3(v7) + b2(): + v9 = call f2(v1) -> u32 + jmp b3(v9) + b3(v2: u32): + return v2 + } + "; + assert_normalized_ssa_equals(ssa, expected); + } + + #[test] + fn apply_created_per_caller_runtime() { + let src = " + acir(inline) fn main f0 { + b0(v0: u32): + v3 = call f1(f2, v0) -> u32 + v5 = add v0, u32 1 + v6 = eq v3, v5 + constrain v3 == v5 + v9 = call f4(f3, v0) -> u32 + v10 = add v0, u32 1 + v11 = eq v9, v10 + constrain v9 == v10 + return + } + brillig(inline) fn wrapper f1 { + b0(v0: function, v1: u32): + v2 = call v0(v1) -> u32 + return v2 + } + acir(inline) fn wrapper_acir f4 { + b0(v0: function, v1: u32): + v2 = call v0(v1) -> u32 + return v2 + } + brillig(inline) fn increment f2 { + b0(v0: u32): + v2 = add v0, u32 1 + return v2 + } + acir(inline) fn increment_acir f3 { + b0(v0: u32): + v2 = add v0, u32 1 + return v2 + } + "; + + let ssa = Ssa::from_str(src).unwrap(); + let ssa = ssa.defunctionalize(); + + let applies = ssa.functions.values().filter(|f| f.name() == "apply").collect::>(); + assert_eq!(applies.len(), 2); + assert!(applies.iter().any(|f| f.runtime().is_acir())); + assert!(applies.iter().any(|f| f.runtime().is_brillig())); + } +} diff --git a/noir/noir-repo/compiler/noirc_evaluator/src/ssa/opt/die.rs b/noir/noir-repo/compiler/noirc_evaluator/src/ssa/opt/die.rs index eed1af8251b..f02b0975e9d 100644 --- a/noir/noir-repo/compiler/noirc_evaluator/src/ssa/opt/die.rs +++ b/noir/noir-repo/compiler/noirc_evaluator/src/ssa/opt/die.rs @@ -22,9 +22,34 @@ use super::rc::{pop_rc_for, RcInstruction}; impl Ssa { /// Performs Dead Instruction Elimination (DIE) to remove any instructions with /// unused results. + /// + /// This step should come after the flattening of the CFG and mem2reg. #[tracing::instrument(level = "trace", skip(self))] - pub(crate) fn dead_instruction_elimination(mut self) -> Ssa { - self.functions.par_iter_mut().for_each(|(_, func)| func.dead_instruction_elimination(true)); + pub(crate) fn dead_instruction_elimination(self) -> Ssa { + self.dead_instruction_elimination_inner(true) + } + + fn dead_instruction_elimination_inner(mut self, flattened: bool) -> Ssa { + let mut used_global_values: HashSet<_> = self + .functions + .par_iter_mut() + .flat_map(|(_, func)| func.dead_instruction_elimination(true, flattened)) + .collect(); + + let globals = &self.functions[&self.main_id].dfg.globals; + // Check which globals are used across all functions + for (id, value) in globals.values_iter().rev() { + if used_global_values.contains(&id) { + if let Value::Instruction { instruction, .. } = &value { + let instruction = &globals[*instruction]; + instruction.for_each_value(|value_id| { + used_global_values.insert(value_id); + }); + } + } + } + + self.used_global_values = used_global_values; self } @@ -37,8 +62,18 @@ impl Function { /// instructions that reference results from an instruction in another block are evaluated first. /// If we did not iterate blocks in this order we could not safely say whether or not the results /// of its instructions are needed elsewhere. - pub(crate) fn dead_instruction_elimination(&mut self, insert_out_of_bounds_checks: bool) { - let mut context = Context::default(); + /// + /// Returns the set of globals that were used in this function. + /// After processing all functions, the union of these sets enables determining the unused globals. + pub(crate) fn dead_instruction_elimination( + &mut self, + insert_out_of_bounds_checks: bool, + flattened: bool, + ) -> HashSet { + let mut context = Context { flattened, ..Default::default() }; + + context.mark_function_parameter_arrays_as_used(self); + for call_data in &self.dfg.data_bus.call_data { context.mark_used_instruction_results(&self.dfg, call_data.array_id); } @@ -58,11 +93,12 @@ impl Function { // instructions (we don't want to remove those checks, or instructions that are // dependencies of those checks) if inserted_out_of_bounds_checks { - self.dead_instruction_elimination(false); - return; + return self.dead_instruction_elimination(false, flattened); } context.remove_rc_instructions(&mut self.dfg); + + context.used_values.into_iter().filter(|value| self.dfg.is_global(*value)).collect() } } @@ -76,6 +112,15 @@ struct Context { /// they technically contain side-effects but we still want to remove them if their /// `value` parameter is not used elsewhere. rc_instructions: Vec<(InstructionId, BasicBlockId)>, + + /// The elimination of certain unused instructions assumes that the DIE pass runs after + /// the flattening of the CFG, but if that's not the case then we should not eliminate + /// them just yet. + flattened: bool, + + // When tracking mutations we consider arrays with the same type as all being possibly mutated. + // This we consider to span all blocks of the functions. + mutated_array_types: HashSet, } impl Context { @@ -105,14 +150,18 @@ impl Context { let block = &function.dfg[block_id]; self.mark_terminator_values_as_used(function, block); - let instructions_len = block.instructions().len(); + // Lend the shared array type to the tracker. + let mut mutated_array_types = std::mem::take(&mut self.mutated_array_types); + let mut rc_tracker = RcTracker::new(&mut mutated_array_types); + rc_tracker.mark_terminator_arrays_as_used(function, block); - let mut rc_tracker = RcTracker::default(); + let instructions_len = block.instructions().len(); // Indexes of instructions that might be out of bounds. // We'll remove those, but before that we'll insert bounds checks for them. let mut possible_index_out_of_bounds_indexes = Vec::new(); + // Going in reverse so we know if a result of an instruction was used. for (instruction_index, instruction_id) in block.instructions().iter().rev().enumerate() { let instruction = &function.dfg[*instruction_id]; @@ -162,6 +211,9 @@ impl Context { .instructions_mut() .retain(|instruction| !self.instructions_to_remove.contains(instruction)); + // Take the mutated array back. + std::mem::swap(&mut self.mutated_array_types, &mut mutated_array_types); + false } @@ -172,7 +224,7 @@ impl Context { fn is_unused(&self, instruction_id: InstructionId, function: &Function) -> bool { let instruction = &function.dfg[instruction_id]; - if instruction.can_eliminate_if_unused(function) { + if instruction.can_eliminate_if_unused(function, self.flattened) { let results = function.dfg.instruction_results(instruction_id); results.iter().all(|result| !self.used_values.contains(result)) } else if let Instruction::Call { func, arguments } = instruction { @@ -195,15 +247,32 @@ impl Context { /// Inspects a value and marks all instruction results as used. fn mark_used_instruction_results(&mut self, dfg: &DataFlowGraph, value_id: ValueId) { let value_id = dfg.resolve(value_id); - if matches!(&dfg[value_id], Value::Instruction { .. } | Value::Param { .. }) { + if matches!(&dfg[value_id], Value::Instruction { .. } | Value::Param { .. }) + || dfg.is_global(value_id) + { self.used_values.insert(value_id); } } - fn remove_rc_instructions(self, dfg: &mut DataFlowGraph) { + /// Mark any array parameters to the function itself as possibly mutated. + fn mark_function_parameter_arrays_as_used(&mut self, function: &Function) { + for parameter in function.parameters() { + let typ = function.dfg.type_of_value(*parameter); + if typ.contains_an_array() { + let typ = typ.get_contained_array(); + // Want to store the array type which is being referenced, + // because it's the underlying array that the `inc_rc` is associated with. + self.mutated_array_types.insert(typ.clone()); + } + } + } + + /// Go through the RC instructions collected when we figured out which values were unused; + /// for each RC that refers to an unused value, remove the RC as well. + fn remove_rc_instructions(&self, dfg: &mut DataFlowGraph) { let unused_rc_values_by_block: HashMap> = - self.rc_instructions.into_iter().fold(HashMap::default(), |mut acc, (rc, block)| { - let value = match &dfg[rc] { + self.rc_instructions.iter().fold(HashMap::default(), |mut acc, (rc, block)| { + let value = match &dfg[*rc] { Instruction::IncrementRc { value } => *value, Instruction::DecrementRc { value } => *value, other => { @@ -214,7 +283,7 @@ impl Context { }; if !self.used_values.contains(&value) { - acc.entry(block).or_default().insert(rc); + acc.entry(*block).or_default().insert(*rc); } acc }); @@ -356,15 +425,16 @@ impl Context { ) -> bool { use Instruction::*; if let IncrementRc { value } | DecrementRc { value } = instruction { - if let Value::Instruction { instruction, .. } = &dfg[*value] { - return match &dfg[*instruction] { - MakeArray { .. } => true, - Call { func, .. } => { - matches!(&dfg[*func], Value::Intrinsic(_) | Value::ForeignFunction(_)) - } - _ => false, - }; - } + let Some(instruction) = dfg.get_local_or_global_instruction(*value) else { + return false; + }; + return match instruction { + MakeArray { .. } => true, + Call { func, .. } => { + matches!(&dfg[*func], Value::Intrinsic(_) | Value::ForeignFunction(_)) + } + _ => false, + }; } false } @@ -528,8 +598,8 @@ fn apply_side_effects( (lhs, rhs) } -#[derive(Default)] -struct RcTracker { +/// Per block RC tracker. +struct RcTracker<'a> { // We can track IncrementRc instructions per block to determine whether they are useless. // IncrementRc and DecrementRc instructions are normally side effectual instructions, but we remove // them if their value is not used anywhere in the function. However, even when their value is used, their existence @@ -538,22 +608,44 @@ struct RcTracker { // with the same value but no array set in between. // If we see an inc/dec RC pair within a block we can safely remove both instructions. rcs_with_possible_pairs: HashMap>, + // Tracks repeated RC instructions: if there are two `inc_rc` for the same value in a row, the 2nd one is redundant. rc_pairs_to_remove: HashSet, // We also separately track all IncrementRc instructions and all array types which have been mutably borrowed. // If an array is the same type as one of those non-mutated array types, we can safely remove all IncrementRc instructions on that array. inc_rcs: HashMap>, - mutated_array_types: HashSet, // The SSA often creates patterns where after simplifications we end up with repeat // IncrementRc instructions on the same value. We track whether the previous instruction was an IncrementRc, // and if the current instruction is also an IncrementRc on the same value we remove the current instruction. // `None` if the previous instruction was anything other than an IncrementRc previous_inc_rc: Option, + // Mutated arrays shared across the blocks of the function. + mutated_array_types: &'a mut HashSet, } -impl RcTracker { +impl<'a> RcTracker<'a> { + fn new(mutated_array_types: &'a mut HashSet) -> Self { + Self { + rcs_with_possible_pairs: Default::default(), + rc_pairs_to_remove: Default::default(), + inc_rcs: Default::default(), + previous_inc_rc: Default::default(), + mutated_array_types, + } + } + + fn mark_terminator_arrays_as_used(&mut self, function: &Function, block: &BasicBlock) { + block.unwrap_terminator().for_each_value(|value| { + let typ = function.dfg.type_of_value(value); + if matches!(&typ, Type::Array(_, _) | Type::Slice(_)) { + self.mutated_array_types.insert(typ); + } + }); + } + fn track_inc_rcs_to_remove(&mut self, instruction_id: InstructionId, function: &Function) { let instruction = &function.dfg[instruction_id]; + // Deduplicate IncRC instructions. if let Instruction::IncrementRc { value } = instruction { if let Some(previous_value) = self.previous_inc_rc { if previous_value == *value { @@ -562,6 +654,7 @@ impl RcTracker { } self.previous_inc_rc = Some(*value); } else { + // Reset the deduplication. self.previous_inc_rc = None; } @@ -569,6 +662,8 @@ impl RcTracker { // when we see a DecrementRc and check whether it was possibly mutated when we see an IncrementRc. match instruction { Instruction::IncrementRc { value } => { + // Get any RC instruction recorded further down the block for this array; + // if it exists and not marked as mutated, then both RCs can be removed. if let Some(inc_rc) = pop_rc_for(*value, function, &mut self.rcs_with_possible_pairs) { @@ -577,7 +672,7 @@ impl RcTracker { self.rc_pairs_to_remove.insert(instruction_id); } } - + // Remember that this array was RC'd by this instruction. self.inc_rcs.entry(*value).or_default().insert(instruction_id); } Instruction::DecrementRc { value } => { @@ -590,12 +685,12 @@ impl RcTracker { } Instruction::ArraySet { array, .. } => { let typ = function.dfg.type_of_value(*array); + // We mark all RCs that refer to arrays with a matching type as the one being set, as possibly mutated. if let Some(dec_rcs) = self.rcs_with_possible_pairs.get_mut(&typ) { for dec_rc in dec_rcs { dec_rc.possibly_mutated = true; } } - self.mutated_array_types.insert(typ); } Instruction::Store { value, .. } => { @@ -606,6 +701,9 @@ impl RcTracker { } } Instruction::Call { arguments, .. } => { + // Treat any array-type arguments to calls as possible sources of mutation. + // During the preprocessing of functions in isolation we don't want to + // get rid of IncRCs arrays that can potentially be mutated outside. for arg in arguments { let typ = function.dfg.type_of_value(*arg); if matches!(&typ, Type::Array(..) | Type::Slice(..)) { @@ -617,6 +715,7 @@ impl RcTracker { } } + /// Get all RC instructions which work on arrays whose type has not been marked as mutated. fn get_non_mutated_arrays(&self, dfg: &DataFlowGraph) -> HashSet { self.inc_rcs .keys() @@ -815,16 +914,6 @@ mod test { #[test] fn keep_inc_rc_on_borrowed_array_set() { - // brillig(inline) fn main f0 { - // b0(v0: [u32; 2]): - // inc_rc v0 - // v3 = array_set v0, index u32 0, value u32 1 - // inc_rc v0 - // inc_rc v0 - // inc_rc v0 - // v4 = array_get v3, index u32 1 - // return v4 - // } let src = " brillig(inline) fn main f0 { b0(v0: [u32; 2]): @@ -878,7 +967,7 @@ mod test { } #[test] - fn remove_inc_rcs_that_are_never_mutably_borrowed() { + fn does_not_remove_inc_rcs_that_are_never_mutably_borrowed() { let src = " brillig(inline) fn main f0 { b0(v0: [Field; 2]): @@ -900,7 +989,9 @@ mod test { let expected = " brillig(inline) fn main f0 { b0(v0: [Field; 2]): + inc_rc v0 v2 = array_get v0, index u32 0 -> Field + inc_rc v0 return v2 } "; @@ -909,6 +1000,36 @@ mod test { assert_normalized_ssa_equals(ssa, expected); } + #[test] + fn do_not_remove_inc_rcs_for_arrays_in_terminator() { + let src = " + brillig(inline) fn main f0 { + b0(v0: [Field; 2]): + inc_rc v0 + inc_rc v0 + inc_rc v0 + v2 = array_get v0, index u32 0 -> Field + inc_rc v0 + return v0, v2 + } + "; + + let ssa = Ssa::from_str(src).unwrap(); + + let expected = " + brillig(inline) fn main f0 { + b0(v0: [Field; 2]): + inc_rc v0 + v2 = array_get v0, index u32 0 -> Field + inc_rc v0 + return v0, v2 + } + "; + + let ssa = ssa.dead_instruction_elimination(); + assert_normalized_ssa_equals(ssa, expected); + } + #[test] fn do_not_remove_inc_rc_if_used_as_call_arg() { // We do not want to remove inc_rc instructions on values @@ -941,4 +1062,53 @@ mod test { let ssa = ssa.dead_instruction_elimination(); assert_normalized_ssa_equals(ssa, src); } + + #[test] + fn do_not_remove_mutable_reference_params() { + let src = " + acir(inline) fn main f0 { + b0(v0: Field, v1: Field): + v2 = allocate -> &mut Field + store v0 at v2 + call f1(v2) + v4 = load v2 -> Field + v5 = eq v4, v1 + constrain v4 == v1 + return + } + acir(inline) fn Add10 f1 { + b0(v0: &mut Field): + v1 = load v0 -> Field + v2 = load v0 -> Field + v4 = add v2, Field 10 + store v4 at v0 + return + } + "; + + let ssa = Ssa::from_str(src).unwrap(); + + // Even though these ACIR functions only have 1 block, we have not inlined and flattened anything yet. + let ssa = ssa.dead_instruction_elimination_inner(false); + + let expected = " + acir(inline) fn main f0 { + b0(v0: Field, v1: Field): + v2 = allocate -> &mut Field + store v0 at v2 + call f1(v2) + v4 = load v2 -> Field + constrain v4 == v1 + return + } + acir(inline) fn Add10 f1 { + b0(v0: &mut Field): + v1 = load v0 -> Field + v3 = add v1, Field 10 + store v3 at v0 + return + } + "; + assert_normalized_ssa_equals(ssa, expected); + } } diff --git a/noir/noir-repo/compiler/noirc_evaluator/src/ssa/opt/inlining.rs b/noir/noir-repo/compiler/noirc_evaluator/src/ssa/opt/inlining.rs index b5cbc90e30d..7f96df1384b 100644 --- a/noir/noir-repo/compiler/noirc_evaluator/src/ssa/opt/inlining.rs +++ b/noir/noir-repo/compiler/noirc_evaluator/src/ssa/opt/inlining.rs @@ -2,9 +2,10 @@ //! The purpose of this pass is to inline the instructions of each function call //! within the function caller. If all function calls are known, there will only //! be a single function remaining when the pass finishes. -use std::collections::{BTreeSet, HashSet, VecDeque}; +use std::collections::{BTreeMap, BTreeSet, HashSet, VecDeque}; use acvm::acir::AcirField; +use im::HashMap; use iter_extended::{btree_map, vecmap}; use crate::ssa::{ @@ -12,14 +13,13 @@ use crate::ssa::{ ir::{ basic_block::BasicBlockId, call_stack::CallStackId, - dfg::InsertInstructionResult, + dfg::{GlobalsGraph, InsertInstructionResult}, function::{Function, FunctionId, RuntimeType}, instruction::{Instruction, InstructionId, TerminatorInstruction}, value::{Value, ValueId}, }, ssa_gen::Ssa, }; -use fxhash::FxHashMap as HashMap; /// An arbitrary limit to the maximum number of recursive call /// frames at any point in time. @@ -46,54 +46,90 @@ impl Ssa { /// This step should run after runtime separation, since it relies on the runtime of the called functions being final. #[tracing::instrument(level = "trace", skip(self))] pub(crate) fn inline_functions(self, aggressiveness: i64) -> Ssa { - let inline_sources = get_functions_to_inline_into(&self, false, aggressiveness); - Self::inline_functions_inner(self, &inline_sources, false) + let inline_infos = compute_inline_infos(&self, false, aggressiveness); + Self::inline_functions_inner(self, &inline_infos, false) } - // Run the inlining pass where functions marked with `InlineType::NoPredicates` as not entry points + /// Run the inlining pass where functions marked with `InlineType::NoPredicates` as not entry points pub(crate) fn inline_functions_with_no_predicates(self, aggressiveness: i64) -> Ssa { - let inline_sources = get_functions_to_inline_into(&self, true, aggressiveness); - Self::inline_functions_inner(self, &inline_sources, true) + let inline_infos = compute_inline_infos(&self, true, aggressiveness); + Self::inline_functions_inner(self, &inline_infos, true) } fn inline_functions_inner( mut self, - inline_sources: &BTreeSet, + inline_infos: &InlineInfos, inline_no_predicates_functions: bool, ) -> Ssa { - // Note that we clear all functions other than those in `inline_sources`. - // If we decide to do partial inlining then we should change this to preserve those functions which still exist. - self.functions = btree_map(inline_sources, |entry_point| { - let should_inline_call = - |_context: &PerFunctionContext, ssa: &Ssa, called_func_id: FunctionId| -> bool { - let function = &ssa.functions[&called_func_id]; - - match function.runtime() { - RuntimeType::Acir(inline_type) => { - // If the called function is acir, we inline if it's not an entry point - - // If we have not already finished the flattening pass, functions marked - // to not have predicates should be preserved. - let preserve_function = - !inline_no_predicates_functions && function.is_no_predicates(); - !inline_type.is_entry_point() && !preserve_function - } - RuntimeType::Brillig(_) => { - // If the called function is brillig, we inline only if it's into brillig and the function is not recursive - ssa.functions[entry_point].runtime().is_brillig() - && !inline_sources.contains(&called_func_id) - } - } - }; + let inline_targets = + inline_infos.iter().filter_map(|(id, info)| info.is_inline_target().then_some(*id)); + + let should_inline_call = |callee: &Function| -> bool { + match callee.runtime() { + RuntimeType::Acir(_) => { + // If we have not already finished the flattening pass, functions marked + // to not have predicates should be preserved. + let preserve_function = + !inline_no_predicates_functions && callee.is_no_predicates(); + !preserve_function + } + RuntimeType::Brillig(_) => { + // We inline inline if the function called wasn't ruled out as too costly or recursive. + InlineInfo::should_inline(inline_infos, callee.id()) + } + } + }; + + // NOTE: Functions are processed independently of each other, with the final mapping replacing the original, + // instead of inlining the "leaf" functions, moving up towards the entry point. + self.functions = btree_map(inline_targets, |entry_point| { + let function = &self.functions[&entry_point]; + let new_function = function.inlined(&self, &should_inline_call); + (entry_point, new_function) + }); + self + } + + pub(crate) fn inline_simple_functions(mut self: Ssa) -> Ssa { + let should_inline_call = |callee: &Function| { + if let RuntimeType::Acir(_) = callee.runtime() { + // Functions marked to not have predicates should be preserved. + if callee.is_no_predicates() { + return false; + } + } + + let entry_block_id = callee.entry_block(); + let entry_block = &callee.dfg[entry_block_id]; - let new_function = - InlineContext::new(&self, *entry_point).inline_all(&self, &should_inline_call); - (*entry_point, new_function) + // Only inline functions with a single block + if entry_block.successors().next().is_some() { + return false; + } + + // Only inline functions with 0 or 1 instructions + entry_block.instructions().len() <= 1 + }; + + self.functions = btree_map(self.functions.iter(), |(id, function)| { + (*id, function.inlined(&self, &should_inline_call)) }); + self } } +impl Function { + /// Create a new function which has the functions called by this one inlined into its body. + pub(super) fn inlined( + &self, + ssa: &Ssa, + should_inline_call: &impl Fn(&Function) -> bool, + ) -> Function { + InlineContext::new(ssa, self.id()).inline_all(ssa, &should_inline_call) + } +} + /// The context for the function inlining pass. /// /// This works using an internal FunctionBuilder to build a new main function from scratch. @@ -115,6 +151,9 @@ struct InlineContext { /// inline into. The same goes for ValueIds, InstructionIds, and for storing other data like /// parameter to argument mappings. struct PerFunctionContext<'function> { + /// The function that we are inlining calls into. + entry_function: &'function Function, + /// The source function is the function we're currently inlining into the function being built. source_function: &'function Function, @@ -139,10 +178,12 @@ struct PerFunctionContext<'function> { /// True if we're currently working on the entry point function. inlining_entry: bool, - globals: &'function Function, + globals: &'function GlobalsGraph, } /// Utility function to find out the direct calls of a function. +/// +/// Returns the function IDs from all `Call` instructions without deduplication. fn called_functions_vec(func: &Function) -> Vec { let mut called_function_ids = Vec::new(); for block_id in func.reachable_blocks() { @@ -160,32 +201,65 @@ fn called_functions_vec(func: &Function) -> Vec { called_function_ids } -/// Utility function to find out the deduplicated direct calls of a function. +/// Utility function to find out the deduplicated direct calls made from a function. fn called_functions(func: &Function) -> BTreeSet { called_functions_vec(func).into_iter().collect() } +/// Information about a function to aid the decision about whether to inline it or not. +/// The final decision depends on what we're inlining it into. +#[derive(Default, Debug)] +pub(super) struct InlineInfo { + is_brillig_entry_point: bool, + is_acir_entry_point: bool, + is_recursive: bool, + pub(super) should_inline: bool, + weight: i64, + cost: i64, +} + +impl InlineInfo { + /// Functions which are to be retained, not inlined. + pub(super) fn is_inline_target(&self) -> bool { + self.is_brillig_entry_point + || self.is_acir_entry_point + || self.is_recursive + || !self.should_inline + } + + pub(super) fn should_inline(inline_infos: &InlineInfos, called_func_id: FunctionId) -> bool { + inline_infos.get(&called_func_id).map(|info| info.should_inline).unwrap_or_default() + } +} + +type InlineInfos = BTreeMap; + /// The functions we should inline into (and that should be left in the final program) are: /// - main /// - Any Brillig function called from Acir /// - Some Brillig functions depending on aggressiveness and some metrics /// - Any Acir functions with a [fold inline type][InlineType::Fold], -fn get_functions_to_inline_into( +/// +/// The returned `InlineInfos` won't have every function in it, only the ones which the algorithm visited. +pub(super) fn compute_inline_infos( ssa: &Ssa, inline_no_predicates_functions: bool, aggressiveness: i64, -) -> BTreeSet { - let mut brillig_entry_points = BTreeSet::default(); - let mut acir_entry_points = BTreeSet::default(); - - if matches!(ssa.main().runtime(), RuntimeType::Brillig(_)) { - brillig_entry_points.insert(ssa.main_id); - } else { - acir_entry_points.insert(ssa.main_id); - } +) -> InlineInfos { + let mut inline_infos = InlineInfos::default(); + + inline_infos.insert( + ssa.main_id, + InlineInfo { + is_acir_entry_point: ssa.main().runtime().is_acir(), + is_brillig_entry_point: ssa.main().runtime().is_brillig(), + ..Default::default() + }, + ); + // Handle ACIR functions. for (func_id, function) in ssa.functions.iter() { - if matches!(function.runtime(), RuntimeType::Brillig(_)) { + if function.runtime().is_brillig() { continue; } @@ -193,83 +267,216 @@ fn get_functions_to_inline_into( // to not have predicates should be preserved. let preserve_function = !inline_no_predicates_functions && function.is_no_predicates(); if function.runtime().is_entry_point() || preserve_function { - acir_entry_points.insert(*func_id); + inline_infos.entry(*func_id).or_default().is_acir_entry_point = true; } - for called_function_id in called_functions(function) { - if matches!(ssa.functions[&called_function_id].runtime(), RuntimeType::Brillig(_)) { - brillig_entry_points.insert(called_function_id); + // Any Brillig function called from ACIR is an entry into the Brillig VM. + for called_func_id in called_functions(function) { + if ssa.functions[&called_func_id].runtime().is_brillig() { + inline_infos.entry(called_func_id).or_default().is_brillig_entry_point = true; } } } - let times_called = compute_times_called(ssa); + let callers = compute_callers(ssa); + let times_called = compute_times_called(&callers); - let brillig_functions_to_retain: BTreeSet<_> = compute_functions_to_retain( + mark_brillig_functions_to_retain( ssa, - &brillig_entry_points, - ×_called, inline_no_predicates_functions, aggressiveness, + ×_called, + &mut inline_infos, ); - acir_entry_points - .into_iter() - .chain(brillig_entry_points) - .chain(brillig_functions_to_retain) + inline_infos +} + +/// Compute the time each function is called from any other function. +fn compute_times_called( + callers: &BTreeMap>, +) -> HashMap { + callers + .iter() + .map(|(callee, callers)| { + let total_calls = callers.values().sum(); + (*callee, total_calls) + }) .collect() } -fn compute_times_called(ssa: &Ssa) -> HashMap { +/// Compute for each function the set of functions that call it, and how many times they do so. +fn compute_callers(ssa: &Ssa) -> BTreeMap> { ssa.functions .iter() - .flat_map(|(_caller_id, function)| { - let called_functions_vec = called_functions_vec(function); - called_functions_vec.into_iter() + .flat_map(|(caller_id, function)| { + let called_functions = called_functions_vec(function); + called_functions.into_iter().map(|callee_id| (*caller_id, callee_id)) }) - .chain(std::iter::once(ssa.main_id)) - .fold(HashMap::default(), |mut map, func_id| { - *map.entry(func_id).or_insert(0) += 1; - map + .fold( + // Make sure an entry exists even for ones that don't get called. + ssa.functions.keys().map(|id| (*id, BTreeMap::new())).collect(), + |mut acc, (caller_id, callee_id)| { + let callers = acc.entry(callee_id).or_default(); + *callers.entry(caller_id).or_default() += 1; + acc + }, + ) +} + +/// Compute for each function the set of functions called by it, and how many times it does so. +fn compute_callees(ssa: &Ssa) -> BTreeMap> { + ssa.functions + .iter() + .flat_map(|(caller_id, function)| { + let called_functions = called_functions_vec(function); + called_functions.into_iter().map(|callee_id| (*caller_id, callee_id)) }) + .fold( + // Make sure an entry exists even for ones that don't call anything. + ssa.functions.keys().map(|id| (*id, BTreeMap::new())).collect(), + |mut acc, (caller_id, callee_id)| { + let callees = acc.entry(caller_id).or_default(); + *callees.entry(callee_id).or_default() += 1; + acc + }, + ) } -fn should_retain_recursive( +/// Compute something like a topological order of the functions, starting with the ones +/// that do not call any other functions, going towards the entry points. When cycles +/// are detected, take the one which are called by the most to break the ties. +/// +/// This can be used to simplify the most often called functions first. +/// +/// Returns the functions paired with their own as well as transitive weight, +/// which accumulates the weight of all the functions they call, as well as own. +pub(super) fn compute_bottom_up_order(ssa: &Ssa) -> Vec<(FunctionId, (usize, usize))> { + let mut order = Vec::new(); + let mut visited = HashSet::new(); + + // Call graph which we'll repeatedly prune to find the "leaves". + let mut callees = compute_callees(ssa); + let callers = compute_callers(ssa); + + // Number of times a function is called, used to break cycles in the call graph by popping the next candidate. + let mut times_called = compute_times_called(&callers).into_iter().collect::>(); + times_called.sort_by_key(|(id, cnt)| { + // Sort by called the *least* by others, as these are less likely to cut the graph when removed. + let called_desc = -(*cnt as i64); + // Sort entries first (last to be popped). + let is_entry_asc = -called_desc.signum(); + // Finally break ties by ID. + (is_entry_asc, called_desc, *id) + }); + + // Start with the weight of the functions in isolation, then accumulate as we pop off the ones they call. + let own_weights = ssa + .functions + .iter() + .map(|(id, f)| (*id, compute_function_own_weight(f))) + .collect::>(); + let mut weights = own_weights.clone(); + + // Seed the queue with functions that don't call anything. + let mut queue = callees + .iter() + .filter_map(|(id, callees)| callees.is_empty().then_some(*id)) + .collect::>(); + + loop { + while let Some(id) = queue.pop_front() { + // Pull the current weight of yet-to-be emitted callees (a nod to mutual recursion). + for (callee, cnt) in &callees[&id] { + if *callee != id { + weights[&id] = weights[&id].saturating_add(cnt.saturating_mul(weights[callee])); + } + } + // Own weight plus the weights accumulated from callees. + let weight = weights[&id]; + let own_weight = own_weights[&id]; + + // Emit the function. + order.push((id, (own_weight, weight))); + visited.insert(id); + + // Update the callers of this function. + for (caller, cnt) in &callers[&id] { + // Update the weight of the caller with the weight of this function. + weights[caller] = weights[caller].saturating_add(cnt.saturating_mul(weight)); + // Remove this function from the callees of the caller. + let callees = callees.get_mut(caller).unwrap(); + callees.remove(&id); + // If the caller doesn't call any other function, enqueue it, + // unless it's the entry function, which is never called by anything, so it should be last. + if callees.is_empty() && !visited.contains(caller) && !callers[caller].is_empty() { + queue.push_back(*caller); + } + } + } + // If we ran out of the queue, maybe there is a cycle; take the next most called function. + while let Some((id, _)) = times_called.pop() { + if !visited.contains(&id) { + queue.push_back(id); + break; + } + } + if times_called.is_empty() && queue.is_empty() { + assert_eq!(order.len(), callers.len()); + return order; + } + } +} + +/// Traverse the call graph starting from a given function, marking function to be retained if they are: +/// * recursive functions, or +/// * the cost of inlining outweighs the cost of not doing so +fn mark_functions_to_retain_recursive( ssa: &Ssa, - func: FunctionId, - times_called: &HashMap, - should_retain_function: &mut HashMap, - mut explored_functions: im::HashSet, inline_no_predicates_functions: bool, aggressiveness: i64, + times_called: &HashMap, + inline_infos: &mut InlineInfos, + mut explored_functions: im::HashSet, + func: FunctionId, ) { - // We have already decided on this function - if should_retain_function.get(&func).is_some() { + // Check if we have set any of the fields this method touches. + let decided = |inline_infos: &InlineInfos| { + inline_infos + .get(&func) + .map(|info| info.is_recursive || info.should_inline || info.weight != 0) + .unwrap_or_default() + }; + + // Check if we have already decided on this function + if decided(inline_infos) { return; } - // Recursive, this function won't be inlined + + // If recursive, this function won't be inlined if explored_functions.contains(&func) { - should_retain_function.insert(func, (true, 0)); + inline_infos.entry(func).or_default().is_recursive = true; return; } explored_functions.insert(func); - // Decide on dependencies first - let called_functions = called_functions(&ssa.functions[&func]); - for function in called_functions.iter() { - should_retain_recursive( + // Decide on dependencies first, so we know their weight. + let called_functions = called_functions_vec(&ssa.functions[&func]); + for callee in &called_functions { + mark_functions_to_retain_recursive( ssa, - *function, - times_called, - should_retain_function, - explored_functions.clone(), inline_no_predicates_functions, aggressiveness, + times_called, + inline_infos, + explored_functions.clone(), + *callee, ); } + // We could have decided on this function while deciding on dependencies - // If the function is recursive - if should_retain_function.get(&func).is_some() { + // if the function is recursive. + if decided(inline_infos) { return; } @@ -277,13 +484,18 @@ fn should_retain_recursive( // We compute the weight (roughly the number of instructions) of the function after inlining // And the interface cost of the function (the inherent cost at the callsite, roughly the number of args and returns) // We then can compute an approximation of the cost of inlining vs the cost of retaining the function - // We do this computation using saturating i64s to avoid overflows - let inlined_function_weights: i64 = called_functions.iter().fold(0, |acc, called_function| { - let (should_retain, weight) = should_retain_function[called_function]; - if should_retain { - acc + // We do this computation using saturating i64s to avoid overflows, + // and because we want to calculate a difference which can be negative. + + // Total weight of functions called by this one, unless we decided not to inline them. + // Callees which appear multiple times would be inlined multiple times. + let inlined_function_weights: i64 = called_functions.iter().fold(0, |acc, callee| { + let info = &inline_infos[callee]; + // If the callee is not going to be inlined then we can ignore its cost. + if info.should_inline { + acc.saturating_add(info.weight) } else { - acc.saturating_add(weight) + acc } }); @@ -296,54 +508,50 @@ fn should_retain_recursive( let inline_cost = times_called.saturating_mul(this_function_weight); let retain_cost = times_called.saturating_mul(interface_cost) + this_function_weight; + let net_cost = inline_cost.saturating_sub(retain_cost); let runtime = ssa.functions[&func].runtime(); // We inline if the aggressiveness is higher than inline cost minus the retain cost // If aggressiveness is infinite, we'll always inline // If aggressiveness is 0, we'll inline when the inline cost is lower than the retain cost // If aggressiveness is minus infinity, we'll never inline (other than in the mandatory cases) - let should_inline = ((inline_cost.saturating_sub(retain_cost)) < aggressiveness) + let should_inline = (net_cost < aggressiveness) || runtime.is_inline_always() || (runtime.is_no_predicates() && inline_no_predicates_functions); - should_retain_function.insert(func, (!should_inline, this_function_weight)); + let info = inline_infos.entry(func).or_default(); + info.should_inline = should_inline; + info.weight = this_function_weight; + info.cost = net_cost; } -fn compute_functions_to_retain( +/// Mark Brillig functions that should not be inlined because they are recursive or expensive. +fn mark_brillig_functions_to_retain( ssa: &Ssa, - entry_points: &BTreeSet, - times_called: &HashMap, inline_no_predicates_functions: bool, aggressiveness: i64, -) -> BTreeSet { - let mut should_retain_function = HashMap::default(); + times_called: &HashMap, + inline_infos: &mut InlineInfos, +) { + let brillig_entry_points = inline_infos + .iter() + .filter_map(|(id, info)| info.is_brillig_entry_point.then_some(*id)) + .collect::>(); - for entry_point in entry_points.iter() { - should_retain_recursive( + for entry_point in brillig_entry_points { + mark_functions_to_retain_recursive( ssa, - *entry_point, - times_called, - &mut should_retain_function, - im::HashSet::default(), inline_no_predicates_functions, aggressiveness, + times_called, + inline_infos, + im::HashSet::default(), + entry_point, ); } - - should_retain_function - .into_iter() - .filter_map( - |(func_id, (should_retain, _))| { - if should_retain { - Some(func_id) - } else { - None - } - }, - ) - .collect() } +/// Compute a weight of a function based on the number of instructions in its reachable blocks. fn compute_function_own_weight(func: &Function) -> usize { let mut weight = 0; for block_id in func.reachable_blocks() { @@ -354,6 +562,7 @@ fn compute_function_own_weight(func: &Function) -> usize { weight } +/// Compute interface cost of a function based on the number of inputs and outputs. fn compute_function_interface_cost(func: &Function) -> usize { func.parameters().len() + func.returns().len() } @@ -377,11 +586,12 @@ impl InlineContext { fn inline_all( mut self, ssa: &Ssa, - should_inline_call: &impl Fn(&PerFunctionContext, &Ssa, FunctionId) -> bool, + should_inline_call: &impl Fn(&Function) -> bool, ) -> Function { let entry_point = &ssa.functions[&self.entry_point]; - let mut context = PerFunctionContext::new(&mut self, entry_point, &ssa.globals); + let globals = &entry_point.dfg.globals; + let mut context = PerFunctionContext::new(&mut self, entry_point, entry_point, globals); context.inlining_entry = true; for (_, value) in entry_point.dfg.globals.values_iter() { @@ -420,7 +630,7 @@ impl InlineContext { ssa: &Ssa, id: FunctionId, arguments: &[ValueId], - should_inline_call: &impl Fn(&PerFunctionContext, &Ssa, FunctionId) -> bool, + should_inline_call: &impl Fn(&Function) -> bool, ) -> Vec { self.recursion_level += 1; @@ -428,11 +638,13 @@ impl InlineContext { if self.recursion_level > RECURSION_LIMIT { panic!( - "Attempted to recur more than {RECURSION_LIMIT} times during inlining function '{}': {}", source_function.name(), source_function + "Attempted to recur more than {RECURSION_LIMIT} times during inlining function '{}':\n{}", source_function.name(), source_function ); } - let mut context = PerFunctionContext::new(self, source_function, &ssa.globals); + let entry_point = &ssa.functions[&self.entry_point]; + let globals = &source_function.dfg.globals; + let mut context = PerFunctionContext::new(self, entry_point, source_function, globals); let parameters = source_function.parameters(); assert_eq!(parameters.len(), arguments.len()); @@ -454,11 +666,13 @@ impl<'function> PerFunctionContext<'function> { /// the arguments of the destination function. fn new( context: &'function mut InlineContext, + entry_function: &'function Function, source_function: &'function Function, - globals: &'function Function, + globals: &'function GlobalsGraph, ) -> Self { Self { context, + entry_function, source_function, blocks: HashMap::default(), values: HashMap::default(), @@ -480,18 +694,20 @@ impl<'function> PerFunctionContext<'function> { let new_value = match &self.source_function.dfg[id] { value @ Value::Instruction { instruction, .. } => { - // TODO: Inlining the global into the function is only a temporary measure - // until Brillig gen with globals is working end to end if self.source_function.dfg.is_global(id) { - let Instruction::MakeArray { elements, typ } = &self.globals.dfg[*instruction] - else { - panic!("Only expect Instruction::MakeArray for a global"); - }; - let elements = elements - .iter() - .map(|element| self.translate_value(*element)) - .collect::>(); - return self.context.builder.insert_make_array(elements, typ.clone()); + if self.context.builder.current_function.dfg.runtime().is_acir() { + let Instruction::MakeArray { elements, typ } = &self.globals[*instruction] + else { + panic!("Only expect Instruction::MakeArray for a global"); + }; + let elements = elements + .iter() + .map(|element| self.translate_value(*element)) + .collect::>(); + return self.context.builder.insert_make_array(elements, typ.clone()); + } else { + return id; + } } unreachable!("All Value::Instructions should already be known during inlining after creating the original inlined instruction. Unknown value {id} = {value:?}") } @@ -499,11 +715,16 @@ impl<'function> PerFunctionContext<'function> { unreachable!("All Value::Params should already be known from previous calls to translate_block. Unknown value {id} = {value:?}") } Value::NumericConstant { constant, typ } => { - // TODO: Inlining the global into the function is only a temporary measure - // until Brillig gen with globals is working end to end. - // The dfg indexes a global's inner value directly, so we will need to check here + // The dfg indexes a global's inner value directly, so we need to check here // whether we have a global. - self.context.builder.numeric_constant(*constant, *typ) + // We also only keep a global and do not inline it in a Brillig runtime. + if self.source_function.dfg.is_global(id) + && self.context.builder.current_function.dfg.runtime().is_brillig() + { + id + } else { + self.context.builder.numeric_constant(*constant, *typ) + } } Value::Function(function) => self.context.builder.import_function(*function), Value::Intrinsic(intrinsic) => self.context.builder.import_intrinsic_id(*intrinsic), @@ -572,7 +793,7 @@ impl<'function> PerFunctionContext<'function> { fn inline_blocks( &mut self, ssa: &Ssa, - should_inline_call: &impl Fn(&PerFunctionContext, &Ssa, FunctionId) -> bool, + should_inline_call: &impl Fn(&Function) -> bool, ) -> Vec { let mut seen_blocks = HashSet::new(); let mut block_queue = VecDeque::new(); @@ -639,7 +860,7 @@ impl<'function> PerFunctionContext<'function> { &mut self, ssa: &Ssa, block_id: BasicBlockId, - should_inline_call: &impl Fn(&PerFunctionContext, &Ssa, FunctionId) -> bool, + should_inline_call: &impl Fn(&Function) -> bool, ) { let mut side_effects_enabled: Option = None; @@ -648,19 +869,29 @@ impl<'function> PerFunctionContext<'function> { match &self.source_function.dfg[*id] { Instruction::Call { func, arguments } => match self.get_function(*func) { Some(func_id) => { - if should_inline_call(self, ssa, func_id) { - self.inline_function(ssa, *id, func_id, arguments, should_inline_call); - - // This is only relevant during handling functions with `InlineType::NoPredicates` as these - // can pollute the function they're being inlined into with `Instruction::EnabledSideEffects`, - // resulting in predicates not being applied properly. - // - // Note that this doesn't cover the case in which there exists an `Instruction::EnabledSideEffects` - // within the function being inlined whilst the source function has not encountered one yet. - // In practice this isn't an issue as the last `Instruction::EnabledSideEffects` in the - // function being inlined will be to turn off predicates rather than to create one. - if let Some(condition) = side_effects_enabled { - self.context.builder.insert_enable_side_effects_if(condition); + if let Some(callee) = self.should_inline_call(ssa, func_id) { + if should_inline_call(callee) { + self.inline_function( + ssa, + *id, + func_id, + arguments, + should_inline_call, + ); + + // This is only relevant during handling functions with `InlineType::NoPredicates` as these + // can pollute the function they're being inlined into with `Instruction::EnabledSideEffects`, + // resulting in predicates not being applied properly. + // + // Note that this doesn't cover the case in which there exists an `Instruction::EnabledSideEffects` + // within the function being inlined whilst the source function has not encountered one yet. + // In practice this isn't an issue as the last `Instruction::EnabledSideEffects` in the + // function being inlined will be to turn off predicates rather than to create one. + if let Some(condition) = side_effects_enabled { + self.context.builder.insert_enable_side_effects_if(condition); + } + } else { + self.push_instruction(*id); } } else { self.push_instruction(*id); @@ -677,6 +908,38 @@ impl<'function> PerFunctionContext<'function> { } } + fn should_inline_call<'a>( + &self, + ssa: &'a Ssa, + called_func_id: FunctionId, + ) -> Option<&'a Function> { + // Do not inline self-recursive functions on the top level. + // Inlining a self-recursive function works when there is something to inline into + // by importing all the recursive blocks, but for the entry function there is no wrapper. + if self.entry_function.id() == called_func_id { + return None; + } + + let callee = &ssa.functions[&called_func_id]; + + match callee.runtime() { + RuntimeType::Acir(inline_type) => { + // If the called function is acir, we inline if it's not an entry point + if inline_type.is_entry_point() { + return None; + } + } + RuntimeType::Brillig(_) => { + if self.entry_function.runtime().is_acir() { + // We never inline a brillig function into an ACIR function. + return None; + } + } + } + + Some(callee) + } + /// Inline a function call and remember the inlined return values in the values map fn inline_function( &mut self, @@ -684,7 +947,7 @@ impl<'function> PerFunctionContext<'function> { call_id: InstructionId, function: FunctionId, arguments: &[ValueId], - should_inline_call: &impl Fn(&PerFunctionContext, &Ssa, FunctionId) -> bool, + should_inline_call: &impl Fn(&Function) -> bool, ) { let old_results = self.source_function.dfg.instruction_results(call_id); let arguments = vecmap(arguments, |arg| self.translate_value(*arg)); @@ -874,6 +1137,8 @@ impl<'function> PerFunctionContext<'function> { #[cfg(test)] mod test { + use std::cmp::max; + use acvm::{acir::AcirField, FieldElement}; use noirc_frontend::monomorphization::ast::InlineType; @@ -886,8 +1151,12 @@ mod test { map::Id, types::{NumericType, Type}, }, + opt::assert_normalized_ssa_equals, + Ssa, }; + use super::compute_bottom_up_order; + #[test] fn basic_inlining() { // fn foo { @@ -1158,26 +1427,25 @@ mod test { #[test] #[should_panic( - expected = "Attempted to recur more than 1000 times during inlining function 'main': acir(inline) fn main f0 {" + expected = "Attempted to recur more than 1000 times during inlining function 'foo':\nacir(inline) fn foo f1 {" )] fn unconditional_recursion() { - // fn main f1 { - // b0(): - // call f1() - // return - // } - let main_id = Id::test_new(0); - let mut builder = FunctionBuilder::new("main".into(), main_id); - - let main = builder.import_function(main_id); - let results = builder.insert_call(main, Vec::new(), vec![]).to_vec(); - builder.terminate_with_return(results); - - let ssa = builder.finish(); - assert_eq!(ssa.functions.len(), 1); + let src = " + acir(inline) fn main f0 { + b0(): + call f1() + return + } + acir(inline) fn foo f1 { + b0(): + call f1() + return + } + "; + let ssa = Ssa::from_str(src).unwrap(); + assert_eq!(ssa.functions.len(), 2); - let inlined = ssa.inline_functions(i64::MAX); - assert_eq!(inlined.functions.len(), 0); + let _ = ssa.inline_functions(i64::MAX); } #[test] @@ -1267,4 +1535,167 @@ mod test { // No inlining has happened assert_eq!(inlined.functions.len(), 2); } + + #[test] + fn bottom_up_order_and_weights() { + let src = " + brillig(inline) fn main f0 { + b0(v0: u32, v1: u1): + v3 = call f2(v0) -> u1 + v4 = eq v3, v1 + constrain v3 == v1 + return + } + brillig(inline) fn is_even f1 { + b0(v0: u32): + v3 = eq v0, u32 0 + jmpif v3 then: b2, else: b1 + b1(): + v5 = call f3(v0) -> u32 + v7 = call f2(v5) -> u1 + jmp b3(v7) + b2(): + jmp b3(u1 1) + b3(v1: u1): + return v1 + } + brillig(inline) fn is_odd f2 { + b0(v0: u32): + v3 = eq v0, u32 0 + jmpif v3 then: b2, else: b1 + b1(): + v5 = call f3(v0) -> u32 + v7 = call f1(v5) -> u1 + jmp b3(v7) + b2(): + jmp b3(u1 0) + b3(v1: u1): + return v1 + } + brillig(inline) fn decrement f3 { + b0(v0: u32): + v2 = sub v0, u32 1 + return v2 + } + "; + // main + // | + // V + // is_odd <-> is_even + // | | + // V V + // decrement + + let ssa = Ssa::from_str(src).unwrap(); + let order = compute_bottom_up_order(&ssa); + + assert_eq!(order.len(), 4); + let (ids, ws): (Vec<_>, Vec<_>) = order.into_iter().map(|(id, w)| (id.to_u32(), w)).unzip(); + let (ows, tws): (Vec<_>, Vec<_>) = ws.into_iter().unzip(); + + // Check order + assert_eq!(ids[0], 3, "decrement: first, it doesn't call anything"); + assert_eq!(ids[1], 1, "is_even: called by is_odd; removing first avoids cutting the graph"); + assert_eq!(ids[2], 2, "is_odd: called by is_odd and main"); + assert_eq!(ids[3], 0, "main: last, it's the entry"); + + // Check own weights + assert_eq!(ows, [2, 7, 7, 4]); + + // Check transitive weights + assert_eq!(tws[0], ows[0], "decrement"); + assert_eq!( + tws[1], + ows[1] + // own + tws[0] + // pushed from decrement + (ows[2] + tws[0]), // pulled from is_odd at the time is_even is emitted + "is_even" + ); + assert_eq!( + tws[2], + ows[2] + // own + tws[0] + // pushed from decrement + tws[1], // pushed from is_even + "is_odd" + ); + assert_eq!( + tws[3], + ows[3] + // own + tws[2], // pushed from is_odd + "main" + ); + assert!(tws[3] > max(tws[1], tws[2]), "ideally 'main' has the most weight"); + } + + #[test] + fn inline_simple_functions_with_zero_instructions() { + let src = " + acir(inline) fn main f0 { + b0(v0: Field): + v2 = call f1(v0) -> Field + v3 = call f1(v0) -> Field + v4 = add v2, v3 + return v4 + } + + acir(inline) fn foo f1 { + b0(v0: Field): + return v0 + } + "; + let ssa = Ssa::from_str(src).unwrap(); + + let expected = " + acir(inline) fn main f0 { + b0(v0: Field): + v1 = add v0, v0 + return v1 + } + acir(inline) fn foo f1 { + b0(v0: Field): + return v0 + } + "; + + let ssa = ssa.inline_simple_functions(); + assert_normalized_ssa_equals(ssa, expected); + } + + #[test] + fn inline_simple_functions_with_one_instruction() { + let src = " + acir(inline) fn main f0 { + b0(v0: Field): + v2 = call f1(v0) -> Field + v3 = call f1(v0) -> Field + v4 = add v2, v3 + return v4 + } + + acir(inline) fn foo f1 { + b0(v0: Field): + v2 = add v0, Field 1 + return v2 + } + "; + let ssa = Ssa::from_str(src).unwrap(); + + let expected = " + acir(inline) fn main f0 { + b0(v0: Field): + v2 = add v0, Field 1 + v3 = add v0, Field 1 + v4 = add v2, v3 + return v4 + } + acir(inline) fn foo f1 { + b0(v0: Field): + v2 = add v0, Field 1 + return v2 + } + "; + + let ssa = ssa.inline_simple_functions(); + assert_normalized_ssa_equals(ssa, expected); + } } diff --git a/noir/noir-repo/compiler/noirc_evaluator/src/ssa/opt/loop_invariant.rs b/noir/noir-repo/compiler/noirc_evaluator/src/ssa/opt/loop_invariant.rs index 125cf3a12ca..1e2e783d516 100644 --- a/noir/noir-repo/compiler/noirc_evaluator/src/ssa/opt/loop_invariant.rs +++ b/noir/noir-repo/compiler/noirc_evaluator/src/ssa/opt/loop_invariant.rs @@ -16,6 +16,7 @@ use crate::ssa::{ function::Function, function_inserter::FunctionInserter, instruction::{binary::eval_constant_binary_op, BinaryOp, Instruction, InstructionId}, + post_order::PostOrder, types::Type, value::ValueId, }, @@ -36,7 +37,7 @@ impl Ssa { } impl Function { - fn loop_invariant_code_motion(&mut self) { + pub(super) fn loop_invariant_code_motion(&mut self) { Loops::find_all(self).hoist_loop_invariants(self); } } @@ -272,8 +273,10 @@ impl<'f> LoopInvariantContext<'f> { /// correct new value IDs based upon the `FunctionInserter` internal map. /// Leaving out this mapping could lead to instructions with values that do not exist. fn map_dependent_instructions(&mut self) { - let blocks = self.inserter.function.reachable_blocks(); - for block in blocks { + let mut block_order = PostOrder::with_function(self.inserter.function).into_vec(); + block_order.reverse(); + + for block in block_order { for instruction_id in self.inserter.function.dfg[block].take_instructions() { self.inserter.push_instruction(instruction_id, block); } diff --git a/noir/noir-repo/compiler/noirc_evaluator/src/ssa/opt/mod.rs b/noir/noir-repo/compiler/noirc_evaluator/src/ssa/opt/mod.rs index f97d36f0844..44796e2531e 100644 --- a/noir/noir-repo/compiler/noirc_evaluator/src/ssa/opt/mod.rs +++ b/noir/noir-repo/compiler/noirc_evaluator/src/ssa/opt/mod.rs @@ -17,6 +17,7 @@ mod loop_invariant; mod make_constrain_not_equal; mod mem2reg; mod normalize_value_ids; +mod preprocess_fns; mod rc; mod remove_bit_shifts; mod remove_enable_side_effects; diff --git a/noir/noir-repo/compiler/noirc_evaluator/src/ssa/opt/preprocess_fns.rs b/noir/noir-repo/compiler/noirc_evaluator/src/ssa/opt/preprocess_fns.rs new file mode 100644 index 00000000000..ae20c9b8b4a --- /dev/null +++ b/noir/noir-repo/compiler/noirc_evaluator/src/ssa/opt/preprocess_fns.rs @@ -0,0 +1,70 @@ +//! Pre-process functions before inlining them into others. + +use crate::ssa::{ + ir::function::{Function, RuntimeType}, + Ssa, +}; + +use super::inlining::{self, InlineInfo}; + +impl Ssa { + /// Run pre-processing steps on functions in isolation. + pub(crate) fn preprocess_functions(mut self, aggressiveness: i64) -> Ssa { + // Bottom-up order, starting with the "leaf" functions, so we inline already optimized code into the ones that call them. + let bottom_up = inlining::compute_bottom_up_order(&self); + + // Preliminary inlining decisions. + let inline_infos = inlining::compute_inline_infos(&self, false, aggressiveness); + + let should_inline_call = |callee: &Function| -> bool { + match callee.runtime() { + RuntimeType::Acir(_) => { + // Functions marked to not have predicates should be preserved. + !callee.is_no_predicates() + } + RuntimeType::Brillig(_) => { + // We inline inline if the function called wasn't ruled out as too costly or recursive. + InlineInfo::should_inline(&inline_infos, callee.id()) + } + } + }; + + for (id, (own_weight, transitive_weight)) in bottom_up { + let function = &self.functions[&id]; + + // Skip preprocessing heavy functions that gained most of their weight from transitive accumulation, which tend to be near the entry. + // These can be processed later by the regular SSA passes. + let is_heavy = transitive_weight > own_weight * 10; + + // Functions which are inline targets will be processed in later passes. + // Here we want to treat the functions which will be inlined into them. + let is_target = + inline_infos.get(&id).map(|info| info.is_inline_target()).unwrap_or_default(); + + if is_heavy || is_target { + continue; + } + + // Start with an inline pass. + let mut function = function.inlined(&self, &should_inline_call); + // Help unrolling determine bounds. + function.as_slice_optimization(); + // Prepare for unrolling + function.loop_invariant_code_motion(); + // We might not be able to unroll all loops without fully inlining them, so ignore errors. + let _ = function.unroll_loops_iteratively(); + // Reduce the number of redundant stores/loads after unrolling + function.mem2reg(); + // Try to reduce the number of blocks. + function.simplify_function(); + // Remove leftover instructions. + function.dead_instruction_elimination(true, false); + + // Put it back into the SSA, so the next functions can pick it up. + self.functions.insert(id, function); + } + + // Remove any functions that have been inlined into others already. + self.remove_unreachable_functions() + } +} diff --git a/noir/noir-repo/compiler/noirc_evaluator/src/ssa/opt/remove_unreachable.rs b/noir/noir-repo/compiler/noirc_evaluator/src/ssa/opt/remove_unreachable.rs index 41023b5f376..9b80b3a4d23 100644 --- a/noir/noir-repo/compiler/noirc_evaluator/src/ssa/opt/remove_unreachable.rs +++ b/noir/noir-repo/compiler/noirc_evaluator/src/ssa/opt/remove_unreachable.rs @@ -19,9 +19,13 @@ impl Ssa { pub(crate) fn remove_unreachable_functions(mut self) -> Self { let mut used_functions = HashSet::default(); - for function_id in self.functions.keys() { - if self.is_entry_point(*function_id) { - collect_reachable_functions(&self, *function_id, &mut used_functions); + for (id, function) in self.functions.iter() { + // XXX: `self.is_entry_point(*id)` could leave Brillig functions that nobody calls in the SSA. + let is_entry_point = function.id() == self.main_id + || function.runtime().is_acir() && function.runtime().is_entry_point(); + + if is_entry_point { + collect_reachable_functions(&self, *id, &mut used_functions); } } @@ -78,3 +82,54 @@ fn used_functions(func: &Function) -> BTreeSet { used_function_ids } + +#[cfg(test)] +mod tests { + use crate::ssa::opt::assert_normalized_ssa_equals; + + use super::Ssa; + + #[test] + fn remove_unused_brillig() { + let src = " + brillig(inline) fn main f0 { + b0(v0: u32): + v2 = call f1(v0) -> u32 + v4 = add v0, u32 1 + v5 = eq v2, v4 + constrain v2 == v4 + return + } + brillig(inline) fn increment f1 { + b0(v0: u32): + v2 = add v0, u32 1 + return v2 + } + brillig(inline) fn increment_acir f2 { + b0(v0: u32): + v2 = add v0, u32 1 + return v2 + } + "; + + let ssa = Ssa::from_str(src).unwrap(); + let ssa = ssa.remove_unreachable_functions(); + + let expected = " + brillig(inline) fn main f0 { + b0(v0: u32): + v2 = call f1(v0) -> u32 + v4 = add v0, u32 1 + v5 = eq v2, v4 + constrain v2 == v4 + return + } + brillig(inline) fn increment f1 { + b0(v0: u32): + v2 = add v0, u32 1 + return v2 + } + "; + assert_normalized_ssa_equals(ssa, expected); + } +} diff --git a/noir/noir-repo/compiler/noirc_evaluator/src/ssa/opt/unrolling.rs b/noir/noir-repo/compiler/noirc_evaluator/src/ssa/opt/unrolling.rs index 79181b7e74e..eb0bbd8c532 100644 --- a/noir/noir-repo/compiler/noirc_evaluator/src/ssa/opt/unrolling.rs +++ b/noir/noir-repo/compiler/noirc_evaluator/src/ssa/opt/unrolling.rs @@ -24,7 +24,10 @@ use acvm::{acir::AcirField, FieldElement}; use im::HashSet; use crate::{ - brillig::brillig_gen::convert_ssa_function, + brillig::{ + brillig_gen::{brillig_globals::convert_ssa_globals, convert_ssa_function}, + brillig_ir::brillig_variable::BrilligVariable, + }, errors::RuntimeError, ssa::{ ir::{ @@ -54,41 +57,40 @@ impl Ssa { /// fewer SSA instructions, but that can still result in more Brillig opcodes. #[tracing::instrument(level = "trace", skip(self))] pub(crate) fn unroll_loops_iteratively( - mut self: Ssa, + mut self, max_bytecode_increase_percent: Option, ) -> Result { - for (_, function) in self.functions.iter_mut() { - // Take a snapshot of the function to compare byte size increase, - // but only if the setting indicates we have to, otherwise skip it. - let orig_func_and_max_incr_pct = max_bytecode_increase_percent - .filter(|_| function.runtime().is_brillig()) - .map(|max_incr_pct| (function.clone(), max_incr_pct)); - - // Try to unroll loops first: - let (mut has_unrolled, mut unroll_errors) = function.try_unroll_loops(); - - // Keep unrolling until no more errors are found - while !unroll_errors.is_empty() { - let prev_unroll_err_count = unroll_errors.len(); - - // Simplify the SSA before retrying - simplify_between_unrolls(function); - - // Unroll again - let (new_unrolled, new_errors) = function.try_unroll_loops(); - unroll_errors = new_errors; - has_unrolled |= new_unrolled; - - // If we didn't manage to unroll any more loops, exit - if unroll_errors.len() >= prev_unroll_err_count { - return Err(unroll_errors.swap_remove(0)); - } - } + let mut global_cache = None; + + for function in self.functions.values_mut() { + let is_brillig = function.runtime().is_brillig(); + + // Take a snapshot in case we have to restore it. + let orig_function = + (max_bytecode_increase_percent.is_some() && is_brillig).then(|| function.clone()); + + // We must be able to unroll ACIR loops at this point, so exit on failure to unroll. + let has_unrolled = function.unroll_loops_iteratively()?; + + // Check if the size increase is acceptable + // This is here now instead of in `Function::unroll_loops_iteratively` because we'd need + // more finessing to convince the borrow checker that it's okay to share a read-only reference + // to the globals and a mutable reference to the function at the same time, both part of the `Ssa`. + if has_unrolled && is_brillig { + if let Some(max_incr_pct) = max_bytecode_increase_percent { + if global_cache.is_none() { + let globals = (*function.dfg.globals).clone(); + // DIE is run at the end of our SSA optimizations, so we mark all globals as in use here. + let used_globals = &globals.values_iter().map(|(id, _)| id).collect(); + let (_, brillig_globals, _) = + convert_ssa_globals(false, globals, used_globals); + global_cache = Some(brillig_globals); + } + let brillig_globals = global_cache.as_ref().unwrap(); - if has_unrolled { - if let Some((orig_function, max_incr_pct)) = orig_func_and_max_incr_pct { - let new_size = brillig_bytecode_size(function); - let orig_size = brillig_bytecode_size(&orig_function); + let orig_function = orig_function.expect("took snapshot to compare"); + let new_size = brillig_bytecode_size(function, brillig_globals); + let orig_size = brillig_bytecode_size(&orig_function, brillig_globals); if !is_new_size_ok(orig_size, new_size, max_incr_pct) { *function = orig_function; } @@ -100,6 +102,38 @@ impl Ssa { } impl Function { + /// Try to unroll loops in the function. + /// + /// Returns an `Err` if it cannot be done, for example because the loop bounds + /// cannot be determined at compile time. This can happen during pre-processing, + /// but it should still leave the function in a partially unrolled, but valid state. + /// + /// If successful, returns a flag indicating whether any loops have been unrolled. + pub(super) fn unroll_loops_iteratively(&mut self) -> Result { + // Try to unroll loops first: + let (mut has_unrolled, mut unroll_errors) = self.try_unroll_loops(); + + // Keep unrolling until no more errors are found + while !unroll_errors.is_empty() { + let prev_unroll_err_count = unroll_errors.len(); + + // Simplify the SSA before retrying + simplify_between_unrolls(self); + + // Unroll again + let (new_unrolled, new_errors) = self.try_unroll_loops(); + unroll_errors = new_errors; + has_unrolled |= new_unrolled; + + // If we didn't manage to unroll any more loops, exit + if unroll_errors.len() >= prev_unroll_err_count { + return Err(unroll_errors.swap_remove(0)); + } + } + + Ok(has_unrolled) + } + // Loop unrolling in brillig can lead to a code explosion currently. // This can also be true for ACIR, but we have no alternative to unrolling in ACIR. // Brillig also generally prefers smaller code rather than faster code, @@ -310,11 +344,13 @@ impl Loop { // simplified to a simple jump. return None; } - assert_eq!( - instructions.len(), - 1, - "The header should just compare the induction variable and jump" - ); + + if instructions.len() != 1 { + // The header should just compare the induction variable and jump. + // If that's not the case, this might be a `loop` and not a `for` loop. + return None; + } + match &function.dfg[instructions[0]] { Instruction::Binary(Binary { lhs: _, operator: BinaryOp::Lt, rhs }) => { function.dfg.get_numeric_constant(*rhs) @@ -750,7 +786,13 @@ fn get_induction_variable(function: &Function, block: BasicBlockId) -> Result usize { +fn brillig_bytecode_size( + function: &Function, + globals: &HashMap, +) -> usize { // We need to do some SSA passes in order for the conversion to be able to go ahead, // otherwise we can hit `unreachable!()` instructions in `convert_ssa_instruction`. // Creating a clone so as not to modify the originals. @@ -990,9 +1035,9 @@ fn brillig_bytecode_size(function: &Function) -> usize { simplify_between_unrolls(&mut temp); // This is to try to prevent hitting ICE. - temp.dead_instruction_elimination(false); + temp.dead_instruction_elimination(false, true); - convert_ssa_function(&temp, false).byte_code.len() + convert_ssa_function(&temp, false, globals).byte_code.len() } /// Decide if the new bytecode size is acceptable, compared to the original. diff --git a/noir/noir-repo/compiler/noirc_evaluator/src/ssa/parser/ast.rs b/noir/noir-repo/compiler/noirc_evaluator/src/ssa/parser/ast.rs index 6c7608a2f16..05743ffd7ca 100644 --- a/noir/noir-repo/compiler/noirc_evaluator/src/ssa/parser/ast.rs +++ b/noir/noir-repo/compiler/noirc_evaluator/src/ssa/parser/ast.rs @@ -7,9 +7,28 @@ use crate::ssa::ir::{function::RuntimeType, instruction::BinaryOp, types::Type}; #[derive(Debug)] pub(crate) struct ParsedSsa { + pub(crate) globals: Vec, pub(crate) functions: Vec, } +#[derive(Debug)] +pub(crate) struct ParsedGlobal { + pub(crate) name: Identifier, + pub(crate) value: ParsedGlobalValue, +} + +#[derive(Debug)] +pub(crate) enum ParsedGlobalValue { + NumericConstant(ParsedNumericConstant), + MakeArray(ParsedMakeArray), +} + +#[derive(Debug)] +pub(crate) struct ParsedMakeArray { + pub(crate) elements: Vec, + pub(crate) typ: Type, +} + #[derive(Debug)] pub(crate) struct ParsedFunction { pub(crate) runtime_type: RuntimeType, @@ -145,6 +164,12 @@ pub(crate) enum ParsedTerminator { #[derive(Debug, Clone)] pub(crate) enum ParsedValue { - NumericConstant { constant: FieldElement, typ: Type }, + NumericConstant(ParsedNumericConstant), Variable(Identifier), } + +#[derive(Debug, Clone)] +pub(crate) struct ParsedNumericConstant { + pub(crate) value: FieldElement, + pub(crate) typ: Type, +} diff --git a/noir/noir-repo/compiler/noirc_evaluator/src/ssa/parser/into_ssa.rs b/noir/noir-repo/compiler/noirc_evaluator/src/ssa/parser/into_ssa.rs index fcaaf74f533..37d2cd720f9 100644 --- a/noir/noir-repo/compiler/noirc_evaluator/src/ssa/parser/into_ssa.rs +++ b/noir/noir-repo/compiler/noirc_evaluator/src/ssa/parser/into_ssa.rs @@ -1,18 +1,22 @@ -use std::collections::HashMap; +use std::{collections::HashMap, sync::Arc}; use acvm::acir::circuit::ErrorSelector; use crate::ssa::{ function_builder::FunctionBuilder, ir::{ - basic_block::BasicBlockId, function::FunctionId, instruction::ConstrainError, + basic_block::BasicBlockId, + call_stack::CallStackId, + dfg::GlobalsGraph, + function::{Function, FunctionId}, + instruction::{ConstrainError, Instruction}, value::ValueId, }, }; use super::{ - ast::AssertMessage, Identifier, ParsedBlock, ParsedFunction, ParsedInstruction, ParsedSsa, - ParsedTerminator, ParsedValue, RuntimeType, Ssa, SsaError, + ast::AssertMessage, Identifier, ParsedBlock, ParsedFunction, ParsedGlobal, ParsedGlobalValue, + ParsedInstruction, ParsedSsa, ParsedTerminator, ParsedValue, RuntimeType, Ssa, SsaError, Type, }; impl ParsedSsa { @@ -24,7 +28,7 @@ impl ParsedSsa { struct Translator { builder: FunctionBuilder, - /// Maps function names to their IDs + /// Maps internal function names (e.g. "f1") to their IDs functions: HashMap, /// Maps block names to their IDs @@ -37,6 +41,17 @@ struct Translator { /// will recreate the SSA step by step, which can result in a new ID layout. variables: HashMap>, + /// The function that will hold the actual SSA globals. + globals_function: Function, + + /// The types of globals in the parsed SSA, in the order they were defined. + global_types: Vec, + + /// Maps names (e.g. "g0") in the parsed SSA to global IDs. + global_values: HashMap, + + globals_graph: Arc, + error_selector_counter: u64, } @@ -72,13 +87,26 @@ impl Translator { functions.insert(function.internal_name.clone(), function_id); } + // Does not matter what ID we use here. + let globals = Function::new("globals".to_owned(), main_id); + let mut translator = Self { builder, functions, variables: HashMap::new(), blocks: HashMap::new(), + globals_function: globals, + global_types: Vec::new(), + global_values: HashMap::new(), + globals_graph: Arc::new(GlobalsGraph::default()), error_selector_counter: 0, }; + + translator.translate_globals(std::mem::take(&mut parsed_ssa.globals))?; + + translator.globals_graph = + Arc::new(GlobalsGraph::from_dfg(translator.globals_function.dfg.clone())); + translator.translate_function_body(main_function)?; Ok(translator) @@ -101,6 +129,8 @@ impl Translator { } fn translate_function_body(&mut self, function: ParsedFunction) -> Result<(), SsaError> { + self.builder.set_globals(self.globals_graph.clone()); + // First define all blocks so that they are known (a block might jump to a block that comes next) for (index, block) in function.blocks.iter().enumerate() { // The first block is the entry block and it was automatically created by the builder @@ -135,14 +165,14 @@ impl Translator { match block.terminator { ParsedTerminator::Jmp { destination, arguments } => { - let block_id = self.lookup_block(destination)?; + let block_id = self.lookup_block(&destination)?; let arguments = self.translate_values(arguments)?; self.builder.terminate_with_jmp(block_id, arguments); } ParsedTerminator::Jmpif { condition, then_block, else_block } => { let condition = self.translate_value(condition)?; - let then_destination = self.lookup_block(then_block)?; - let else_destination = self.lookup_block(else_block)?; + let then_destination = self.lookup_block(&then_block)?; + let else_destination = self.lookup_block(&else_block)?; self.builder.terminate_with_jmpif(condition, then_destination, else_destination); } ParsedTerminator::Return(values) => { @@ -187,8 +217,13 @@ impl Translator { let function_id = if let Some(id) = self.builder.import_intrinsic(&function.name) { id } else { - let function_id = self.lookup_function(function)?; - self.builder.import_function(function_id) + let maybe_func = + self.lookup_function(&function).map(|f| self.builder.import_function(f)); + + maybe_func.or_else(|e| { + // e.g. `v2 = call v0(v1) -> u32`, a lambda passed as a parameter + self.lookup_variable(&function).map_err(|_| e) + })? }; let arguments = self.translate_values(arguments)?; @@ -290,13 +325,59 @@ impl Translator { fn translate_value(&mut self, value: ParsedValue) -> Result { match value { - ParsedValue::NumericConstant { constant, typ } => { - Ok(self.builder.numeric_constant(constant, typ.unwrap_numeric())) + ParsedValue::NumericConstant(constant) => { + Ok(self.builder.numeric_constant(constant.value, constant.typ.unwrap_numeric())) } - ParsedValue::Variable(identifier) => self.lookup_variable(identifier), + ParsedValue::Variable(identifier) => self.lookup_variable(&identifier).or_else(|e| { + self.lookup_function(&identifier) + .map(|f| { + // e.g. `v3 = call f1(f2, v0) -> u32` + self.builder.import_function(f) + }) + .map_err(|_| e) + }), } } + fn translate_globals(&mut self, globals: Vec) -> Result<(), SsaError> { + for global in globals { + self.translate_global(global)?; + } + Ok(()) + } + + fn translate_global(&mut self, global: ParsedGlobal) -> Result<(), SsaError> { + let value_id = match global.value { + ParsedGlobalValue::NumericConstant(constant) => self + .globals_function + .dfg + .make_constant(constant.value, constant.typ.unwrap_numeric()), + ParsedGlobalValue::MakeArray(make_array) => { + let mut elements = im::Vector::new(); + for element in make_array.elements { + let element_id = match element { + ParsedValue::NumericConstant(constant) => self + .globals_function + .dfg + .make_constant(constant.value, constant.typ.unwrap_numeric()), + ParsedValue::Variable(identifier) => self.lookup_global(identifier)?, + }; + elements.push_back(element_id); + } + + let instruction = Instruction::MakeArray { elements, typ: make_array.typ.clone() }; + let block = self.globals_function.entry_block(); + let call_stack = CallStackId::root(); + self.globals_function + .dfg + .insert_instruction_and_results(instruction, block, None, call_stack) + .first() + } + }; + + self.define_global(global.name, value_id) + } + fn define_variable( &mut self, identifier: Identifier, @@ -314,37 +395,66 @@ impl Translator { Ok(()) } - fn lookup_variable(&mut self, identifier: Identifier) -> Result { - if let Some(value_id) = self.variables[&self.current_function_id()].get(&identifier.name) { + fn lookup_variable(&mut self, identifier: &Identifier) -> Result { + if let Some(value_id) = self + .variables + .get(&self.current_function_id()) + .and_then(|hash| hash.get(&identifier.name)) + { + Ok(*value_id) + } else if let Some(value_id) = self.global_values.get(&identifier.name) { + Ok(*value_id) + } else { + Err(SsaError::UnknownVariable(identifier.clone())) + } + } + + fn define_global(&mut self, identifier: Identifier, value_id: ValueId) -> Result<(), SsaError> { + if self.global_values.contains_key(&identifier.name) { + return Err(SsaError::GlobalAlreadyDefined(identifier)); + } + + self.global_values.insert(identifier.name, value_id); + + let typ = self.globals_function.dfg.type_of_value(value_id); + self.global_types.push(typ); + + Ok(()) + } + + fn lookup_global(&mut self, identifier: Identifier) -> Result { + if let Some(value_id) = self.global_values.get(&identifier.name) { Ok(*value_id) } else { - Err(SsaError::UnknownVariable(identifier)) + Err(SsaError::UnknownGlobal(identifier)) } } - fn lookup_block(&mut self, identifier: Identifier) -> Result { + fn lookup_block(&mut self, identifier: &Identifier) -> Result { if let Some(block_id) = self.blocks[&self.current_function_id()].get(&identifier.name) { Ok(*block_id) } else { - Err(SsaError::UnknownBlock(identifier)) + Err(SsaError::UnknownBlock(identifier.clone())) } } - fn lookup_function(&mut self, identifier: Identifier) -> Result { + fn lookup_function(&mut self, identifier: &Identifier) -> Result { if let Some(function_id) = self.functions.get(&identifier.name) { Ok(*function_id) } else { - Err(SsaError::UnknownFunction(identifier)) + Err(SsaError::UnknownFunction(identifier.clone())) } } fn finish(self) -> Ssa { let mut ssa = self.builder.finish(); + // Normalize the IDs so we have a better chance of matching the SSA we parsed // after the step-by-step reconstruction done during translation. This assumes // that the SSA we parsed was printed by the `SsaBuilder`, which normalizes // before each print. ssa.normalize_ids(); + ssa } diff --git a/noir/noir-repo/compiler/noirc_evaluator/src/ssa/parser/mod.rs b/noir/noir-repo/compiler/noirc_evaluator/src/ssa/parser/mod.rs index 143ba511879..cc660355bbd 100644 --- a/noir/noir-repo/compiler/noirc_evaluator/src/ssa/parser/mod.rs +++ b/noir/noir-repo/compiler/noirc_evaluator/src/ssa/parser/mod.rs @@ -13,8 +13,9 @@ use super::{ use acvm::{AcirField, FieldElement}; use ast::{ - AssertMessage, Identifier, ParsedBlock, ParsedFunction, ParsedInstruction, ParsedParameter, - ParsedSsa, ParsedValue, + AssertMessage, Identifier, ParsedBlock, ParsedFunction, ParsedGlobal, ParsedGlobalValue, + ParsedInstruction, ParsedMakeArray, ParsedNumericConstant, ParsedParameter, ParsedSsa, + ParsedValue, }; use lexer::{Lexer, LexerError}; use noirc_errors::Span; @@ -99,6 +100,8 @@ pub(crate) enum SsaError { ParserError(ParserError), #[error("Unknown variable '{0}'")] UnknownVariable(Identifier), + #[error("Unknown global '{0}'")] + UnknownGlobal(Identifier), #[error("Unknown block '{0}'")] UnknownBlock(Identifier), #[error("Unknown function '{0}'")] @@ -107,6 +110,8 @@ pub(crate) enum SsaError { MismatchedReturnValues { returns: Vec, expected: usize }, #[error("Variable '{0}' already defined")] VariableAlreadyDefined(Identifier), + #[error("Global '{0}' already defined")] + GlobalAlreadyDefined(Identifier), } impl SsaError { @@ -114,8 +119,10 @@ impl SsaError { match self { SsaError::ParserError(parser_error) => parser_error.span(), SsaError::UnknownVariable(identifier) + | SsaError::UnknownGlobal(identifier) | SsaError::UnknownBlock(identifier) | SsaError::VariableAlreadyDefined(identifier) + | SsaError::GlobalAlreadyDefined(identifier) | SsaError::UnknownFunction(identifier) => identifier.span, SsaError::MismatchedReturnValues { returns, expected: _ } => returns[0].span, } @@ -138,12 +145,39 @@ impl<'a> Parser<'a> { } pub(crate) fn parse_ssa(&mut self) -> ParseResult { + let globals = self.parse_globals()?; + let mut functions = Vec::new(); while !self.at(Token::Eof) { let function = self.parse_function()?; functions.push(function); } - Ok(ParsedSsa { functions }) + Ok(ParsedSsa { globals, functions }) + } + + fn parse_globals(&mut self) -> ParseResult> { + let mut globals = Vec::new(); + + while let Some(name) = self.eat_identifier()? { + self.eat_or_error(Token::Assign)?; + + let value = self.parse_global_value()?; + globals.push(ParsedGlobal { name, value }); + } + + Ok(globals) + } + + fn parse_global_value(&mut self) -> ParseResult { + if let Some(constant) = self.parse_numeric_constant()? { + return Ok(ParsedGlobalValue::NumericConstant(constant)); + } + + if let Some(make_array) = self.parse_make_array()? { + return Ok(ParsedGlobalValue::MakeArray(make_array)); + } + + self.expected_global_value() } fn parse_function(&mut self) -> ParseResult { @@ -461,40 +495,12 @@ impl<'a> Parser<'a> { return Ok(ParsedInstruction::Load { target, value, typ }); } - if self.eat_keyword(Keyword::MakeArray)? { - if self.eat(Token::Ampersand)? { - let Some(string) = self.eat_byte_str()? else { - return self.expected_byte_string(); - }; - let u8 = Type::Numeric(NumericType::Unsigned { bit_size: 8 }); - let typ = Type::Slice(Arc::new(vec![u8.clone()])); - let elements = string - .bytes() - .map(|byte| ParsedValue::NumericConstant { - constant: FieldElement::from(byte as u128), - typ: u8.clone(), - }) - .collect(); - return Ok(ParsedInstruction::MakeArray { target, elements, typ }); - } else if let Some(string) = self.eat_byte_str()? { - let u8 = Type::Numeric(NumericType::Unsigned { bit_size: 8 }); - let typ = Type::Array(Arc::new(vec![u8.clone()]), string.len() as u32); - let elements = string - .bytes() - .map(|byte| ParsedValue::NumericConstant { - constant: FieldElement::from(byte as u128), - typ: u8.clone(), - }) - .collect(); - return Ok(ParsedInstruction::MakeArray { target, elements, typ }); - } else { - self.eat_or_error(Token::LeftBracket)?; - let elements = self.parse_comma_separated_values()?; - self.eat_or_error(Token::RightBracket)?; - self.eat_or_error(Token::Colon)?; - let typ = self.parse_type()?; - return Ok(ParsedInstruction::MakeArray { target, elements, typ }); - } + if let Some(make_array) = self.parse_make_array()? { + return Ok(ParsedInstruction::MakeArray { + target, + elements: make_array.elements, + typ: make_array.typ, + }); } if self.eat_keyword(Keyword::Not)? { @@ -524,6 +530,52 @@ impl<'a> Parser<'a> { self.expected_instruction_or_terminator() } + fn parse_make_array(&mut self) -> ParseResult> { + if !self.eat_keyword(Keyword::MakeArray)? { + return Ok(None); + } + + let make_array = if self.eat(Token::Ampersand)? { + let Some(string) = self.eat_byte_str()? else { + return self.expected_byte_string(); + }; + let u8 = Type::Numeric(NumericType::Unsigned { bit_size: 8 }); + let typ = Type::Slice(Arc::new(vec![u8.clone()])); + let elements = string + .bytes() + .map(|byte| { + ParsedValue::NumericConstant(ParsedNumericConstant { + value: FieldElement::from(byte as u128), + typ: u8.clone(), + }) + }) + .collect(); + ParsedMakeArray { elements, typ } + } else if let Some(string) = self.eat_byte_str()? { + let u8 = Type::Numeric(NumericType::Unsigned { bit_size: 8 }); + let typ = Type::Array(Arc::new(vec![u8.clone()]), string.len() as u32); + let elements = string + .bytes() + .map(|byte| { + ParsedValue::NumericConstant(ParsedNumericConstant { + value: FieldElement::from(byte as u128), + typ: u8.clone(), + }) + }) + .collect(); + ParsedMakeArray { elements, typ } + } else { + self.eat_or_error(Token::LeftBracket)?; + let elements = self.parse_comma_separated_values()?; + self.eat_or_error(Token::RightBracket)?; + self.eat_or_error(Token::Colon)?; + let typ = self.parse_type()?; + ParsedMakeArray { elements, typ } + }; + + Ok(Some(make_array)) + } + fn parse_terminator(&mut self) -> ParseResult { if let Some(terminator) = self.parse_return()? { return Ok(terminator); @@ -617,12 +669,8 @@ impl<'a> Parser<'a> { } fn parse_value(&mut self) -> ParseResult> { - if let Some(value) = self.parse_field_value()? { - return Ok(Some(value)); - } - - if let Some(value) = self.parse_int_value()? { - return Ok(Some(value)); + if let Some(constant) = self.parse_numeric_constant()? { + return Ok(Some(ParsedValue::NumericConstant(constant))); } if let Some(identifier) = self.eat_identifier()? { @@ -632,23 +680,35 @@ impl<'a> Parser<'a> { Ok(None) } - fn parse_field_value(&mut self) -> ParseResult> { + fn parse_numeric_constant(&mut self) -> ParseResult> { + if let Some(constant) = self.parse_field_value()? { + return Ok(Some(constant)); + } + + if let Some(constant) = self.parse_int_value()? { + return Ok(Some(constant)); + } + + Ok(None) + } + + fn parse_field_value(&mut self) -> ParseResult> { if self.eat_keyword(Keyword::Field)? { - let constant = self.eat_int_or_error()?; - Ok(Some(ParsedValue::NumericConstant { constant, typ: Type::field() })) + let value = self.eat_int_or_error()?; + Ok(Some(ParsedNumericConstant { value, typ: Type::field() })) } else { Ok(None) } } - fn parse_int_value(&mut self) -> ParseResult> { + fn parse_int_value(&mut self) -> ParseResult> { if let Some(int_type) = self.eat_int_type()? { - let constant = self.eat_int_or_error()?; + let value = self.eat_int_or_error()?; let typ = match int_type { IntType::Unsigned(bit_size) => Type::unsigned(bit_size), IntType::Signed(bit_size) => Type::signed(bit_size), }; - Ok(Some(ParsedValue::NumericConstant { constant, typ })) + Ok(Some(ParsedNumericConstant { value, typ })) } else { Ok(None) } @@ -932,6 +992,13 @@ impl<'a> Parser<'a> { }) } + fn expected_global_value(&mut self) -> ParseResult { + Err(ParserError::ExpectedGlobalValue { + found: self.token.token().clone(), + span: self.token.to_span(), + }) + } + fn expected_token(&mut self, token: Token) -> ParseResult { Err(ParserError::ExpectedToken { token, @@ -971,6 +1038,10 @@ pub(crate) enum ParserError { ExpectedByteString { found: Token, span: Span }, #[error("Expected a value, found '{found}'")] ExpectedValue { found: Token, span: Span }, + #[error( + "Expected a global value (Field literal, integer literal or make_array), found '{found}'" + )] + ExpectedGlobalValue { found: Token, span: Span }, #[error("Multiple return values only allowed for call")] MultipleReturnValuesOnlyAllowedForCall { second_target: Identifier }, } @@ -987,7 +1058,8 @@ impl ParserError { | ParserError::ExpectedInstructionOrTerminator { span, .. } | ParserError::ExpectedStringOrData { span, .. } | ParserError::ExpectedByteString { span, .. } - | ParserError::ExpectedValue { span, .. } => *span, + | ParserError::ExpectedValue { span, .. } + | ParserError::ExpectedGlobalValue { span, .. } => *span, ParserError::MultipleReturnValuesOnlyAllowedForCall { second_target, .. } => { second_target.span } diff --git a/noir/noir-repo/compiler/noirc_evaluator/src/ssa/parser/tests.rs b/noir/noir-repo/compiler/noirc_evaluator/src/ssa/parser/tests.rs index 8c24b2ec458..c803e2a94fe 100644 --- a/noir/noir-repo/compiler/noirc_evaluator/src/ssa/parser/tests.rs +++ b/noir/noir-repo/compiler/noirc_evaluator/src/ssa/parser/tests.rs @@ -530,3 +530,19 @@ fn test_does_not_simplify() { "; assert_ssa_roundtrip(src); } + +#[test] +fn parses_globals() { + let src = " + g0 = Field 0 + g1 = u32 1 + g2 = make_array [] : [Field; 0] + g3 = make_array [g2] : [[Field; 0]; 1] + + acir(inline) fn main f0 { + b0(): + return g3 + } + "; + assert_ssa_roundtrip(src); +} diff --git a/noir/noir-repo/compiler/noirc_evaluator/src/ssa/ssa_gen/context.rs b/noir/noir-repo/compiler/noirc_evaluator/src/ssa/ssa_gen/context.rs index 0b778ef9b78..a845c5654b2 100644 --- a/noir/noir-repo/compiler/noirc_evaluator/src/ssa/ssa_gen/context.rs +++ b/noir/noir-repo/compiler/noirc_evaluator/src/ssa/ssa_gen/context.rs @@ -88,7 +88,8 @@ pub(super) struct SharedContext { #[derive(Copy, Clone)] pub(super) struct Loop { pub(super) loop_entry: BasicBlockId, - pub(super) loop_index: ValueId, + /// The loop index will be `Some` for a `for` and `None` for a `loop` + pub(super) loop_index: Option, pub(super) loop_end: BasicBlockId, } @@ -1010,13 +1011,8 @@ impl<'a> FunctionContext<'a> { } } - pub(crate) fn enter_loop( - &mut self, - loop_entry: BasicBlockId, - loop_index: ValueId, - loop_end: BasicBlockId, - ) { - self.loops.push(Loop { loop_entry, loop_index, loop_end }); + pub(crate) fn enter_loop(&mut self, loop_: Loop) { + self.loops.push(loop_); } pub(crate) fn exit_loop(&mut self) { diff --git a/noir/noir-repo/compiler/noirc_evaluator/src/ssa/ssa_gen/mod.rs b/noir/noir-repo/compiler/noirc_evaluator/src/ssa/ssa_gen/mod.rs index d73a5946b4c..c65bc9ba7cf 100644 --- a/noir/noir-repo/compiler/noirc_evaluator/src/ssa/ssa_gen/mod.rs +++ b/noir/noir-repo/compiler/noirc_evaluator/src/ssa/ssa_gen/mod.rs @@ -6,7 +6,7 @@ use acvm::AcirField; use noirc_frontend::token::FmtStrFragment; pub(crate) use program::Ssa; -use context::SharedContext; +use context::{Loop, SharedContext}; use iter_extended::{try_vecmap, vecmap}; use noirc_errors::Location; use noirc_frontend::ast::{UnaryOp, Visibility}; @@ -48,9 +48,10 @@ pub(crate) fn generate_ssa(program: Program) -> Result { let is_return_data = matches!(program.return_visibility, Visibility::ReturnData); let return_location = program.return_location; - let context = SharedContext::new(program); + let mut context = SharedContext::new(program); - let globals = GlobalsGraph::from_dfg(context.globals_context.dfg.clone()); + let globals_dfg = std::mem::take(&mut context.globals_context.dfg); + let globals = GlobalsGraph::from_dfg(globals_dfg); let main_id = Program::main_id(); let main = context.program.main(); @@ -124,8 +125,7 @@ pub(crate) fn generate_ssa(program: Program) -> Result { function_context.codegen_function_body(&function.body)?; } - let mut ssa = function_context.builder.finish(); - ssa.globals = context.globals_context; + let ssa = function_context.builder.finish(); Ok(ssa) } @@ -152,6 +152,7 @@ impl<'a> FunctionContext<'a> { Expression::Index(index) => self.codegen_index(index), Expression::Cast(cast) => self.codegen_cast(cast), Expression::For(for_expr) => self.codegen_for(for_expr), + Expression::Loop(block) => self.codegen_loop(block), Expression::If(if_expr) => self.codegen_if(if_expr), Expression::Tuple(tuple) => self.codegen_tuple(tuple), Expression::ExtractTupleField(tuple, index) => { @@ -557,7 +558,7 @@ impl<'a> FunctionContext<'a> { // Remember the blocks and variable used in case there are break/continue instructions // within the loop which need to jump to them. - self.enter_loop(loop_entry, loop_index, loop_end); + self.enter_loop(Loop { loop_entry, loop_index: Some(loop_index), loop_end }); // Set the location of the initial jmp instruction to the start range. This is the location // used to issue an error if the start range cannot be determined at compile-time. @@ -587,6 +588,38 @@ impl<'a> FunctionContext<'a> { Ok(Self::unit_value()) } + /// Codegens a loop, creating three new blocks in the process. + /// The return value of a loop is always a unit literal. + /// + /// For example, the loop `loop { body }` is codegen'd as: + /// + /// ```text + /// br loop_body() + /// loop_body(): + /// v3 = ... codegen body ... + /// br loop_body() + /// loop_end(): + /// ... This is the current insert point after codegen_for finishes ... + /// ``` + fn codegen_loop(&mut self, block: &Expression) -> Result { + let loop_body = self.builder.insert_block(); + let loop_end = self.builder.insert_block(); + + self.enter_loop(Loop { loop_entry: loop_body, loop_index: None, loop_end }); + + self.builder.terminate_with_jmp(loop_body, vec![]); + + // Compile the loop body + self.builder.switch_to_block(loop_body); + self.codegen_expression(block)?; + self.builder.terminate_with_jmp(loop_body, vec![]); + + // Finish by switching to the end of the loop + self.builder.switch_to_block(loop_end); + self.exit_loop(); + Ok(Self::unit_value()) + } + /// Codegens an if expression, handling the case of what to do if there is no 'else'. /// /// For example, the expression `if cond { a } else { b }` is codegen'd as: @@ -852,8 +885,12 @@ impl<'a> FunctionContext<'a> { let loop_ = self.current_loop(); // Must remember to increment i before jumping - let new_loop_index = self.make_offset(loop_.loop_index, 1); - self.builder.terminate_with_jmp(loop_.loop_entry, vec![new_loop_index]); + if let Some(loop_index) = loop_.loop_index { + let new_loop_index = self.make_offset(loop_index, 1); + self.builder.terminate_with_jmp(loop_.loop_entry, vec![new_loop_index]); + } else { + self.builder.terminate_with_jmp(loop_.loop_entry, vec![]); + } Self::unit_value() } } diff --git a/noir/noir-repo/compiler/noirc_evaluator/src/ssa/ssa_gen/program.rs b/noir/noir-repo/compiler/noirc_evaluator/src/ssa/ssa_gen/program.rs index 305ee16446d..04986bd8db1 100644 --- a/noir/noir-repo/compiler/noirc_evaluator/src/ssa/ssa_gen/program.rs +++ b/noir/noir-repo/compiler/noirc_evaluator/src/ssa/ssa_gen/program.rs @@ -1,6 +1,7 @@ use std::collections::BTreeMap; use acvm::acir::circuit::ErrorSelector; +use fxhash::FxHashSet as HashSet; use iter_extended::btree_map; use serde::{Deserialize, Serialize}; use serde_with::serde_as; @@ -11,13 +12,15 @@ use crate::ssa::ir::{ }; use noirc_frontend::hir_def::types::Type as HirType; +use super::ValueId; + /// Contains the entire SSA representation of the program. #[serde_as] #[derive(Serialize, Deserialize)] pub(crate) struct Ssa { #[serde_as(as = "Vec<(_, _)>")] pub(crate) functions: BTreeMap, - pub(crate) globals: Function, + pub(crate) used_global_values: HashSet, pub(crate) main_id: FunctionId, #[serde(skip)] pub(crate) next_id: AtomicCounter, @@ -54,9 +57,9 @@ impl Ssa { next_id: AtomicCounter::starting_after(max_id), entry_point_to_generated_index: BTreeMap::new(), error_selector_to_type: error_types, - // This field should be set afterwards as globals are generated - // outside of the FunctionBuilder, which is where the `Ssa` is instantiated. - globals: Function::new_for_globals(), + // This field is set only after running DIE and is utilized + // for optimizing implementation of globals post-SSA. + used_global_values: HashSet::default(), } } diff --git a/noir/noir-repo/compiler/noirc_frontend/src/ast/enumeration.rs b/noir/noir-repo/compiler/noirc_frontend/src/ast/enumeration.rs new file mode 100644 index 00000000000..eeeb823b9fc --- /dev/null +++ b/noir/noir-repo/compiler/noirc_frontend/src/ast/enumeration.rs @@ -0,0 +1,50 @@ +use std::fmt::Display; + +use crate::ast::{Ident, UnresolvedGenerics, UnresolvedType}; +use crate::token::SecondaryAttribute; + +use iter_extended::vecmap; +use noirc_errors::Span; + +use super::{Documented, ItemVisibility}; + +/// Ast node for an enum +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct NoirEnumeration { + pub name: Ident, + pub attributes: Vec, + pub visibility: ItemVisibility, + pub generics: UnresolvedGenerics, + pub variants: Vec>, + pub span: Span, +} + +impl NoirEnumeration { + pub fn is_abi(&self) -> bool { + self.attributes.iter().any(|attr| attr.is_abi()) + } +} + +/// We only support variants of the form `Name(A, B, ...)` currently. +/// Enum variants like `Name { a: A, b: B, .. }` will be implemented later +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct EnumVariant { + pub name: Ident, + pub parameters: Vec, +} + +impl Display for NoirEnumeration { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let generics = vecmap(&self.generics, |generic| generic.to_string()); + let generics = if generics.is_empty() { "".into() } else { generics.join(", ") }; + + writeln!(f, "enum {}{} {{", self.name, generics)?; + + for variant in self.variants.iter() { + let parameters = vecmap(&variant.item.parameters, ToString::to_string).join(", "); + writeln!(f, " {}({}),", variant.item.name, parameters)?; + } + + write!(f, "}}") + } +} diff --git a/noir/noir-repo/compiler/noirc_frontend/src/ast/expression.rs b/noir/noir-repo/compiler/noirc_frontend/src/ast/expression.rs index 9d521545e7a..1f7a37428b2 100644 --- a/noir/noir-repo/compiler/noirc_frontend/src/ast/expression.rs +++ b/noir/noir-repo/compiler/noirc_frontend/src/ast/expression.rs @@ -8,7 +8,7 @@ use crate::ast::{ UnresolvedTraitConstraint, UnresolvedType, UnresolvedTypeData, Visibility, }; use crate::node_interner::{ - ExprId, InternedExpressionKind, InternedStatementKind, QuotedTypeId, StructId, + ExprId, InternedExpressionKind, InternedStatementKind, QuotedTypeId, TypeId, }; use crate::token::{Attributes, FmtStrFragment, FunctionAttribute, Token, Tokens}; use crate::{Kind, Type}; @@ -559,7 +559,7 @@ pub struct ConstructorExpression { /// This may be filled out during macro expansion /// so that we can skip re-resolving the type name since it /// would be lost at that point. - pub struct_type: Option, + pub struct_type: Option, } #[derive(Debug, PartialEq, Eq, Clone)] diff --git a/noir/noir-repo/compiler/noirc_frontend/src/ast/mod.rs b/noir/noir-repo/compiler/noirc_frontend/src/ast/mod.rs index f8a82574bee..33f504437c0 100644 --- a/noir/noir-repo/compiler/noirc_frontend/src/ast/mod.rs +++ b/noir/noir-repo/compiler/noirc_frontend/src/ast/mod.rs @@ -5,6 +5,7 @@ //! Noir's Ast is produced by the parser and taken as input to name resolution, //! where it is converted into the Hir (defined in the hir_def module). mod docs; +mod enumeration; mod expression; mod function; mod statement; @@ -24,6 +25,7 @@ use proptest_derive::Arbitrary; use acvm::FieldElement; pub use docs::*; +pub use enumeration::*; use noirc_errors::Span; use serde::{Deserialize, Serialize}; pub use statement::*; diff --git a/noir/noir-repo/compiler/noirc_frontend/src/ast/statement.rs b/noir/noir-repo/compiler/noirc_frontend/src/ast/statement.rs index 57572e80d1e..02715e8c2d3 100644 --- a/noir/noir-repo/compiler/noirc_frontend/src/ast/statement.rs +++ b/noir/noir-repo/compiler/noirc_frontend/src/ast/statement.rs @@ -12,6 +12,7 @@ use super::{ }; use crate::ast::UnresolvedTypeData; use crate::elaborator::types::SELF_TYPE_NAME; +use crate::elaborator::Turbofish; use crate::lexer::token::SpannedToken; use crate::node_interner::{ InternedExpressionKind, InternedPattern, InternedStatementKind, NodeInterner, @@ -45,7 +46,7 @@ pub enum StatementKind { Expression(Expression), Assign(AssignStatement), For(ForLoopStatement), - Loop(Expression), + Loop(Expression, Span /* loop keyword span */), Break, Continue, /// This statement should be executed at compile-time @@ -307,6 +308,7 @@ pub struct ModuleDeclaration { pub visibility: ItemVisibility, pub ident: Ident, pub outer_attributes: Vec, + pub has_semicolon: bool, } impl std::fmt::Display for ModuleDeclaration { @@ -535,6 +537,12 @@ impl PathSegment { pub fn turbofish_span(&self) -> Span { Span::from(self.ident.span().end()..self.span.end()) } + + pub fn turbofish(&self) -> Option { + self.generics + .as_ref() + .map(|generics| Turbofish { span: self.turbofish_span(), generics: generics.clone() }) + } } impl From for PathSegment { @@ -965,7 +973,7 @@ impl Display for StatementKind { StatementKind::Expression(expression) => expression.fmt(f), StatementKind::Assign(assign) => assign.fmt(f), StatementKind::For(for_loop) => for_loop.fmt(f), - StatementKind::Loop(block) => write!(f, "loop {}", block), + StatementKind::Loop(block, _) => write!(f, "loop {}", block), StatementKind::Break => write!(f, "break"), StatementKind::Continue => write!(f, "continue"), StatementKind::Comptime(statement) => write!(f, "comptime {}", statement.kind), diff --git a/noir/noir-repo/compiler/noirc_frontend/src/ast/visitor.rs b/noir/noir-repo/compiler/noirc_frontend/src/ast/visitor.rs index ec50a982a70..d7fe63a6a45 100644 --- a/noir/noir-repo/compiler/noirc_frontend/src/ast/visitor.rs +++ b/noir/noir-repo/compiler/noirc_frontend/src/ast/visitor.rs @@ -21,15 +21,17 @@ use crate::{ }; use super::{ - ForBounds, FunctionReturnType, GenericTypeArgs, IntegerBitSize, ItemVisibility, Pattern, - Signedness, TraitBound, TraitImplItemKind, TypePath, UnresolvedGenerics, - UnresolvedTraitConstraint, UnresolvedType, UnresolvedTypeData, UnresolvedTypeExpression, + ForBounds, FunctionReturnType, GenericTypeArgs, IntegerBitSize, ItemVisibility, + NoirEnumeration, Pattern, Signedness, TraitBound, TraitImplItemKind, TypePath, + UnresolvedGenerics, UnresolvedTraitConstraint, UnresolvedType, UnresolvedTypeData, + UnresolvedTypeExpression, }; #[derive(Debug, Copy, Clone, PartialEq, Eq)] pub enum AttributeTarget { Module, Struct, + Enum, Trait, Function, Let, @@ -142,6 +144,10 @@ pub trait Visitor { true } + fn visit_noir_enum(&mut self, _: &NoirEnumeration, _: Span) -> bool { + true + } + fn visit_noir_type_alias(&mut self, _: &NoirTypeAlias, _: Span) -> bool { true } @@ -527,6 +533,7 @@ impl Item { } ItemKind::TypeAlias(noir_type_alias) => noir_type_alias.accept(self.span, visitor), ItemKind::Struct(noir_struct) => noir_struct.accept(self.span, visitor), + ItemKind::Enum(noir_enum) => noir_enum.accept(self.span, visitor), ItemKind::ModuleDecl(module_declaration) => { module_declaration.accept(self.span, visitor); } @@ -775,6 +782,26 @@ impl NoirStruct { } } +impl NoirEnumeration { + pub fn accept(&self, span: Span, visitor: &mut impl Visitor) { + if visitor.visit_noir_enum(self, span) { + self.accept_children(visitor); + } + } + + pub fn accept_children(&self, visitor: &mut impl Visitor) { + for attribute in &self.attributes { + attribute.accept(AttributeTarget::Enum, visitor); + } + + for variant in &self.variants { + for parameter in &variant.item.parameters { + parameter.accept(visitor); + } + } + } +} + impl NoirTypeAlias { pub fn accept(&self, span: Span, visitor: &mut impl Visitor) { if visitor.visit_noir_type_alias(self, span) { @@ -1108,7 +1135,7 @@ impl Statement { StatementKind::For(for_loop_statement) => { for_loop_statement.accept(visitor); } - StatementKind::Loop(block) => { + StatementKind::Loop(block, _) => { if visitor.visit_loop_statement(block) { block.accept(visitor); } diff --git a/noir/noir-repo/compiler/noirc_frontend/src/elaborator/comptime.rs b/noir/noir-repo/compiler/noirc_frontend/src/elaborator/comptime.rs index d88bb62e871..c13c74f44cb 100644 --- a/noir/noir-repo/compiler/noirc_frontend/src/elaborator/comptime.rs +++ b/noir/noir-repo/compiler/noirc_frontend/src/elaborator/comptime.rs @@ -19,7 +19,7 @@ use crate::{ resolution::errors::ResolverError, }, hir_def::expr::{HirExpression, HirIdent}, - node_interner::{DefinitionKind, DependencyId, FuncId, NodeInterner, StructId, TraitId}, + node_interner::{DefinitionKind, DependencyId, FuncId, NodeInterner, TraitId, TypeId}, parser::{Item, ItemKind}, token::{MetaAttribute, SecondaryAttribute}, Type, TypeBindings, UnificationError, @@ -442,7 +442,21 @@ impl<'context> Elaborator<'context> { self.crate_id, &mut self.errors, ) { - generated_items.types.insert(type_id, the_struct); + generated_items.structs.insert(type_id, the_struct); + } + } + ItemKind::Enum(enum_def) => { + if let Some((type_id, the_enum)) = dc_mod::collect_enum( + self.interner, + self.def_maps.get_mut(&self.crate_id).unwrap(), + self.usage_tracker, + Documented::new(enum_def, item.doc_comments), + self.file, + self.local_module, + self.crate_id, + &mut self.errors, + ) { + generated_items.enums.insert(type_id, the_enum); } } ItemKind::Impl(r#impl) => { @@ -498,7 +512,7 @@ impl<'context> Elaborator<'context> { pub(super) fn run_attributes( &mut self, traits: &BTreeMap, - types: &BTreeMap, + types: &BTreeMap, functions: &[UnresolvedFunctions], module_attributes: &[ModuleAttribute], ) { diff --git a/noir/noir-repo/compiler/noirc_frontend/src/elaborator/enums.rs b/noir/noir-repo/compiler/noirc_frontend/src/elaborator/enums.rs new file mode 100644 index 00000000000..2ccd2b25561 --- /dev/null +++ b/noir/noir-repo/compiler/noirc_frontend/src/elaborator/enums.rs @@ -0,0 +1,157 @@ +use iter_extended::vecmap; +use noirc_errors::Location; + +use crate::{ + ast::{EnumVariant, FunctionKind, NoirEnumeration, UnresolvedType, Visibility}, + hir_def::{ + expr::{HirEnumConstructorExpression, HirExpression, HirIdent}, + function::{FuncMeta, FunctionBody, HirFunction, Parameters}, + stmt::HirPattern, + }, + node_interner::{DefinitionKind, FuncId, FunctionModifiers, TypeId}, + token::Attributes, + DataType, Shared, Type, +}; + +use super::Elaborator; + +impl Elaborator<'_> { + #[allow(clippy::too_many_arguments)] + pub(super) fn define_enum_variant_function( + &mut self, + enum_: &NoirEnumeration, + type_id: TypeId, + variant: &EnumVariant, + variant_arg_types: Vec, + variant_index: usize, + datatype: &Shared, + self_type: &Type, + self_type_unresolved: UnresolvedType, + ) { + let name_string = variant.name.to_string(); + let datatype_ref = datatype.borrow(); + let location = Location::new(variant.name.span(), self.file); + + let id = self.interner.push_empty_fn(); + + let modifiers = FunctionModifiers { + name: name_string.clone(), + visibility: enum_.visibility, + attributes: Attributes { function: None, secondary: Vec::new() }, + is_unconstrained: false, + generic_count: datatype_ref.generics.len(), + is_comptime: false, + name_location: location, + }; + let definition_id = + self.interner.push_function_definition(id, modifiers, type_id.module_id(), location); + + let hir_name = HirIdent::non_trait_method(definition_id, location); + let parameters = self.make_enum_variant_parameters(variant_arg_types, location); + self.push_enum_variant_function_body(id, datatype, variant_index, ¶meters, location); + + let function_type = + datatype_ref.variant_function_type_with_forall(variant_index, datatype.clone()); + self.interner.push_definition_type(definition_id, function_type.clone()); + + let meta = FuncMeta { + name: hir_name, + kind: FunctionKind::Normal, + parameters, + parameter_idents: Vec::new(), + return_type: crate::ast::FunctionReturnType::Ty(self_type_unresolved), + return_visibility: Visibility::Private, + typ: function_type, + direct_generics: datatype_ref.generics.clone(), + all_generics: datatype_ref.generics.clone(), + location, + has_body: false, + trait_constraints: Vec::new(), + type_id: Some(type_id), + trait_id: None, + trait_impl: None, + enum_variant_index: Some(variant_index), + is_entry_point: false, + has_inline_attribute: false, + function_body: FunctionBody::Resolved, + source_crate: self.crate_id, + source_module: type_id.local_module_id(), + source_file: self.file, + self_type: None, + }; + + self.interner.push_fn_meta(meta, id); + self.interner.add_method(self_type, name_string, id, None); + + let name = variant.name.clone(); + Self::get_module_mut(self.def_maps, type_id.module_id()) + .declare_function(name, enum_.visibility, id) + .ok(); + } + + // Given: + // ``` + // enum FooEnum { Foo(u32, u8), ... } + // + // fn Foo(a: u32, b: u8) -> FooEnum {} + // ``` + // Create (pseudocode): + // ``` + // fn Foo(a: u32, b: u8) -> FooEnum { + // // This can't actually be written directly in Noir + // FooEnum { + // tag: Foo_tag, + // Foo: (a, b), + // // fields from other variants are zeroed in monomorphization + // } + // } + // ``` + fn push_enum_variant_function_body( + &mut self, + id: FuncId, + self_type: &Shared, + variant_index: usize, + parameters: &Parameters, + location: Location, + ) { + // Each parameter of the enum variant function is used as a parameter of the enum + // constructor expression + let arguments = vecmap(¶meters.0, |(pattern, typ, _)| match pattern { + HirPattern::Identifier(ident) => { + let id = self.interner.push_expr(HirExpression::Ident(ident.clone(), None)); + self.interner.push_expr_type(id, typ.clone()); + self.interner.push_expr_location(id, location.span, location.file); + id + } + _ => unreachable!(), + }); + + let enum_generics = self_type.borrow().generic_types(); + let construct_variant = HirExpression::EnumConstructor(HirEnumConstructorExpression { + r#type: self_type.clone(), + enum_generics: enum_generics.clone(), + arguments, + variant_index, + }); + let body = self.interner.push_expr(construct_variant); + self.interner.update_fn(id, HirFunction::unchecked_from_expr(body)); + + let typ = Type::DataType(self_type.clone(), enum_generics); + self.interner.push_expr_type(body, typ); + self.interner.push_expr_location(body, location.span, location.file); + } + + fn make_enum_variant_parameters( + &mut self, + parameter_types: Vec, + location: Location, + ) -> Parameters { + Parameters(vecmap(parameter_types.into_iter().enumerate(), |(i, parameter_type)| { + let name = format!("${i}"); + let parameter = DefinitionKind::Local(None); + let id = self.interner.push_definition(name, false, false, parameter, location); + let pattern = HirPattern::Identifier(HirIdent::non_trait_method(id, location)); + (pattern, parameter_type, Visibility::Private) + })) + } +} diff --git a/noir/noir-repo/compiler/noirc_frontend/src/elaborator/expressions.rs b/noir/noir-repo/compiler/noirc_frontend/src/elaborator/expressions.rs index 33af075aebd..68e13688b1c 100644 --- a/noir/noir-repo/compiler/noirc_frontend/src/elaborator/expressions.rs +++ b/noir/noir-repo/compiler/noirc_frontend/src/elaborator/expressions.rs @@ -29,7 +29,7 @@ use crate::{ }, node_interner::{DefinitionKind, ExprId, FuncId, InternedStatementKind, TraitMethodId}, token::{FmtStrFragment, Tokens}, - Kind, QuotedType, Shared, StructType, Type, + DataType, Kind, QuotedType, Shared, Type, }; use super::{Elaborator, LambdaContext, UnsafeBlockStatus}; @@ -52,7 +52,7 @@ impl<'context> Elaborator<'context> { ExpressionKind::If(if_) => self.elaborate_if(*if_), ExpressionKind::Variable(variable) => return self.elaborate_variable(variable), ExpressionKind::Tuple(tuple) => self.elaborate_tuple(tuple), - ExpressionKind::Lambda(lambda) => self.elaborate_lambda(*lambda), + ExpressionKind::Lambda(lambda) => self.elaborate_lambda(*lambda, None), ExpressionKind::Parenthesized(expr) => return self.elaborate_expression(*expr), ExpressionKind::Quote(quote) => self.elaborate_quote(quote, expr.span), ExpressionKind::Comptime(comptime, _) => { @@ -75,7 +75,10 @@ impl<'context> Elaborator<'context> { self.push_err(ResolverError::UnquoteUsedOutsideQuote { span: expr.span }); (HirExpression::Error, Type::Error) } - ExpressionKind::AsTraitPath(_) => todo!("Implement AsTraitPath"), + ExpressionKind::AsTraitPath(_) => { + self.push_err(ResolverError::UnquoteUsedOutsideQuote { span: expr.span }); + (HirExpression::Error, Type::Error) + } ExpressionKind::TypePath(path) => return self.elaborate_type_path(path), }; let id = self.interner.push_expr(hir_expr); @@ -387,17 +390,28 @@ impl<'context> Elaborator<'context> { fn elaborate_call(&mut self, call: CallExpression, span: Span) -> (HirExpression, Type) { let (func, func_type) = self.elaborate_expression(*call.func); + let func_arg_types = + if let Type::Function(args, _, _, _) = &func_type { Some(args) } else { None }; let mut arguments = Vec::with_capacity(call.arguments.len()); - let args = vecmap(call.arguments, |arg| { + let args = vecmap(call.arguments.into_iter().enumerate(), |(arg_index, arg)| { let span = arg.span; + let expected_type = func_arg_types.and_then(|args| args.get(arg_index)); let (arg, typ) = if call.is_macro_call { - self.elaborate_in_comptime_context(|this| this.elaborate_expression(arg)) + self.elaborate_in_comptime_context(|this| { + this.elaborate_expression_with_type(arg, expected_type) + }) } else { - self.elaborate_expression(arg) + self.elaborate_expression_with_type(arg, expected_type) }; + // Try to unify this argument type against the function's argument type + // so that a potential lambda following this argument can have more concrete types. + if let Some(expected_type) = expected_type { + let _ = expected_type.unify(&typ); + } + arguments.push(arg); (typ, arg, span) }); @@ -458,6 +472,32 @@ impl<'context> Elaborator<'context> { None }; + let call_span = Span::from(object_span.start()..method_name_span.end()); + let location = Location::new(call_span, self.file); + + let (function_id, function_name) = method_ref.clone().into_function_id_and_name( + object_type.clone(), + generics.clone(), + location, + self.interner, + ); + + let func_type = + self.type_check_variable(function_name.clone(), function_id, generics.clone()); + self.interner.push_expr_type(function_id, func_type.clone()); + + let func_arg_types = + if let Type::Function(args, _, _, _) = &func_type { Some(args) } else { None }; + + // Try to unify the object type with the first argument of the function. + // The reason to do this is that many methods that take a lambda will yield `self` or part of `self` + // as a parameter. By unifying `self` with the first argument we'll potentially get more + // concrete types in the arguments that are function types, which will later be passed as + // lambda parameter hints. + if let Some(first_arg_type) = func_arg_types.and_then(|args| args.first()) { + let _ = first_arg_type.unify(&object_type); + } + // These arguments will be given to the desugared function call. // Compared to the method arguments, they also contain the object. let mut function_args = Vec::with_capacity(method_call.arguments.len() + 1); @@ -465,17 +505,22 @@ impl<'context> Elaborator<'context> { function_args.push((object_type.clone(), object, object_span)); - for arg in method_call.arguments { + for (arg_index, arg) in method_call.arguments.into_iter().enumerate() { let span = arg.span; - let (arg, typ) = self.elaborate_expression(arg); + let expected_type = func_arg_types.and_then(|args| args.get(arg_index + 1)); + let (arg, typ) = self.elaborate_expression_with_type(arg, expected_type); + + // Try to unify this argument type against the function's argument type + // so that a potential lambda following this argument can have more concrete types. + if let Some(expected_type) = expected_type { + let _ = expected_type.unify(&typ); + } + arguments.push(arg); function_args.push((typ, arg, span)); } - let call_span = Span::from(object_span.start()..method_name_span.end()); - let location = Location::new(call_span, self.file); let method = method_call.method_name; - let turbofish_generics = generics.clone(); let is_macro_call = method_call.is_macro_call; let method_call = HirMethodCallExpression { method, object, arguments, location, generics }; @@ -485,18 +530,9 @@ impl<'context> Elaborator<'context> { // Desugar the method call into a normal, resolved function call // so that the backend doesn't need to worry about methods // TODO: update object_type here? - let ((function_id, function_name), function_call) = method_call.into_function_call( - method_ref, - object_type, - is_macro_call, - location, - self.interner, - ); - - let func_type = - self.type_check_variable(function_name, function_id, turbofish_generics); - self.interner.push_expr_type(function_id, func_type.clone()); + let function_call = + method_call.into_function_call(function_id, is_macro_call, location); self.interner .add_function_reference(func_id, Location::new(method_name_span, self.file)); @@ -520,6 +556,26 @@ impl<'context> Elaborator<'context> { } } + /// Elaborates an expression knowing that it has to match a given type. + fn elaborate_expression_with_type( + &mut self, + arg: Expression, + typ: Option<&Type>, + ) -> (ExprId, Type) { + let ExpressionKind::Lambda(lambda) = arg.kind else { + return self.elaborate_expression(arg); + }; + + let span = arg.span; + let type_hint = + if let Some(Type::Function(func_args, _, _, _)) = typ { Some(func_args) } else { None }; + let (hir_expr, typ) = self.elaborate_lambda(*lambda, type_hint); + let id = self.interner.push_expr(hir_expr); + self.interner.push_expr_location(id, span, self.file); + self.interner.push_expr_type(id, typ.clone()); + (id, typ) + } + fn check_method_call_visibility(&mut self, func_id: FuncId, object_type: &Type, name: &Ident) { if !method_call_is_visible( object_type, @@ -561,12 +617,14 @@ impl<'context> Elaborator<'context> { let is_self_type = last_segment.ident.is_self_type_name(); let (r#type, struct_generics) = if let Some(struct_id) = constructor.struct_type { - let typ = self.interner.get_struct(struct_id); + let typ = self.interner.get_type(struct_id); let generics = typ.borrow().instantiate(self.interner); (typ, generics) } else { match self.lookup_type_or_error(path) { - Some(Type::Struct(r#type, struct_generics)) => (r#type, struct_generics), + Some(Type::DataType(r#type, struct_generics)) if r#type.borrow().is_struct() => { + (r#type, struct_generics) + } Some(typ) => { self.push_err(ResolverError::NonStructUsedInConstructor { typ: typ.to_string(), @@ -593,7 +651,11 @@ impl<'context> Elaborator<'context> { let generics = struct_generics.clone(); let fields = constructor.fields; - let field_types = r#type.borrow().get_fields_with_visibility(&struct_generics); + let field_types = r#type + .borrow() + .get_fields_with_visibility(&struct_generics) + .expect("This type should already be validated to be a struct"); + let fields = self.resolve_constructor_expr_fields(struct_type.clone(), field_types, fields, span); let expr = HirExpression::Constructor(HirConstructorExpression { @@ -604,12 +666,12 @@ impl<'context> Elaborator<'context> { let struct_id = struct_type.borrow().id; let reference_location = Location::new(last_segment.ident.span(), self.file); - self.interner.add_struct_reference(struct_id, reference_location, is_self_type); + self.interner.add_type_reference(struct_id, reference_location, is_self_type); - (expr, Type::Struct(struct_type, generics)) + (expr, Type::DataType(struct_type, generics)) } - pub(super) fn mark_struct_as_constructed(&mut self, struct_type: Shared) { + pub(super) fn mark_struct_as_constructed(&mut self, struct_type: Shared) { let struct_type = struct_type.borrow(); let parent_module_id = struct_type.id.parent_module_id(self.def_maps); self.usage_tracker.mark_as_used(parent_module_id, &struct_type.name); @@ -620,14 +682,17 @@ impl<'context> Elaborator<'context> { /// are part of the struct. fn resolve_constructor_expr_fields( &mut self, - struct_type: Shared, + struct_type: Shared, field_types: Vec<(String, ItemVisibility, Type)>, fields: Vec<(Ident, Expression)>, span: Span, ) -> Vec<(Ident, ExprId)> { let mut ret = Vec::with_capacity(fields.len()); let mut seen_fields = HashSet::default(); - let mut unseen_fields = struct_type.borrow().field_names(); + let mut unseen_fields = struct_type + .borrow() + .field_names() + .expect("This type should already be validated to be a struct"); for (field_name, field) in fields { let expected_field_with_index = field_types @@ -846,19 +911,38 @@ impl<'context> Elaborator<'context> { (HirExpression::Tuple(element_ids), Type::Tuple(element_types)) } - fn elaborate_lambda(&mut self, lambda: Lambda) -> (HirExpression, Type) { + /// For elaborating a lambda we might get `parameters_type_hints`. These come from a potential + /// call that has this lambda as the argument. + /// The parameter type hints will be the types of the function type corresponding to the lambda argument. + fn elaborate_lambda( + &mut self, + lambda: Lambda, + parameters_type_hints: Option<&Vec>, + ) -> (HirExpression, Type) { self.push_scope(); let scope_index = self.scopes.current_scope_index(); self.lambda_stack.push(LambdaContext { captures: Vec::new(), scope_index }); let mut arg_types = Vec::with_capacity(lambda.parameters.len()); - let parameters = vecmap(lambda.parameters, |(pattern, typ)| { - let parameter = DefinitionKind::Local(None); - let typ = self.resolve_inferred_type(typ); - arg_types.push(typ.clone()); - (self.elaborate_pattern(pattern, typ.clone(), parameter, true), typ) - }); + let parameters = + vecmap(lambda.parameters.into_iter().enumerate(), |(index, (pattern, typ))| { + let parameter = DefinitionKind::Local(None); + let typ = if let UnresolvedTypeData::Unspecified = typ.typ { + if let Some(parameter_type_hint) = + parameters_type_hints.and_then(|hints| hints.get(index)) + { + parameter_type_hint.clone() + } else { + self.interner.next_type_variable_with_kind(Kind::Any) + } + } else { + self.resolve_type(typ) + }; + + arg_types.push(typ.clone()); + (self.elaborate_pattern(pattern, typ.clone(), parameter, true), typ) + }); let return_type = self.resolve_inferred_type(lambda.return_type); let body_span = lambda.body.span; diff --git a/noir/noir-repo/compiler/noirc_frontend/src/elaborator/lints.rs b/noir/noir-repo/compiler/noirc_frontend/src/elaborator/lints.rs index d3b776bea24..af80dfaa823 100644 --- a/noir/noir-repo/compiler/noirc_frontend/src/elaborator/lints.rs +++ b/noir/noir-repo/compiler/noirc_frontend/src/elaborator/lints.rs @@ -282,6 +282,7 @@ fn can_return_without_recursing(interner: &NodeInterner, func_id: FuncId, expr_i HirStatement::Semi(e) => check(e), // Rust doesn't seem to check the for loop body (it's bounds might mean it's never called). HirStatement::For(e) => check(e.start_range) && check(e.end_range), + HirStatement::Loop(e) => check(e), HirStatement::Constrain(_) | HirStatement::Comptime(_) | HirStatement::Break @@ -319,6 +320,7 @@ fn can_return_without_recursing(interner: &NodeInterner, func_id: FuncId, expr_i HirExpression::Lambda(_) | HirExpression::Literal(_) | HirExpression::Constructor(_) + | HirExpression::EnumConstructor(_) | HirExpression::Quote(_) | HirExpression::Unquote(_) | HirExpression::Comptime(_) diff --git a/noir/noir-repo/compiler/noirc_frontend/src/elaborator/mod.rs b/noir/noir-repo/compiler/noirc_frontend/src/elaborator/mod.rs index d3dded22ab4..65db9f62559 100644 --- a/noir/noir-repo/compiler/noirc_frontend/src/elaborator/mod.rs +++ b/noir/noir-repo/compiler/noirc_frontend/src/elaborator/mod.rs @@ -12,10 +12,11 @@ use crate::{ graph::CrateId, hir::{ def_collector::dc_crate::{ - filter_literal_globals, CompilationError, ImplMap, UnresolvedFunctions, - UnresolvedGlobal, UnresolvedStruct, UnresolvedTraitImpl, UnresolvedTypeAlias, + filter_literal_globals, CollectedItems, CompilationError, ImplMap, UnresolvedEnum, + UnresolvedFunctions, UnresolvedGlobal, UnresolvedStruct, UnresolvedTraitImpl, + UnresolvedTypeAlias, }, - def_collector::{dc_crate::CollectedItems, errors::DefCollectorErrorKind}, + def_collector::errors::DefCollectorErrorKind, def_map::{DefMaps, ModuleData}, def_map::{LocalModuleId, ModuleId, MAIN_FUNCTION}, resolution::errors::ResolverError, @@ -23,19 +24,18 @@ use crate::{ type_check::{generics::TraitGenerics, TypeCheckError}, Context, }, - hir_def::traits::TraitImpl, hir_def::{ expr::{HirCapturedVar, HirIdent}, function::{FuncMeta, FunctionBody, HirFunction}, - traits::TraitConstraint, + traits::{TraitConstraint, TraitImpl}, types::{Generics, Kind, ResolvedGeneric}, }, node_interner::{ DefinitionKind, DependencyId, ExprId, FuncId, FunctionModifiers, GlobalId, NodeInterner, - ReferenceId, StructId, TraitId, TraitImplId, TypeAliasId, + ReferenceId, TraitId, TraitImplId, TypeAliasId, TypeId, }, token::SecondaryAttribute, - Shared, Type, TypeVariable, + EnumVariant, Shared, Type, TypeVariable, }; use crate::{ ast::{ItemVisibility, UnresolvedType}, @@ -43,10 +43,11 @@ use crate::{ hir_def::traits::ResolvedTraitBound, node_interner::GlobalValue, usage_tracker::UsageTracker, - StructField, StructType, TypeBindings, + DataType, StructField, TypeBindings, }; mod comptime; +mod enums; mod expressions; mod lints; mod path_resolution; @@ -61,6 +62,7 @@ mod unquote; use fm::FileId; use iter_extended::vecmap; use noirc_errors::{Location, Span, Spanned}; +pub use path_resolution::Turbofish; use path_resolution::{PathResolution, PathResolutionItem}; use types::bind_ordered_generics; @@ -93,6 +95,11 @@ enum UnsafeBlockStatus { InUnsafeBlockWithConstrainedCalls, } +pub struct Loop { + pub is_for: bool, + pub has_break: bool, +} + pub struct Elaborator<'context> { scopes: ScopeForest, @@ -106,7 +113,7 @@ pub struct Elaborator<'context> { pub(crate) file: FileId, unsafe_block_status: UnsafeBlockStatus, - nested_loops: usize, + current_loop: Option, /// Contains a mapping of the current struct or functions's generics to /// unique type variables if we're resolving a struct. Empty otherwise. @@ -146,7 +153,7 @@ pub struct Elaborator<'context> { /// struct Wrapped { /// } /// ``` - resolving_ids: BTreeSet, + resolving_ids: BTreeSet, /// Each constraint in the `where` clause of the function currently being resolved. trait_bounds: Vec, @@ -229,7 +236,7 @@ impl<'context> Elaborator<'context> { crate_graph, file: FileId::dummy(), unsafe_block_status: UnsafeBlockStatus::NotInUnsafeBlock, - nested_loops: 0, + current_loop: None, generics: Vec::new(), lambda_stack: Vec::new(), self_type: None, @@ -318,8 +325,9 @@ impl<'context> Elaborator<'context> { self.define_type_alias(alias_id, alias); } - // Must resolve structs before we resolve globals. - self.collect_struct_definitions(&items.types); + // Must resolve types before we resolve globals. + self.collect_struct_definitions(&items.structs); + self.collect_enum_definitions(&items.enums); self.define_function_metas(&mut items.functions, &mut items.impls, &mut items.trait_impls); @@ -349,7 +357,7 @@ impl<'context> Elaborator<'context> { // since the generated items are checked beforehand as well. self.run_attributes( &items.traits, - &items.types, + &items.structs, &items.functions, &items.module_attributes, ); @@ -976,7 +984,7 @@ impl<'context> Elaborator<'context> { let statements = std::mem::take(&mut func.def.body.statements); let body = BlockExpression { statements }; - let struct_id = if let Some(Type::Struct(struct_type, _)) = &self.self_type { + let struct_id = if let Some(Type::DataType(struct_type, _)) = &self.self_type { Some(struct_type.borrow().id) } else { None @@ -992,9 +1000,10 @@ impl<'context> Elaborator<'context> { typ, direct_generics, all_generics: self.generics.clone(), - struct_id, + type_id: struct_id, trait_id, trait_impl: self.current_trait_impl, + enum_variant_index: None, parameters: parameters.into(), parameter_idents, return_type: func.def.return_type.clone(), @@ -1024,13 +1033,21 @@ impl<'context> Elaborator<'context> { self.mark_type_as_used(typ); } } - Type::Struct(struct_type, generics) => { - self.mark_struct_as_constructed(struct_type.clone()); + Type::DataType(datatype, generics) => { + self.mark_struct_as_constructed(datatype.clone()); for generic in generics { self.mark_type_as_used(generic); } - for (_, typ) in struct_type.borrow().get_fields(generics) { - self.mark_type_as_used(&typ); + if let Some(fields) = datatype.borrow().get_fields(generics) { + for (_, typ) in fields { + self.mark_type_as_used(&typ); + } + } else if let Some(variants) = datatype.borrow().get_variants(generics) { + for (_, variant_types) in variants { + for typ in variant_types { + self.mark_type_as_used(&typ); + } + } } } Type::Alias(alias_type, generics) => { @@ -1043,7 +1060,7 @@ impl<'context> Elaborator<'context> { Type::MutableReference(typ) => { self.mark_type_as_used(typ); } - Type::InfixExpr(left, _op, right) => { + Type::InfixExpr(left, _op, right, _) => { self.mark_type_as_used(left); self.mark_type_as_used(right); } @@ -1501,7 +1518,7 @@ impl<'context> Elaborator<'context> { let function_ids = functions.function_ids(); - if let Type::Struct(struct_type, _) = &self_type { + if let Type::DataType(struct_type, _) = &self_type { let struct_ref = struct_type.borrow(); // `impl`s are only allowed on types defined within the current crate @@ -1596,7 +1613,7 @@ impl<'context> Elaborator<'context> { } /// Find the struct in the parent module so we can know its visibility - fn find_struct_visibility(&self, struct_type: &StructType) -> Option { + fn find_struct_visibility(&self, struct_type: &DataType) -> Option { let parent_module_id = struct_type.id.parent_module_id(self.def_maps); let parent_module_data = self.get_module(parent_module_id); let per_ns = parent_module_data.find_name(&struct_type.name); @@ -1618,8 +1635,8 @@ impl<'context> Elaborator<'context> { return false; } // Public struct functions should not expose private types. - if let Some(struct_visibility) = func_meta.struct_id.and_then(|id| { - let struct_def = self.get_struct(id); + if let Some(struct_visibility) = func_meta.type_id.and_then(|id| { + let struct_def = self.get_type(id); let struct_def = struct_def.borrow(); self.find_struct_visibility(&struct_def) }) { @@ -1638,7 +1655,7 @@ impl<'context> Elaborator<'context> { span: Span, ) { match typ { - Type::Struct(struct_type, generics) => { + Type::DataType(struct_type, generics) => { let struct_type = struct_type.borrow(); let struct_module_id = struct_type.id.module_id(); @@ -1688,7 +1705,7 @@ impl<'context> Elaborator<'context> { Type::MutableReference(typ) | Type::Array(_, typ) | Type::Slice(typ) => { self.check_type_is_not_more_private_then_item(name, visibility, typ, span); } - Type::InfixExpr(left, _op, right) => { + Type::InfixExpr(left, _op, right, _) => { self.check_type_is_not_more_private_then_item(name, visibility, left, span); self.check_type_is_not_more_private_then_item(name, visibility, right, span); } @@ -1708,7 +1725,7 @@ impl<'context> Elaborator<'context> { } } - fn collect_struct_definitions(&mut self, structs: &BTreeMap) { + fn collect_struct_definitions(&mut self, structs: &BTreeMap) { // This is necessary to avoid cloning the entire struct map // when adding checks after each struct field is resolved. let struct_ids = structs.keys().copied().collect::>(); @@ -1743,29 +1760,29 @@ impl<'context> Elaborator<'context> { } } - let fields_len = fields.len(); - self.interner.update_struct(*type_id, |struct_def| { + if self.interner.is_in_lsp_mode() { + for (field_index, field) in fields.iter().enumerate() { + let location = Location::new(field.name.span(), self.file); + let reference_id = ReferenceId::StructMember(*type_id, field_index); + self.interner.add_definition_location(reference_id, location, None); + } + } + + self.interner.update_type(*type_id, |struct_def| { struct_def.set_fields(fields); }); - - for field_index in 0..fields_len { - self.interner.add_definition_location( - ReferenceId::StructMember(*type_id, field_index), - None, - ); - } } // Check whether the struct fields have nested slices // We need to check after all structs are resolved to // make sure every struct's fields is accurately set. for id in struct_ids { - let struct_type = self.interner.get_struct(id); + let struct_type = self.interner.get_type(id); // Only handle structs without generics as any generics args will be checked // after monomorphization when performing SSA codegen if struct_type.borrow().generics.is_empty() { - let fields = struct_type.borrow().get_fields(&[]); + let fields = struct_type.borrow().get_fields(&[]).unwrap(); for (_, field_type) in fields.iter() { if field_type.is_nested_slice() { let location = struct_type.borrow().location; @@ -1780,14 +1797,14 @@ impl<'context> Elaborator<'context> { pub fn resolve_struct_fields( &mut self, unresolved: &NoirStruct, - struct_id: StructId, + struct_id: TypeId, ) -> Vec { self.recover_generics(|this| { this.current_item = Some(DependencyId::Struct(struct_id)); this.resolving_ids.insert(struct_id); - let struct_def = this.interner.get_struct(struct_id); + let struct_def = this.interner.get_type(struct_id); this.add_existing_generics(&unresolved.generics, &struct_def.borrow().generics); let fields = vecmap(&unresolved.fields, |field| { @@ -1803,6 +1820,50 @@ impl<'context> Elaborator<'context> { }) } + fn collect_enum_definitions(&mut self, enums: &BTreeMap) { + for (type_id, typ) in enums { + self.file = typ.file_id; + self.local_module = typ.module_id; + self.generics.clear(); + + let datatype = self.interner.get_type(*type_id); + let generics = datatype.borrow().generic_types(); + self.add_existing_generics(&typ.enum_def.generics, &datatype.borrow().generics); + + let self_type = Type::DataType(datatype.clone(), generics); + let self_type_id = self.interner.push_quoted_type(self_type.clone()); + let unresolved = UnresolvedType { + typ: UnresolvedTypeData::Resolved(self_type_id), + span: typ.enum_def.span, + }; + + datatype.borrow_mut().init_variants(); + let module_id = ModuleId { krate: self.crate_id, local_id: typ.module_id }; + + for (i, variant) in typ.enum_def.variants.iter().enumerate() { + let types = vecmap(&variant.item.parameters, |typ| self.resolve_type(typ.clone())); + let name = variant.item.name.clone(); + datatype.borrow_mut().push_variant(EnumVariant::new(name, types.clone())); + + // Define a function for each variant to construct it + self.define_enum_variant_function( + &typ.enum_def, + *type_id, + &variant.item, + types, + i, + &datatype, + &self_type, + unresolved.clone(), + ); + + let reference_id = ReferenceId::EnumVariant(*type_id, i); + let location = Location::new(variant.item.name.span(), self.file); + self.interner.add_definition_location(reference_id, location, Some(module_id)); + } + } + } + fn elaborate_global(&mut self, global: UnresolvedGlobal) { let old_module = std::mem::replace(&mut self.local_module, global.module_id); let old_file = std::mem::replace(&mut self.file, global.file_id); @@ -1818,15 +1879,15 @@ impl<'context> Elaborator<'context> { None }; + let span = let_stmt.pattern.span(); + if !self.in_contract() && let_stmt.attributes.iter().any(|attr| matches!(attr, SecondaryAttribute::Abi(_))) { - let span = let_stmt.pattern.span(); self.push_err(ResolverError::AbiAttributeOutsideContract { span }); } if !let_stmt.comptime && matches!(let_stmt.pattern, Pattern::Mutable(..)) { - let span = let_stmt.pattern.span(); self.push_err(ResolverError::MutableGlobal { span }); } @@ -1839,7 +1900,14 @@ impl<'context> Elaborator<'context> { self.elaborate_comptime_global(global_id); if let Some(name) = name { - self.interner.register_global(global_id, name, global.visibility, self.module_id()); + let location = Location::new(span, self.file); + self.interner.register_global( + global_id, + name, + location, + global.visibility, + self.module_id(), + ); } self.local_module = old_module; diff --git a/noir/noir-repo/compiler/noirc_frontend/src/elaborator/path_resolution.rs b/noir/noir-repo/compiler/noirc_frontend/src/elaborator/path_resolution.rs index 0d0b153b6b6..67a99da70eb 100644 --- a/noir/noir-repo/compiler/noirc_frontend/src/elaborator/path_resolution.rs +++ b/noir/noir-repo/compiler/noirc_frontend/src/elaborator/path_resolution.rs @@ -9,7 +9,7 @@ use crate::hir::resolution::errors::ResolverError; use crate::hir::resolution::visibility::item_in_module_is_visible; use crate::locations::ReferencesTracker; -use crate::node_interner::{FuncId, GlobalId, StructId, TraitId, TypeAliasId}; +use crate::node_interner::{FuncId, GlobalId, TraitId, TypeAliasId, TypeId}; use crate::{Shared, Type, TypeAlias}; use super::types::SELF_TYPE_NAME; @@ -27,12 +27,12 @@ pub(crate) struct PathResolution { #[derive(Debug, Clone)] pub enum PathResolutionItem { Module(ModuleId), - Struct(StructId), + Type(TypeId), TypeAlias(TypeAliasId), Trait(TraitId), Global(GlobalId), ModuleFunction(FuncId), - StructFunction(StructId, Option, FuncId), + Method(TypeId, Option, FuncId), TypeAliasFunction(TypeAliasId, Option, FuncId), TraitFunction(TraitId, Option, FuncId), } @@ -41,7 +41,7 @@ impl PathResolutionItem { pub fn function_id(&self) -> Option { match self { PathResolutionItem::ModuleFunction(func_id) - | PathResolutionItem::StructFunction(_, _, func_id) + | PathResolutionItem::Method(_, _, func_id) | PathResolutionItem::TypeAliasFunction(_, _, func_id) | PathResolutionItem::TraitFunction(_, _, func_id) => Some(*func_id), _ => None, @@ -58,12 +58,12 @@ impl PathResolutionItem { pub fn description(&self) -> &'static str { match self { PathResolutionItem::Module(..) => "module", - PathResolutionItem::Struct(..) => "type", + PathResolutionItem::Type(..) => "type", PathResolutionItem::TypeAlias(..) => "type alias", PathResolutionItem::Trait(..) => "trait", PathResolutionItem::Global(..) => "global", PathResolutionItem::ModuleFunction(..) - | PathResolutionItem::StructFunction(..) + | PathResolutionItem::Method(..) | PathResolutionItem::TypeAliasFunction(..) | PathResolutionItem::TraitFunction(..) => "function", } @@ -80,19 +80,19 @@ pub struct Turbofish { #[derive(Debug)] enum IntermediatePathResolutionItem { Module, - Struct(StructId, Option), + Type(TypeId, Option), TypeAlias(TypeAliasId, Option), Trait(TraitId, Option), } pub(crate) type PathResolutionResult = Result; -enum StructMethodLookupResult { +enum MethodLookupResult { /// The method could not be found. There might be trait methods that could be imported, /// but none of them are. NotFound(Vec), - /// Found a struct method. - FoundStructMethod(PerNs), + /// Found a method. + FoundMethod(PerNs), /// Found a trait method and it's currently in scope. FoundTraitMethod(PerNs, TraitId), /// There's only one trait method that matches, but it's not in scope @@ -124,16 +124,16 @@ impl<'context> Elaborator<'context> { let mut module_id = self.module_id(); if path.kind == PathKind::Plain && path.first_name() == Some(SELF_TYPE_NAME) { - if let Some(Type::Struct(struct_type, _)) = &self.self_type { - let struct_type = struct_type.borrow(); + if let Some(Type::DataType(datatype, _)) = &self.self_type { + let datatype = datatype.borrow(); if path.segments.len() == 1 { return Ok(PathResolution { - item: PathResolutionItem::Struct(struct_type.id), + item: PathResolutionItem::Type(datatype.id), errors: Vec::new(), }); } - module_id = struct_type.id.module_id(); + module_id = datatype.id.module_id(); path.segments.remove(0); } } @@ -211,9 +211,9 @@ impl<'context> Elaborator<'context> { last_segment.ident.is_self_type_name(), ); - let current_module_id_is_struct; + let current_module_id_is_type; - (current_module_id, current_module_id_is_struct, intermediate_item) = match typ { + (current_module_id, current_module_id_is_type, intermediate_item) = match typ { ModuleDefId::ModuleId(id) => { if last_segment_generics.is_some() { errors.push(PathResolutionError::TurbofishNotAllowedOnItem { @@ -224,46 +224,24 @@ impl<'context> Elaborator<'context> { (id, false, IntermediatePathResolutionItem::Module) } - ModuleDefId::TypeId(id) => ( - id.module_id(), - true, - IntermediatePathResolutionItem::Struct( - id, - last_segment_generics.as_ref().map(|generics| Turbofish { - generics: generics.clone(), - span: last_segment.turbofish_span(), - }), - ), - ), + ModuleDefId::TypeId(id) => { + let item = IntermediatePathResolutionItem::Type(id, last_segment.turbofish()); + (id.module_id(), true, item) + } ModuleDefId::TypeAliasId(id) => { let type_alias = self.interner.get_type_alias(id); let Some(module_id) = get_type_alias_module_def_id(&type_alias) else { return Err(PathResolutionError::Unresolved(last_ident.clone())); }; - ( - module_id, - true, - IntermediatePathResolutionItem::TypeAlias( - id, - last_segment_generics.as_ref().map(|generics| Turbofish { - generics: generics.clone(), - span: last_segment.turbofish_span(), - }), - ), - ) + let item = + IntermediatePathResolutionItem::TypeAlias(id, last_segment.turbofish()); + (module_id, true, item) + } + ModuleDefId::TraitId(id) => { + let item = IntermediatePathResolutionItem::Trait(id, last_segment.turbofish()); + (id.0, false, item) } - ModuleDefId::TraitId(id) => ( - id.0, - false, - IntermediatePathResolutionItem::Trait( - id, - last_segment_generics.as_ref().map(|generics| Turbofish { - generics: generics.clone(), - span: last_segment.turbofish_span(), - }), - ), - ), ModuleDefId::FunctionId(_) => panic!("functions cannot be in the type namespace"), ModuleDefId::GlobalId(_) => panic!("globals cannot be in the type namespace"), }; @@ -284,10 +262,9 @@ impl<'context> Elaborator<'context> { current_module = self.get_module(current_module_id); // Check if namespace - let found_ns = if current_module_id_is_struct { - match self.resolve_struct_function(importing_module, current_module, current_ident) - { - StructMethodLookupResult::NotFound(vec) => { + let found_ns = if current_module_id_is_type { + match self.resolve_method(importing_module, current_module, current_ident) { + MethodLookupResult::NotFound(vec) => { if vec.is_empty() { return Err(PathResolutionError::Unresolved(current_ident.clone())); } else { @@ -303,16 +280,13 @@ impl<'context> Elaborator<'context> { ); } } - StructMethodLookupResult::FoundStructMethod(per_ns) => per_ns, - StructMethodLookupResult::FoundTraitMethod(per_ns, trait_id) => { + MethodLookupResult::FoundMethod(per_ns) => per_ns, + MethodLookupResult::FoundTraitMethod(per_ns, trait_id) => { let trait_ = self.interner.get_trait(trait_id); self.usage_tracker.mark_as_used(importing_module, &trait_.name); per_ns } - StructMethodLookupResult::FoundOneTraitMethodButNotInScope( - per_ns, - trait_id, - ) => { + MethodLookupResult::FoundOneTraitMethodButNotInScope(per_ns, trait_id) => { let trait_ = self.interner.get_trait(trait_id); let trait_name = self.fully_qualified_trait_path(trait_); errors.push(PathResolutionError::TraitMethodNotInScope { @@ -321,7 +295,7 @@ impl<'context> Elaborator<'context> { }); per_ns } - StructMethodLookupResult::FoundMultipleTraitMethods(vec) => { + MethodLookupResult::FoundMultipleTraitMethods(vec) => { let traits = vecmap(vec, |trait_id| { let trait_ = self.interner.get_trait(trait_id); self.usage_tracker.mark_as_used(importing_module, &trait_.name); @@ -373,32 +347,29 @@ impl<'context> Elaborator<'context> { } fn self_type_module_id(&self) -> Option { - if let Some(Type::Struct(struct_type, _)) = &self.self_type { - Some(struct_type.borrow().id.module_id()) + if let Some(Type::DataType(datatype, _)) = &self.self_type { + Some(datatype.borrow().id.module_id()) } else { None } } - fn resolve_struct_function( + fn resolve_method( &self, importing_module_id: ModuleId, current_module: &ModuleData, ident: &Ident, - ) -> StructMethodLookupResult { - // If the current module is a struct, next we need to find a function for it. - // The function could be in the struct itself, or it could be defined in traits. + ) -> MethodLookupResult { + // If the current module is a type, next we need to find a function for it. + // The function could be in the type itself, or it could be defined in traits. let item_scope = current_module.scope(); let Some(values) = item_scope.values().get(ident) else { - return StructMethodLookupResult::NotFound(vec![]); + return MethodLookupResult::NotFound(vec![]); }; - // First search if the function is defined in the struct itself + // First search if the function is defined in the type itself if let Some(item) = values.get(&None) { - return StructMethodLookupResult::FoundStructMethod(PerNs { - types: None, - values: Some(*item), - }); + return MethodLookupResult::FoundMethod(PerNs { types: None, values: Some(*item) }); } // Otherwise, the function could be defined in zero, one or more traits. @@ -427,25 +398,23 @@ impl<'context> Elaborator<'context> { let (trait_id, item) = values.iter().next().expect("Expected an item"); let trait_id = trait_id.expect("The None option was already considered before"); let per_ns = PerNs { types: None, values: Some(*item) }; - return StructMethodLookupResult::FoundOneTraitMethodButNotInScope( - per_ns, trait_id, - ); + return MethodLookupResult::FoundOneTraitMethodButNotInScope(per_ns, trait_id); } else { let trait_ids = vecmap(values, |(trait_id, _)| { trait_id.expect("The none option was already considered before") }); - return StructMethodLookupResult::NotFound(trait_ids); + return MethodLookupResult::NotFound(trait_ids); } } if results.len() > 1 { let trait_ids = vecmap(results, |(trait_id, _)| trait_id); - return StructMethodLookupResult::FoundMultipleTraitMethods(trait_ids); + return MethodLookupResult::FoundMultipleTraitMethods(trait_ids); } let (trait_id, item) = results.remove(0); let per_ns = PerNs { types: None, values: Some(*item) }; - StructMethodLookupResult::FoundTraitMethod(per_ns, trait_id) + MethodLookupResult::FoundTraitMethod(per_ns, trait_id) } } @@ -455,14 +424,14 @@ fn merge_intermediate_path_resolution_item_with_module_def_id( ) -> PathResolutionItem { match module_def_id { ModuleDefId::ModuleId(module_id) => PathResolutionItem::Module(module_id), - ModuleDefId::TypeId(struct_id) => PathResolutionItem::Struct(struct_id), + ModuleDefId::TypeId(type_id) => PathResolutionItem::Type(type_id), ModuleDefId::TypeAliasId(type_alias_id) => PathResolutionItem::TypeAlias(type_alias_id), ModuleDefId::TraitId(trait_id) => PathResolutionItem::Trait(trait_id), ModuleDefId::GlobalId(global_id) => PathResolutionItem::Global(global_id), ModuleDefId::FunctionId(func_id) => match intermediate_item { IntermediatePathResolutionItem::Module => PathResolutionItem::ModuleFunction(func_id), - IntermediatePathResolutionItem::Struct(struct_id, generics) => { - PathResolutionItem::StructFunction(struct_id, generics, func_id) + IntermediatePathResolutionItem::Type(type_id, generics) => { + PathResolutionItem::Method(type_id, generics, func_id) } IntermediatePathResolutionItem::TypeAlias(alias_id, generics) => { PathResolutionItem::TypeAliasFunction(alias_id, generics, func_id) @@ -478,13 +447,13 @@ fn get_type_alias_module_def_id(type_alias: &Shared) -> Option Some(struct_id.borrow().id.module_id()), + Type::DataType(type_id, _generics) => Some(type_id.borrow().id.module_id()), Type::Alias(type_alias, _generics) => get_type_alias_module_def_id(type_alias), Type::Error => None, _ => { - // For now we only allow type aliases that point to structs. + // For now we only allow type aliases that point to data types. // The more general case is captured here: https://github.com/noir-lang/noir/issues/6398 - panic!("Type alias in path not pointing to struct not yet supported") + panic!("Type alias in path not pointing to a data type is not yet supported") } } } diff --git a/noir/noir-repo/compiler/noirc_frontend/src/elaborator/patterns.rs b/noir/noir-repo/compiler/noirc_frontend/src/elaborator/patterns.rs index 6a672866d7e..eab0b91f0f6 100644 --- a/noir/noir-repo/compiler/noirc_frontend/src/elaborator/patterns.rs +++ b/noir/noir-repo/compiler/noirc_frontend/src/elaborator/patterns.rs @@ -17,7 +17,7 @@ use crate::{ stmt::HirPattern, }, node_interner::{DefinitionId, DefinitionKind, ExprId, FuncId, GlobalId, TraitImplKind}, - Kind, Shared, StructType, Type, TypeAlias, TypeBindings, + DataType, Kind, Shared, Type, TypeAlias, TypeBindings, }; use super::{path_resolution::PathResolutionItem, Elaborator, ResolverMeta}; @@ -192,7 +192,11 @@ impl<'context> Elaborator<'context> { }; let (struct_type, generics) = match self.lookup_type_or_error(name) { - Some(Type::Struct(struct_type, struct_generics)) => (struct_type, struct_generics), + Some(Type::DataType(struct_type, struct_generics)) + if struct_type.borrow().is_struct() => + { + (struct_type, struct_generics) + } None => return error_identifier(self), Some(typ) => { let typ = typ.to_string(); @@ -210,7 +214,7 @@ impl<'context> Elaborator<'context> { turbofish_span, ); - let actual_type = Type::Struct(struct_type.clone(), generics); + let actual_type = Type::DataType(struct_type.clone(), generics); let location = Location::new(span, self.file); self.unify(&actual_type, &expected_type, || TypeCheckError::TypeMismatchWithSource { @@ -234,7 +238,7 @@ impl<'context> Elaborator<'context> { let struct_id = struct_type.borrow().id; let reference_location = Location::new(name_span, self.file); - self.interner.add_struct_reference(struct_id, reference_location, is_self_type); + self.interner.add_type_reference(struct_id, reference_location, is_self_type); for (field_index, field) in fields.iter().enumerate() { let reference_location = Location::new(field.0.span(), self.file); @@ -250,7 +254,7 @@ impl<'context> Elaborator<'context> { #[allow(clippy::too_many_arguments)] fn resolve_constructor_pattern_fields( &mut self, - struct_type: Shared, + struct_type: Shared, fields: Vec<(Ident, Pattern)>, span: Span, expected_type: Type, @@ -260,7 +264,10 @@ impl<'context> Elaborator<'context> { ) -> Vec<(Ident, HirPattern)> { let mut ret = Vec::with_capacity(fields.len()); let mut seen_fields = HashSet::default(); - let mut unseen_fields = struct_type.borrow().field_names(); + let mut unseen_fields = struct_type + .borrow() + .field_names() + .expect("This type should already be validated to be a struct"); for (field, pattern) in fields { let (field_type, visibility) = expected_type @@ -434,7 +441,7 @@ impl<'context> Elaborator<'context> { pub(super) fn resolve_struct_turbofish_generics( &mut self, - struct_type: &StructType, + struct_type: &DataType, generics: Vec, unresolved_turbofish: Option>, span: Span, @@ -574,8 +581,8 @@ impl<'context> Elaborator<'context> { /// solve these fn resolve_item_turbofish(&mut self, item: PathResolutionItem) -> Vec { match item { - PathResolutionItem::StructFunction(struct_id, Some(generics), _func_id) => { - let struct_type = self.interner.get_struct(struct_id); + PathResolutionItem::Method(struct_id, Some(generics), _func_id) => { + let struct_type = self.interner.get_type(struct_id); let struct_type = struct_type.borrow(); let struct_generics = struct_type.instantiate(self.interner); self.resolve_struct_turbofish_generics( @@ -886,7 +893,7 @@ impl<'context> Elaborator<'context> { fn get_type_alias_generics(type_alias: &TypeAlias, generics: &[Type]) -> Vec { let typ = type_alias.get_type(generics); match typ { - Type::Struct(_, generics) => generics, + Type::DataType(_, generics) => generics, Type::Alias(type_alias, generics) => { get_type_alias_generics(&type_alias.borrow(), &generics) } diff --git a/noir/noir-repo/compiler/noirc_frontend/src/elaborator/scope.rs b/noir/noir-repo/compiler/noirc_frontend/src/elaborator/scope.rs index fe01e3cb7f3..327ae02b204 100644 --- a/noir/noir-repo/compiler/noirc_frontend/src/elaborator/scope.rs +++ b/noir/noir-repo/compiler/noirc_frontend/src/elaborator/scope.rs @@ -10,8 +10,8 @@ use crate::{ expr::{HirCapturedVar, HirIdent}, traits::Trait, }, - node_interner::{DefinitionId, StructId, TraitId}, - Shared, StructType, + node_interner::{DefinitionId, TraitId, TypeId}, + DataType, Shared, }; use crate::{Type, TypeAlias}; @@ -37,8 +37,8 @@ impl<'context> Elaborator<'context> { current_module } - pub(super) fn get_struct(&self, type_id: StructId) -> Shared { - self.interner.get_struct(type_id) + pub(super) fn get_type(&self, type_id: TypeId) -> Shared { + self.interner.get_type(type_id) } pub(super) fn get_trait_mut(&mut self, trait_id: TraitId) -> &mut Trait { @@ -160,12 +160,12 @@ impl<'context> Elaborator<'context> { } /// Lookup a given struct type by name. - pub fn lookup_struct_or_error(&mut self, path: Path) -> Option> { + pub fn lookup_datatype_or_error(&mut self, path: Path) -> Option> { let span = path.span(); match self.resolve_path_or_error(path) { Ok(item) => { - if let PathResolutionItem::Struct(struct_id) = item { - Some(self.get_struct(struct_id)) + if let PathResolutionItem::Type(struct_id) = item { + Some(self.get_type(struct_id)) } else { self.push_err(ResolverError::Expected { expected: "type", @@ -194,10 +194,10 @@ impl<'context> Elaborator<'context> { let span = path.span; match self.resolve_path_or_error(path) { - Ok(PathResolutionItem::Struct(struct_id)) => { - let struct_type = self.get_struct(struct_id); + Ok(PathResolutionItem::Type(struct_id)) => { + let struct_type = self.get_type(struct_id); let generics = struct_type.borrow().instantiate(self.interner); - Some(Type::Struct(struct_type, generics)) + Some(Type::DataType(struct_type, generics)) } Ok(PathResolutionItem::TypeAlias(alias_id)) => { let alias = self.interner.get_type_alias(alias_id); diff --git a/noir/noir-repo/compiler/noirc_frontend/src/elaborator/statements.rs b/noir/noir-repo/compiler/noirc_frontend/src/elaborator/statements.rs index 24653772f9f..a95e260b6a5 100644 --- a/noir/noir-repo/compiler/noirc_frontend/src/elaborator/statements.rs +++ b/noir/noir-repo/compiler/noirc_frontend/src/elaborator/statements.rs @@ -21,10 +21,10 @@ use crate::{ }, }, node_interner::{DefinitionId, DefinitionKind, GlobalId, StmtId}, - StructType, Type, + DataType, Type, }; -use super::{lints, Elaborator}; +use super::{lints, Elaborator, Loop}; impl<'context> Elaborator<'context> { fn elaborate_statement_value(&mut self, statement: Statement) -> (HirStatement, Type) { @@ -33,7 +33,7 @@ impl<'context> Elaborator<'context> { StatementKind::Constrain(constrain) => self.elaborate_constrain(constrain), StatementKind::Assign(assign) => self.elaborate_assign(assign), StatementKind::For(for_stmt) => self.elaborate_for(for_stmt), - StatementKind::Loop(block) => self.elaborate_loop(block, statement.span), + StatementKind::Loop(block, span) => self.elaborate_loop(block, span), StatementKind::Break => self.elaborate_jump(true, statement.span), StatementKind::Continue => self.elaborate_jump(false, statement.span), StatementKind::Comptime(statement) => self.elaborate_comptime_statement(*statement), @@ -227,7 +227,9 @@ impl<'context> Elaborator<'context> { let (end_range, end_range_type) = self.elaborate_expression(end); let (identifier, block) = (for_loop.identifier, for_loop.block); - self.nested_loops += 1; + let old_loop = std::mem::take(&mut self.current_loop); + + self.current_loop = Some(Loop { is_for: true, has_break: false }); self.push_scope(); // TODO: For loop variables are currently mutable by default since we haven't @@ -261,7 +263,7 @@ impl<'context> Elaborator<'context> { let (block, _block_type) = self.elaborate_expression(block); self.pop_scope(); - self.nested_loops -= 1; + self.current_loop = old_loop; let statement = HirStatement::For(HirForStatement { start_range, end_range, block, identifier }); @@ -271,11 +273,31 @@ impl<'context> Elaborator<'context> { pub(super) fn elaborate_loop( &mut self, - _block: Expression, + block: Expression, span: noirc_errors::Span, ) -> (HirStatement, Type) { - self.push_err(ResolverError::LoopNotYetSupported { span }); - (HirStatement::Error, Type::Unit) + let in_constrained_function = self.in_constrained_function(); + if in_constrained_function { + self.push_err(ResolverError::LoopInConstrainedFn { span }); + } + + let old_loop = std::mem::take(&mut self.current_loop); + self.current_loop = Some(Loop { is_for: false, has_break: false }); + self.push_scope(); + + let (block, _block_type) = self.elaborate_expression(block); + + self.pop_scope(); + + let last_loop = + std::mem::replace(&mut self.current_loop, old_loop).expect("Expected a loop"); + if !last_loop.has_break { + self.push_err(ResolverError::LoopWithoutBreak { span }); + } + + let statement = HirStatement::Loop(block); + + (statement, Type::Unit) } fn elaborate_jump(&mut self, is_break: bool, span: noirc_errors::Span) -> (HirStatement, Type) { @@ -284,7 +306,12 @@ impl<'context> Elaborator<'context> { if in_constrained_function { self.push_err(ResolverError::JumpInConstrainedFn { is_break, span }); } - if self.nested_loops == 0 { + + if let Some(current_loop) = &mut self.current_loop { + if is_break { + current_loop.has_break = true; + } + } else { self.push_err(ResolverError::JumpOutsideLoop { is_break, span }); } @@ -464,7 +491,7 @@ impl<'context> Elaborator<'context> { let lhs_type = lhs_type.follow_bindings(); match &lhs_type { - Type::Struct(s, args) => { + Type::DataType(s, args) => { let s = s.borrow(); if let Some((field, visibility, index)) = s.get_field(field_name, args) { let reference_location = Location::new(span, self.file); @@ -528,7 +555,7 @@ impl<'context> Elaborator<'context> { pub(super) fn check_struct_field_visibility( &mut self, - struct_type: &StructType, + struct_type: &DataType, field_name: &str, visibility: ItemVisibility, span: Span, diff --git a/noir/noir-repo/compiler/noirc_frontend/src/elaborator/trait_impls.rs b/noir/noir-repo/compiler/noirc_frontend/src/elaborator/trait_impls.rs index 20f048bed05..aa27ac29fa6 100644 --- a/noir/noir-repo/compiler/noirc_frontend/src/elaborator/trait_impls.rs +++ b/noir/noir-repo/compiler/noirc_frontend/src/elaborator/trait_impls.rs @@ -217,7 +217,7 @@ impl<'context> Elaborator<'context> { self.file = trait_impl.file_id; let object_crate = match &trait_impl.resolved_object_type { - Some(Type::Struct(struct_type, _)) => struct_type.borrow().id.krate(), + Some(Type::DataType(struct_type, _)) => struct_type.borrow().id.krate(), _ => CrateId::Dummy, }; diff --git a/noir/noir-repo/compiler/noirc_frontend/src/elaborator/types.rs b/noir/noir-repo/compiler/noirc_frontend/src/elaborator/types.rs index 8fa0b210605..53d0860ebf1 100644 --- a/noir/noir-repo/compiler/noirc_frontend/src/elaborator/types.rs +++ b/noir/noir-repo/compiler/noirc_frontend/src/elaborator/types.rs @@ -154,12 +154,12 @@ impl<'context> Elaborator<'context> { let location = Location::new(named_path_span.unwrap_or(typ.span), self.file); match resolved_type { - Type::Struct(ref struct_type, _) => { + Type::DataType(ref data_type, _) => { // Record the location of the type reference self.interner.push_type_ref_location(resolved_type.clone(), location); if !is_synthetic { - self.interner.add_struct_reference( - struct_type.borrow().id, + self.interner.add_type_reference( + data_type.borrow().id, location, is_self_type_name, ); @@ -259,11 +259,11 @@ impl<'context> Elaborator<'context> { return Type::Alias(type_alias, args); } - match self.lookup_struct_or_error(path) { - Some(struct_type) => { - if self.resolving_ids.contains(&struct_type.borrow().id) { - self.push_err(ResolverError::SelfReferentialStruct { - span: struct_type.borrow().name.span(), + match self.lookup_datatype_or_error(path) { + Some(data_type) => { + if self.resolving_ids.contains(&data_type.borrow().id) { + self.push_err(ResolverError::SelfReferentialType { + span: data_type.borrow().name.span(), }); return Type::Error; @@ -272,23 +272,23 @@ impl<'context> Elaborator<'context> { if !self.in_contract() && self .interner - .struct_attributes(&struct_type.borrow().id) + .type_attributes(&data_type.borrow().id) .iter() .any(|attr| matches!(attr, SecondaryAttribute::Abi(_))) { self.push_err(ResolverError::AbiAttributeOutsideContract { - span: struct_type.borrow().name.span(), + span: data_type.borrow().name.span(), }); } - let (args, _) = self.resolve_type_args(args, struct_type.borrow(), span); + let (args, _) = self.resolve_type_args(args, data_type.borrow(), span); if let Some(current_item) = self.current_item { - let dependency_id = struct_type.borrow().id; + let dependency_id = data_type.borrow().id; self.interner.add_type_dependency(current_item, dependency_id); } - Type::Struct(struct_type, args) + Type::DataType(data_type, args) } None => Type::Error, } @@ -535,7 +535,7 @@ impl<'context> Elaborator<'context> { } } (lhs, rhs) => { - let infix = Type::InfixExpr(Box::new(lhs), op, Box::new(rhs)); + let infix = Type::infix_expr(Box::new(lhs), op, Box::new(rhs)); Type::CheckedCast { from: Box::new(infix.clone()), to: Box::new(infix) } .canonicalize() } @@ -684,6 +684,60 @@ impl<'context> Elaborator<'context> { None } + /// This resolves a method in the form `Type::method` where `method` is a trait method + fn resolve_type_trait_method(&mut self, path: &Path) -> Option { + if path.segments.len() < 2 { + return None; + } + + let mut path = path.clone(); + let span = path.span(); + let last_segment = path.pop(); + let before_last_segment = path.last_segment(); + + let path_resolution = self.resolve_path(path).ok()?; + let PathResolutionItem::Type(type_id) = path_resolution.item else { + return None; + }; + + let datatype = self.get_type(type_id); + let generics = datatype.borrow().instantiate(self.interner); + let typ = Type::DataType(datatype, generics); + let method_name = &last_segment.ident.0.contents; + + // If we can find a method on the type, this is definitely not a trait method + if self.interner.lookup_direct_method(&typ, method_name, false).is_some() { + return None; + } + + let trait_methods = self.interner.lookup_trait_methods(&typ, method_name, false); + if trait_methods.is_empty() { + return None; + } + + let (hir_method_reference, error) = + self.get_trait_method_in_scope(&trait_methods, method_name, last_segment.span); + let hir_method_reference = hir_method_reference?; + let func_id = hir_method_reference.func_id(self.interner)?; + let HirMethodReference::TraitMethodId(trait_method_id, _, _) = hir_method_reference else { + return None; + }; + + let trait_id = trait_method_id.trait_id; + let trait_ = self.interner.get_trait(trait_id); + let mut constraint = trait_.as_constraint(span); + constraint.typ = typ; + + let method = TraitMethod { method_id: trait_method_id, constraint, assumed: false }; + let turbofish = before_last_segment.turbofish(); + let item = PathResolutionItem::TraitFunction(trait_id, turbofish, func_id); + let mut errors = path_resolution.errors; + if let Some(error) = error { + errors.push(error); + } + Some(TraitPathResolution { method, item: Some(item), errors }) + } + // Try to resolve the given trait method path. // // Returns the trait method, trait constraint, and whether the impl is assumed to exist by a where clause or not @@ -695,6 +749,7 @@ impl<'context> Elaborator<'context> { self.resolve_trait_static_method_by_self(path) .or_else(|| self.resolve_trait_static_method(path)) .or_else(|| self.resolve_trait_method_by_named_generic(path)) + .or_else(|| self.resolve_type_trait_method(path)) } pub(super) fn unify( @@ -1368,12 +1423,12 @@ impl<'context> Elaborator<'context> { self.lookup_method_in_trait_constraints(object_type, method_name, span) } // Mutable references to another type should resolve to methods of their element type. - // This may be a struct or a primitive type. + // This may be a data type or a primitive type. Type::MutableReference(element) => { self.lookup_method(&element, method_name, span, has_self_arg) } - // If we fail to resolve the object to a struct type, we have no way of type + // If we fail to resolve the object to a data type, we have no way of type // checking its arguments as we can't even resolve the name of the function Type::Error => None, @@ -1383,13 +1438,11 @@ impl<'context> Elaborator<'context> { None } - other => { - self.lookup_struct_or_primitive_method(&other, method_name, span, has_self_arg) - } + other => self.lookup_type_or_primitive_method(&other, method_name, span, has_self_arg), } } - fn lookup_struct_or_primitive_method( + fn lookup_type_or_primitive_method( &mut self, object_type: &Type, method_name: &str, @@ -1420,12 +1473,16 @@ impl<'context> Elaborator<'context> { return self.return_trait_method_in_scope(&generic_methods, method_name, span); } - if let Type::Struct(struct_type, _) = object_type { - let has_field_with_function_type = struct_type - .borrow() - .get_fields_as_written() - .into_iter() - .any(|field| field.name.0.contents == method_name && field.typ.is_function()); + if let Type::DataType(datatype, _) = object_type { + let datatype = datatype.borrow(); + let mut has_field_with_function_type = false; + + if let Some(fields) = datatype.fields_raw() { + has_field_with_function_type = fields + .iter() + .any(|field| field.name.0.contents == method_name && field.typ.is_function()); + } + if has_field_with_function_type { self.push_err(TypeCheckError::CannotInvokeStructFieldFunctionType { method_name: method_name.to_string(), @@ -1456,6 +1513,19 @@ impl<'context> Elaborator<'context> { method_name: &str, span: Span, ) -> Option { + let (method, error) = self.get_trait_method_in_scope(trait_methods, method_name, span); + if let Some(error) = error { + self.push_err(error); + } + method + } + + fn get_trait_method_in_scope( + &mut self, + trait_methods: &[(FuncId, TraitId)], + method_name: &str, + span: Span, + ) -> (Option, Option) { let module_id = self.module_id(); let module_data = self.get_module(module_id); @@ -1489,28 +1559,24 @@ impl<'context> Elaborator<'context> { let trait_id = *traits.iter().next().unwrap(); let trait_ = self.interner.get_trait(trait_id); let trait_name = self.fully_qualified_trait_path(trait_); - - self.push_err(PathResolutionError::TraitMethodNotInScope { + let method = + self.trait_hir_method_reference(trait_id, trait_methods, method_name, span); + let error = PathResolutionError::TraitMethodNotInScope { ident: Ident::new(method_name.into(), span), trait_name, - }); - - return Some(self.trait_hir_method_reference( - trait_id, - trait_methods, - method_name, - span, - )); + }; + return (Some(method), Some(error)); } else { let traits = vecmap(traits, |trait_id| { let trait_ = self.interner.get_trait(trait_id); self.fully_qualified_trait_path(trait_) }); - self.push_err(PathResolutionError::UnresolvedWithPossibleTraitsToImport { + let method = None; + let error = PathResolutionError::UnresolvedWithPossibleTraitsToImport { ident: Ident::new(method_name.into(), span), traits, - }); - return None; + }; + return (method, Some(error)); } } @@ -1519,15 +1585,18 @@ impl<'context> Elaborator<'context> { let trait_ = self.interner.get_trait(trait_id); self.fully_qualified_trait_path(trait_) }); - self.push_err(PathResolutionError::MultipleTraitsInScope { + let method = None; + let error = PathResolutionError::MultipleTraitsInScope { ident: Ident::new(method_name.into(), span), traits, - }); - return None; + }; + return (method, Some(error)); } let trait_id = traits_in_scope[0].0; - Some(self.trait_hir_method_reference(trait_id, trait_methods, method_name, span)) + let method = self.trait_hir_method_reference(trait_id, trait_methods, method_name, span); + let error = None; + (Some(method), error) } fn trait_hir_method_reference( @@ -1545,7 +1614,7 @@ impl<'context> Elaborator<'context> { // Return a TraitMethodId with unbound generics. These will later be bound by the type-checker. let trait_ = self.interner.get_trait(trait_id); - let generics = trait_.as_constraint(span).trait_bound.trait_generics; + let generics = trait_.get_trait_generics(span); let trait_method_id = trait_.find_method(method_name).unwrap(); HirMethodReference::TraitMethodId(trait_method_id, generics, false) } diff --git a/noir/noir-repo/compiler/noirc_frontend/src/hir/comptime/display.rs b/noir/noir-repo/compiler/noirc_frontend/src/hir/comptime/display.rs index ccdfdf00e72..cbcf8b02d03 100644 --- a/noir/noir-repo/compiler/noirc_frontend/src/hir/comptime/display.rs +++ b/noir/noir-repo/compiler/noirc_frontend/src/hir/comptime/display.rs @@ -357,7 +357,7 @@ impl<'value, 'interner> Display for ValuePrinter<'value, 'interner> { } Value::Struct(fields, typ) => { let typename = match typ.follow_bindings() { - Type::Struct(def, _) => def.borrow().name.to_string(), + Type::DataType(def, _) => def.borrow().name.to_string(), other => other.to_string(), }; let fields = vecmap(fields, |(name, value)| { @@ -376,7 +376,7 @@ impl<'value, 'interner> Display for ValuePrinter<'value, 'interner> { } Value::Quoted(tokens) => display_quoted(tokens, 0, self.interner, f), Value::StructDefinition(id) => { - let def = self.interner.get_struct(*id); + let def = self.interner.get_type(*id); let def = def.borrow(); write!(f, "{}", def.name) } @@ -732,8 +732,8 @@ fn remove_interned_in_statement_kind( block: remove_interned_in_expression(interner, for_loop.block), ..for_loop }), - StatementKind::Loop(block) => { - StatementKind::Loop(remove_interned_in_expression(interner, block)) + StatementKind::Loop(block, span) => { + StatementKind::Loop(remove_interned_in_expression(interner, block), span) } StatementKind::Comptime(statement) => { StatementKind::Comptime(Box::new(remove_interned_in_statement(interner, *statement))) diff --git a/noir/noir-repo/compiler/noirc_frontend/src/hir/comptime/errors.rs b/noir/noir-repo/compiler/noirc_frontend/src/hir/comptime/errors.rs index 6ff918328a1..e9a615f2c59 100644 --- a/noir/noir-repo/compiler/noirc_frontend/src/hir/comptime/errors.rs +++ b/noir/noir-repo/compiler/noirc_frontend/src/hir/comptime/errors.rs @@ -246,6 +246,9 @@ pub enum InterpreterError { GlobalsDependencyCycle { location: Location, }, + LoopHaltedForUiResponsiveness { + location: Location, + }, // These cases are not errors, they are just used to prevent us from running more code // until the loop can be resumed properly. These cases will never be displayed to users. @@ -323,7 +326,8 @@ impl InterpreterError { | InterpreterError::CannotSetFunctionBody { location, .. } | InterpreterError::UnknownArrayLength { location, .. } | InterpreterError::CannotInterpretFormatStringWithErrors { location } - | InterpreterError::GlobalsDependencyCycle { location } => *location, + | InterpreterError::GlobalsDependencyCycle { location } + | InterpreterError::LoopHaltedForUiResponsiveness { location } => *location, InterpreterError::FailedToParseMacro { error, file, .. } => { Location::new(error.span(), *file) @@ -683,6 +687,13 @@ impl<'a> From<&'a InterpreterError> for CustomDiagnostic { let secondary = String::new(); CustomDiagnostic::simple_error(msg, secondary, location.span) } + InterpreterError::LoopHaltedForUiResponsiveness { location } => { + let msg = "This loop took too much time to execute so it was halted for UI responsiveness" + .to_string(); + let secondary = + "This error doesn't happen in normal executions of `nargo`".to_string(); + CustomDiagnostic::simple_warning(msg, secondary, location.span) + } } } } diff --git a/noir/noir-repo/compiler/noirc_frontend/src/hir/comptime/hir_to_display_ast.rs b/noir/noir-repo/compiler/noirc_frontend/src/hir/comptime/hir_to_display_ast.rs index 9338c0fc37f..d46484d05fa 100644 --- a/noir/noir-repo/compiler/noirc_frontend/src/hir/comptime/hir_to_display_ast.rs +++ b/noir/noir-repo/compiler/noirc_frontend/src/hir/comptime/hir_to_display_ast.rs @@ -5,8 +5,8 @@ use crate::ast::{ ArrayLiteral, AssignStatement, BlockExpression, CallExpression, CastExpression, ConstrainKind, ConstructorExpression, ExpressionKind, ForLoopStatement, ForRange, GenericTypeArgs, Ident, IfExpression, IndexExpression, InfixExpression, LValue, Lambda, Literal, - MemberAccessExpression, MethodCallExpression, Path, PathSegment, Pattern, PrefixExpression, - UnresolvedType, UnresolvedTypeData, UnresolvedTypeExpression, + MemberAccessExpression, MethodCallExpression, Path, PathKind, PathSegment, Pattern, + PrefixExpression, UnresolvedType, UnresolvedTypeData, UnresolvedTypeExpression, }; use crate::ast::{ConstrainStatement, Expression, Statement, StatementKind}; use crate::hir_def::expr::{ @@ -59,6 +59,7 @@ impl HirStatement { block: for_stmt.block.to_display_ast(interner), span, }), + HirStatement::Loop(block) => StatementKind::Loop(block.to_display_ast(interner), span), HirStatement::Break => StatementKind::Break, HirStatement::Continue => StatementKind::Continue, HirStatement::Expression(expr) => { @@ -211,6 +212,19 @@ impl HirExpression { // A macro was evaluated here: return the quoted result HirExpression::Unquote(block) => ExpressionKind::Quote(block.clone()), + + // Convert this back into a function call `Enum::Foo(args)` + HirExpression::EnumConstructor(constructor) => { + let typ = constructor.r#type.borrow(); + let variant = &typ.variant_at(constructor.variant_index); + let segment1 = PathSegment { ident: typ.name.clone(), span, generics: None }; + let segment2 = PathSegment { ident: variant.name.clone(), span, generics: None }; + let path = Path { segments: vec![segment1, segment2], kind: PathKind::Plain, span }; + let func = Box::new(Expression::new(ExpressionKind::Variable(path), span)); + let arguments = vecmap(&constructor.arguments, |arg| arg.to_display_ast(interner)); + let call = CallExpression { func, arguments, is_macro_call: false }; + ExpressionKind::Call(Box::new(call)) + } }; Expression::new(kind, span) @@ -245,7 +259,7 @@ impl HirPattern { (name.clone(), pattern.to_display_ast(interner)) }); let name = match typ.follow_bindings() { - Type::Struct(struct_def, _) => { + Type::DataType(struct_def, _) => { let struct_def = struct_def.borrow(); struct_def.name.0.contents.clone() } @@ -300,7 +314,7 @@ impl Type { let fields = vecmap(fields, |field| field.to_display_ast()); UnresolvedTypeData::Tuple(fields) } - Type::Struct(def, generics) => { + Type::DataType(def, generics) => { let struct_def = def.borrow(); let ordered_args = vecmap(generics, |generic| generic.to_display_ast()); let generics = @@ -359,7 +373,7 @@ impl Type { Type::Constant(..) => panic!("Type::Constant where a type was expected: {self:?}"), Type::Quoted(quoted_type) => UnresolvedTypeData::Quoted(*quoted_type), Type::Error => UnresolvedTypeData::Error, - Type::InfixExpr(lhs, op, rhs) => { + Type::InfixExpr(lhs, op, rhs, _) => { let lhs = Box::new(lhs.to_type_expression()); let rhs = Box::new(rhs.to_type_expression()); let span = Span::default(); diff --git a/noir/noir-repo/compiler/noirc_frontend/src/hir/comptime/interpreter.rs b/noir/noir-repo/compiler/noirc_frontend/src/hir/comptime/interpreter.rs index ec4d33c3ca4..41c75220a3c 100644 --- a/noir/noir-repo/compiler/noirc_frontend/src/hir/comptime/interpreter.rs +++ b/noir/noir-repo/compiler/noirc_frontend/src/hir/comptime/interpreter.rs @@ -14,7 +14,7 @@ use crate::elaborator::Elaborator; use crate::graph::CrateId; use crate::hir::def_map::ModuleId; use crate::hir::type_check::TypeCheckError; -use crate::hir_def::expr::ImplKind; +use crate::hir_def::expr::{HirEnumConstructorExpression, ImplKind}; use crate::hir_def::function::FunctionBody; use crate::monomorphization::{ perform_impl_bindings, perform_instantiation_bindings, resolve_trait_method, @@ -539,6 +539,9 @@ impl<'local, 'interner> Interpreter<'local, 'interner> { HirExpression::Quote(tokens) => self.evaluate_quote(tokens, id), HirExpression::Comptime(block) => self.evaluate_block(block), HirExpression::Unsafe(block) => self.evaluate_block(block), + HirExpression::EnumConstructor(constructor) => { + self.evaluate_enum_constructor(constructor, id) + } HirExpression::Unquote(tokens) => { // An Unquote expression being found is indicative of a macro being // expanded within another comptime fn which we don't currently support. @@ -1285,6 +1288,14 @@ impl<'local, 'interner> Interpreter<'local, 'interner> { Ok(Value::Struct(fields, typ)) } + fn evaluate_enum_constructor( + &mut self, + _constructor: HirEnumConstructorExpression, + _id: ExprId, + ) -> IResult { + todo!("Support enums in the comptime interpreter") + } + fn evaluate_access(&mut self, access: HirMemberAccess, id: ExprId) -> IResult { let (fields, struct_type) = match self.evaluate(access.lhs)? { Value::Struct(fields, typ) => (fields, typ), @@ -1550,6 +1561,7 @@ impl<'local, 'interner> Interpreter<'local, 'interner> { HirStatement::Constrain(constrain) => self.evaluate_constrain(constrain), HirStatement::Assign(assign) => self.evaluate_assign(assign), HirStatement::For(for_) => self.evaluate_for(for_), + HirStatement::Loop(expression) => self.evaluate_loop(expression), HirStatement::Break => self.evaluate_break(statement), HirStatement::Continue => self.evaluate_continue(statement), HirStatement::Expression(expression) => self.evaluate(expression), @@ -1723,22 +1735,68 @@ impl<'local, 'interner> Interpreter<'local, 'interner> { let (end, _) = get_index(self, for_.end_range)?; let was_in_loop = std::mem::replace(&mut self.in_loop, true); + let mut result = Ok(Value::Unit); + for i in start..end { self.push_scope(); self.current_scope_mut().insert(for_.identifier.id, make_value(i)); - match self.evaluate(for_.block) { - Ok(_) => (), - Err(InterpreterError::Break) => break, - Err(InterpreterError::Continue) => continue, - Err(other) => return Err(other), + let must_break = match self.evaluate(for_.block) { + Ok(_) => false, + Err(InterpreterError::Break) => true, + Err(InterpreterError::Continue) => false, + Err(error) => { + result = Err(error); + true + } + }; + + self.pop_scope(); + + if must_break { + break; } + } + + self.in_loop = was_in_loop; + result + } + + fn evaluate_loop(&mut self, expr: ExprId) -> IResult { + let was_in_loop = std::mem::replace(&mut self.in_loop, true); + let in_lsp = self.elaborator.interner.is_in_lsp_mode(); + let mut counter = 0; + let mut result = Ok(Value::Unit); + + loop { + self.push_scope(); + + let must_break = match self.evaluate(expr) { + Ok(_) => false, + Err(InterpreterError::Break) => true, + Err(InterpreterError::Continue) => false, + Err(error) => { + result = Err(error); + true + } + }; self.pop_scope(); + + if must_break { + break; + } + + counter += 1; + if in_lsp && counter == 10_000 { + let location = self.elaborator.interner.expr_location(&expr); + result = Err(InterpreterError::LoopHaltedForUiResponsiveness { location }); + break; + } } self.in_loop = was_in_loop; - Ok(Value::Unit) + result } fn evaluate_break(&mut self, id: StmtId) -> IResult { diff --git a/noir/noir-repo/compiler/noirc_frontend/src/hir/comptime/interpreter/builtin.rs b/noir/noir-repo/compiler/noirc_frontend/src/hir/comptime/interpreter/builtin.rs index 3506b63919c..6503b0cf77b 100644 --- a/noir/noir-repo/compiler/noirc_frontend/src/hir/comptime/interpreter/builtin.rs +++ b/noir/noir-repo/compiler/noirc_frontend/src/hir/comptime/interpreter/builtin.rs @@ -66,7 +66,7 @@ impl<'local, 'context> Interpreter<'local, 'context> { "array_as_str_unchecked" => array_as_str_unchecked(interner, arguments, location), "array_len" => array_len(interner, arguments, location), "array_refcount" => Ok(Value::U32(0)), - "assert_constant" => Ok(Value::Bool(true)), + "assert_constant" => Ok(Value::Unit), "as_slice" => as_slice(interner, arguments, location), "ctstring_eq" => ctstring_eq(arguments, location), "ctstring_hash" => ctstring_hash(arguments, location), @@ -175,6 +175,7 @@ impl<'local, 'context> Interpreter<'local, 'context> { "slice_push_front" => slice_push_front(interner, arguments, location), "slice_refcount" => Ok(Value::U32(0)), "slice_remove" => slice_remove(interner, arguments, location, call_stack), + "static_assert" => static_assert(interner, arguments, location, call_stack), "str_as_bytes" => str_as_bytes(interner, arguments, location), "str_as_ctstring" => str_as_ctstring(interner, arguments, location), "struct_def_add_attribute" => struct_def_add_attribute(interner, arguments, location), @@ -327,6 +328,28 @@ fn slice_push_back( Ok(Value::Slice(values, typ)) } +// static_assert(predicate: bool, message: str) +fn static_assert( + interner: &NodeInterner, + arguments: Vec<(Value, Location)>, + location: Location, + call_stack: &im::Vector, +) -> IResult { + let (predicate, message) = check_two_arguments(arguments, location)?; + let predicate = get_bool(predicate)?; + let message = get_str(interner, message)?; + + if predicate { + Ok(Value::Unit) + } else { + failing_constraint( + format!("static_assert failed: {}", message).clone(), + location, + call_stack, + ) + } +} + fn str_as_bytes( interner: &NodeInterner, arguments: Vec<(Value, Location)>, @@ -370,7 +393,7 @@ fn struct_def_add_attribute( }; let struct_id = get_struct(self_argument)?; - interner.update_struct_attributes(struct_id, |attributes| { + interner.update_type_attributes(struct_id, |attributes| { attributes.push(attribute); }); @@ -403,7 +426,7 @@ fn struct_def_add_generic( }; let struct_id = get_struct(self_argument)?; - let the_struct = interner.get_struct(struct_id); + let the_struct = interner.get_type(struct_id); let mut the_struct = the_struct.borrow_mut(); let name = Rc::new(generic_name); @@ -436,7 +459,7 @@ fn struct_def_as_type( ) -> IResult { let argument = check_one_argument(arguments, location)?; let struct_id = get_struct(argument)?; - let struct_def_rc = interner.get_struct(struct_id); + let struct_def_rc = interner.get_type(struct_id); let struct_def = struct_def_rc.borrow(); let generics = vecmap(&struct_def.generics, |generic| { @@ -444,7 +467,7 @@ fn struct_def_as_type( }); drop(struct_def); - Ok(Value::Type(Type::Struct(struct_def_rc, generics))) + Ok(Value::Type(Type::DataType(struct_def_rc, generics))) } /// fn generics(self) -> [(Type, Option)] @@ -456,7 +479,7 @@ fn struct_def_generics( ) -> IResult { let argument = check_one_argument(arguments, location)?; let struct_id = get_struct(argument)?; - let struct_def = interner.get_struct(struct_id); + let struct_def = interner.get_type(struct_id); let struct_def = struct_def.borrow(); let expected = Type::Slice(Box::new(Type::Tuple(vec![ @@ -512,7 +535,7 @@ fn struct_def_has_named_attribute( let name = get_str(interner, name)?; - Ok(Value::Bool(has_named_attribute(&name, interner.struct_attributes(&struct_id)))) + Ok(Value::Bool(has_named_attribute(&name, interner.type_attributes(&struct_id)))) } /// fn fields(self, generic_args: [Type]) -> [(Quoted, Type)] @@ -526,7 +549,7 @@ fn struct_def_fields( ) -> IResult { let (typ, generic_args) = check_two_arguments(arguments, location)?; let struct_id = get_struct(typ)?; - let struct_def = interner.get_struct(struct_id); + let struct_def = interner.get_type(struct_id); let struct_def = struct_def.borrow(); let args_location = generic_args.1; @@ -546,9 +569,11 @@ fn struct_def_fields( let mut fields = im::Vector::new(); - for (field_name, field_type) in struct_def.get_fields(&generic_args) { - let name = Value::Quoted(Rc::new(vec![Token::Ident(field_name)])); - fields.push_back(Value::Tuple(vec![name, Value::Type(field_type)])); + if let Some(struct_fields) = struct_def.get_fields(&generic_args) { + for (field_name, field_type) in struct_fields { + let name = Value::Quoted(Rc::new(vec![Token::Ident(field_name)])); + fields.push_back(Value::Tuple(vec![name, Value::Type(field_type)])); + } } let typ = Type::Slice(Box::new(Type::Tuple(vec![ @@ -569,15 +594,17 @@ fn struct_def_fields_as_written( ) -> IResult { let argument = check_one_argument(arguments, location)?; let struct_id = get_struct(argument)?; - let struct_def = interner.get_struct(struct_id); + let struct_def = interner.get_type(struct_id); let struct_def = struct_def.borrow(); let mut fields = im::Vector::new(); - for field in struct_def.get_fields_as_written() { - let name = Value::Quoted(Rc::new(vec![Token::Ident(field.name.to_string())])); - let typ = Value::Type(field.typ); - fields.push_back(Value::Tuple(vec![name, typ])); + if let Some(struct_fields) = struct_def.get_fields_as_written() { + for field in struct_fields { + let name = Value::Quoted(Rc::new(vec![Token::Ident(field.name.to_string())])); + let typ = Value::Type(field.typ); + fields.push_back(Value::Tuple(vec![name, typ])); + } } let typ = Type::Slice(Box::new(Type::Tuple(vec![ @@ -607,7 +634,7 @@ fn struct_def_name( ) -> IResult { let self_argument = check_one_argument(arguments, location)?; let struct_id = get_struct(self_argument)?; - let the_struct = interner.get_struct(struct_id); + let the_struct = interner.get_type(struct_id); let name = Token::Ident(the_struct.borrow().name.to_string()); Ok(Value::Quoted(Rc::new(vec![name]))) @@ -623,7 +650,7 @@ fn struct_def_set_fields( let (the_struct, fields) = check_two_arguments(arguments, location)?; let struct_id = get_struct(the_struct)?; - let struct_def = interner.get_struct(struct_id); + let struct_def = interner.get_type(struct_id); let mut struct_def = struct_def.borrow_mut(); let field_location = fields.1; @@ -1057,7 +1084,7 @@ fn type_as_struct( location: Location, ) -> IResult { type_as(arguments, return_type, location, |typ| { - if let Type::Struct(struct_type, generics) = typ { + if let Type::DataType(struct_type, generics) = typ { Some(Value::Tuple(vec![ Value::StructDefinition(struct_type.borrow().id), Value::Slice( @@ -1432,8 +1459,9 @@ fn zeroed(return_type: Type, span: Span) -> IResult { } Type::Unit => Ok(Value::Unit), Type::Tuple(fields) => Ok(Value::Tuple(try_vecmap(fields, |field| zeroed(field, span))?)), - Type::Struct(struct_type, generics) => { - let fields = struct_type.borrow().get_fields(&generics); + Type::DataType(struct_type, generics) => { + // TODO: Handle enums + let fields = struct_type.borrow().get_fields(&generics).unwrap(); let mut values = HashMap::default(); for (field_name, field_type) in fields { @@ -1441,7 +1469,7 @@ fn zeroed(return_type: Type, span: Span) -> IResult { values.insert(Rc::new(field_name), field_value); } - let typ = Type::Struct(struct_type, generics); + let typ = Type::DataType(struct_type, generics); Ok(Value::Struct(values, typ)) } Type::Alias(alias, generics) => zeroed(alias.borrow().get_type(&generics), span), @@ -2890,7 +2918,7 @@ pub(crate) fn option(option_type: Type, value: Option, span: Span) -> IRe /// Given a type, assert that it's an Option and return the Type for T pub(crate) fn extract_option_generic_type(typ: Type) -> Type { - let Type::Struct(struct_type, mut generics) = typ else { + let Type::DataType(struct_type, mut generics) = typ else { panic!("Expected type to be a struct"); }; diff --git a/noir/noir-repo/compiler/noirc_frontend/src/hir/comptime/interpreter/builtin/builtin_helpers.rs b/noir/noir-repo/compiler/noirc_frontend/src/hir/comptime/interpreter/builtin/builtin_helpers.rs index a3f84a00bfb..342f494023d 100644 --- a/noir/noir-repo/compiler/noirc_frontend/src/hir/comptime/interpreter/builtin/builtin_helpers.rs +++ b/noir/noir-repo/compiler/noirc_frontend/src/hir/comptime/interpreter/builtin/builtin_helpers.rs @@ -28,11 +28,11 @@ use crate::{ function::{FuncMeta, FunctionBody}, stmt::HirPattern, }, - node_interner::{FuncId, NodeInterner, StructId, TraitId, TraitImplId}, + node_interner::{FuncId, NodeInterner, TraitId, TraitImplId, TypeId}, token::{SecondaryAttribute, Token, Tokens}, QuotedType, Type, }; -use crate::{Kind, Shared, StructType}; +use crate::{DataType, Kind, Shared}; use rustc_hash::FxHashMap as HashMap; pub(crate) fn check_argument_count( @@ -108,14 +108,13 @@ pub(crate) fn get_struct_fields( match value { Value::Struct(fields, typ) => Ok((fields, typ)), _ => { - let expected = StructType::new( - StructId::dummy_id(), + let expected = DataType::new( + TypeId::dummy_id(), Ident::new(name.to_string(), location.span), location, Vec::new(), - Vec::new(), ); - let expected = Type::Struct(Shared::new(expected), Vec::new()); + let expected = Type::DataType(Shared::new(expected), Vec::new()); type_mismatch(value, expected, location) } } @@ -327,7 +326,7 @@ pub(crate) fn get_module((value, location): (Value, Location)) -> IResult IResult { +pub(crate) fn get_struct((value, location): (Value, Location)) -> IResult { match value { Value::StructDefinition(id) => Ok(id), _ => type_mismatch(value, Type::Quoted(QuotedType::StructDefinition), location), @@ -434,7 +433,7 @@ fn gather_hir_pattern_tokens( tokens.push(Token::RightParen); } HirPattern::Struct(typ, fields, _) => { - let Type::Struct(struct_type, _) = typ.follow_bindings() else { + let Type::DataType(struct_type, _) = typ.follow_bindings() else { panic!("Expected type to be a struct"); }; diff --git a/noir/noir-repo/compiler/noirc_frontend/src/hir/comptime/value.rs b/noir/noir-repo/compiler/noirc_frontend/src/hir/comptime/value.rs index 77933ba9361..c5ec7d861cd 100644 --- a/noir/noir-repo/compiler/noirc_frontend/src/hir/comptime/value.rs +++ b/noir/noir-repo/compiler/noirc_frontend/src/hir/comptime/value.rs @@ -18,7 +18,7 @@ use crate::{ HirArrayLiteral, HirConstructorExpression, HirExpression, HirIdent, HirLambda, HirLiteral, ImplKind, }, - node_interner::{ExprId, FuncId, NodeInterner, StmtId, StructId, TraitId, TraitImplId}, + node_interner::{ExprId, FuncId, NodeInterner, StmtId, TraitId, TraitImplId, TypeId}, parser::{Item, Parser}, token::{SpannedToken, Token, Tokens}, Kind, QuotedType, Shared, Type, TypeBindings, @@ -62,7 +62,7 @@ pub enum Value { /// tokens can cause larger spans to be before lesser spans, causing an assert. They may also /// be inserted into separate files entirely. Quoted(Rc>), - StructDefinition(StructId), + StructDefinition(TypeId), TraitConstraint(TraitId, TraitGenerics), TraitDefinition(TraitId), TraitImpl(TraitImplId), @@ -234,7 +234,7 @@ impl Value { })?; let struct_type = match typ.follow_bindings() { - Type::Struct(def, _) => Some(def.borrow().id), + Type::DataType(def, _) => Some(def.borrow().id), _ => return Err(InterpreterError::NonStructInConstructor { typ, location }), }; @@ -388,7 +388,7 @@ impl Value { })?; let (r#type, struct_generics) = match typ.follow_bindings() { - Type::Struct(def, generics) => (def, generics), + Type::DataType(def, generics) => (def, generics), _ => return Err(InterpreterError::NonStructInConstructor { typ, location }), }; diff --git a/noir/noir-repo/compiler/noirc_frontend/src/hir/def_collector/dc_crate.rs b/noir/noir-repo/compiler/noirc_frontend/src/hir/def_collector/dc_crate.rs index 7f6509b9f16..9aad806bb3c 100644 --- a/noir/noir-repo/compiler/noirc_frontend/src/hir/def_collector/dc_crate.rs +++ b/noir/noir-repo/compiler/noirc_frontend/src/hir/def_collector/dc_crate.rs @@ -14,10 +14,10 @@ use crate::{Generics, Type}; use crate::hir::resolution::import::{resolve_import, ImportDirective}; use crate::hir::Context; -use crate::ast::Expression; +use crate::ast::{Expression, NoirEnumeration}; use crate::node_interner::{ - FuncId, GlobalId, ModuleAttributes, NodeInterner, ReferenceId, StructId, TraitId, TraitImplId, - TypeAliasId, + FuncId, GlobalId, ModuleAttributes, NodeInterner, ReferenceId, TraitId, TraitImplId, + TypeAliasId, TypeId, }; use crate::ast::{ @@ -64,6 +64,12 @@ pub struct UnresolvedStruct { pub struct_def: NoirStruct, } +pub struct UnresolvedEnum { + pub file_id: FileId, + pub module_id: LocalModuleId, + pub enum_def: NoirEnumeration, +} + #[derive(Clone)] pub struct UnresolvedTrait { pub file_id: FileId, @@ -141,7 +147,8 @@ pub struct DefCollector { #[derive(Default)] pub struct CollectedItems { pub functions: Vec, - pub(crate) types: BTreeMap, + pub(crate) structs: BTreeMap, + pub(crate) enums: BTreeMap, pub(crate) type_aliases: BTreeMap, pub(crate) traits: BTreeMap, pub globals: Vec, @@ -153,7 +160,7 @@ pub struct CollectedItems { impl CollectedItems { pub fn is_empty(&self) -> bool { self.functions.is_empty() - && self.types.is_empty() + && self.structs.is_empty() && self.type_aliases.is_empty() && self.traits.is_empty() && self.globals.is_empty() @@ -254,7 +261,8 @@ impl DefCollector { imports: vec![], items: CollectedItems { functions: vec![], - types: BTreeMap::new(), + structs: BTreeMap::new(), + enums: BTreeMap::new(), type_aliases: BTreeMap::new(), traits: BTreeMap::new(), impls: HashMap::default(), diff --git a/noir/noir-repo/compiler/noirc_frontend/src/hir/def_collector/dc_mod.rs b/noir/noir-repo/compiler/noirc_frontend/src/hir/def_collector/dc_mod.rs index ead6a801ba7..f6f31638557 100644 --- a/noir/noir-repo/compiler/noirc_frontend/src/hir/def_collector/dc_mod.rs +++ b/noir/noir-repo/compiler/noirc_frontend/src/hir/def_collector/dc_mod.rs @@ -12,11 +12,12 @@ use rustc_hash::FxHashMap as HashMap; use crate::ast::{ Documented, Expression, FunctionDefinition, Ident, ItemVisibility, LetStatement, - ModuleDeclaration, NoirFunction, NoirStruct, NoirTrait, NoirTraitImpl, NoirTypeAlias, Pattern, - TraitImplItemKind, TraitItem, TypeImpl, UnresolvedType, UnresolvedTypeData, + ModuleDeclaration, NoirEnumeration, NoirFunction, NoirStruct, NoirTrait, NoirTraitImpl, + NoirTypeAlias, Pattern, TraitImplItemKind, TraitItem, TypeImpl, UnresolvedType, + UnresolvedTypeData, }; use crate::hir::resolution::errors::ResolverError; -use crate::node_interner::{ModuleAttributes, NodeInterner, ReferenceId, StructId}; +use crate::node_interner::{ModuleAttributes, NodeInterner, ReferenceId, TypeId}; use crate::token::SecondaryAttribute; use crate::usage_tracker::{UnusedItem, UsageTracker}; use crate::{ @@ -27,8 +28,8 @@ use crate::{ }; use crate::{Generics, Kind, ResolvedGeneric, Type, TypeVariable}; -use super::dc_crate::CollectedItems; use super::dc_crate::ModuleAttribute; +use super::dc_crate::{CollectedItems, UnresolvedEnum}; use super::{ dc_crate::{ CompilationError, DefCollector, UnresolvedFunctions, UnresolvedGlobal, UnresolvedTraitImpl, @@ -91,7 +92,9 @@ pub fn collect_defs( errors.extend(collector.collect_traits(context, ast.traits, crate_id)); - errors.extend(collector.collect_structs(context, ast.types, crate_id)); + errors.extend(collector.collect_structs(context, ast.structs, crate_id)); + + errors.extend(collector.collect_enums(context, ast.enums, crate_id)); errors.extend(collector.collect_type_aliases(context, ast.type_aliases, crate_id)); @@ -317,7 +320,34 @@ impl<'a> ModCollector<'a> { krate, &mut definition_errors, ) { - self.def_collector.items.types.insert(id, the_struct); + self.def_collector.items.structs.insert(id, the_struct); + } + } + definition_errors + } + + /// Collect any enum definitions declared within the ast. + /// Returns a vector of errors if any enums were already defined, + /// or if an enum has duplicate variants in it. + fn collect_enums( + &mut self, + context: &mut Context, + types: Vec>, + krate: CrateId, + ) -> Vec<(CompilationError, FileId)> { + let mut definition_errors = vec![]; + for enum_definition in types { + if let Some((id, the_enum)) = collect_enum( + &mut context.def_interner, + &mut self.def_collector.def_map, + &mut context.usage_tracker, + enum_definition, + self.file_id, + self.module_id, + krate, + &mut definition_errors, + ) { + self.def_collector.items.enums.insert(id, the_enum); } } definition_errors @@ -336,6 +366,7 @@ impl<'a> ModCollector<'a> { let doc_comments = type_alias.doc_comments; let type_alias = type_alias.item; let name = type_alias.name.clone(); + let location = Location::new(name.span(), self.file_id); let visibility = type_alias.visibility; // And store the TypeId -> TypeAlias mapping somewhere it is reachable @@ -389,6 +420,7 @@ impl<'a> ModCollector<'a> { context.def_interner.register_type_alias( type_alias_id, name, + location, visibility, parent_module_id, ); @@ -410,6 +442,7 @@ impl<'a> ModCollector<'a> { let doc_comments = trait_definition.doc_comments; let trait_definition = trait_definition.item; let name = trait_definition.name.clone(); + let location = Location::new(trait_definition.name.span(), self.file_id); // Create the corresponding module for the trait namespace let trait_id = match self.push_child_module( @@ -503,7 +536,10 @@ impl<'a> ModCollector<'a> { .push_function_definition(func_id, modifiers, trait_id.0, location); let referenced = ReferenceId::Function(func_id); - context.def_interner.add_definition_location(referenced, Some(trait_id.0)); + let module_id = Some(trait_id.0); + context + .def_interner + .add_definition_location(referenced, location, module_id); if !trait_item.doc_comments.is_empty() { context.def_interner.set_doc_comments( @@ -633,6 +669,7 @@ impl<'a> ModCollector<'a> { context.def_interner.register_trait( trait_id, name.to_string(), + location, visibility, parent_module_id, ); @@ -818,7 +855,7 @@ impl<'a> ModCollector<'a> { inner_attributes: Vec, add_to_parent_scope: bool, is_contract: bool, - is_struct: bool, + is_type: bool, ) -> Result { push_child_module( &mut context.def_interner, @@ -831,7 +868,7 @@ impl<'a> ModCollector<'a> { inner_attributes, add_to_parent_scope, is_contract, - is_struct, + is_type, ) } @@ -867,7 +904,7 @@ fn push_child_module( inner_attributes: Vec, add_to_parent_scope: bool, is_contract: bool, - is_struct: bool, + is_type: bool, ) -> Result { // Note: the difference between `location` and `mod_location` is: // - `mod_location` will point to either the token "foo" in `mod foo { ... }` @@ -883,7 +920,7 @@ fn push_child_module( outer_attributes, inner_attributes, is_contract, - is_struct, + is_type, ); let module_id = def_map.modules.insert(new_module); @@ -997,7 +1034,7 @@ pub fn collect_struct( module_id: LocalModuleId, krate: CrateId, definition_errors: &mut Vec<(CompilationError, FileId)>, -) -> Option<(StructId, UnresolvedStruct)> { +) -> Option<(TypeId, UnresolvedStruct)> { let doc_comments = struct_definition.doc_comments; let struct_definition = struct_definition.item; @@ -1030,7 +1067,11 @@ pub fn collect_struct( true, // is struct ) { Ok(module_id) => { - interner.new_struct(&unresolved, resolved_generics, krate, module_id.local_id, file_id) + let name = unresolved.struct_def.name.clone(); + let span = unresolved.struct_def.span; + let attributes = unresolved.struct_def.attributes.clone(); + let local_id = module_id.local_id; + interner.new_type(name, span, attributes, resolved_generics, krate, local_id, file_id) } Err(error) => { definition_errors.push((error.into(), file_id)); @@ -1038,7 +1079,7 @@ pub fn collect_struct( } }; - interner.set_doc_comments(ReferenceId::Struct(id), doc_comments); + interner.set_doc_comments(ReferenceId::Type(id), doc_comments); for (index, field) in unresolved.struct_def.fields.iter().enumerate() { if !field.doc_comments.is_empty() { @@ -1049,7 +1090,7 @@ pub fn collect_struct( // Add the struct to scope so its path can be looked up later let visibility = unresolved.struct_def.visibility; - let result = def_map.modules[module_id.0].declare_struct(name.clone(), visibility, id); + let result = def_map.modules[module_id.0].declare_type(name.clone(), visibility, id); let parent_module_id = ModuleId { krate, local_id: module_id }; @@ -1072,7 +1113,102 @@ pub fn collect_struct( } if interner.is_in_lsp_mode() { - interner.register_struct(id, name.to_string(), visibility, parent_module_id); + interner.register_type(id, name.to_string(), location, visibility, parent_module_id); + } + + Some((id, unresolved)) +} + +#[allow(clippy::too_many_arguments)] +pub fn collect_enum( + interner: &mut NodeInterner, + def_map: &mut CrateDefMap, + usage_tracker: &mut UsageTracker, + enum_def: Documented, + file_id: FileId, + module_id: LocalModuleId, + krate: CrateId, + definition_errors: &mut Vec<(CompilationError, FileId)>, +) -> Option<(TypeId, UnresolvedEnum)> { + let doc_comments = enum_def.doc_comments; + let enum_def = enum_def.item; + + check_duplicate_variant_names(&enum_def, file_id, definition_errors); + + let name = enum_def.name.clone(); + + let unresolved = UnresolvedEnum { file_id, module_id, enum_def }; + + let resolved_generics = Context::resolve_generics( + interner, + &unresolved.enum_def.generics, + definition_errors, + file_id, + ); + + // Create the corresponding module for the enum namespace + let location = Location::new(name.span(), file_id); + let id = match push_child_module( + interner, + def_map, + module_id, + &name, + ItemVisibility::Public, + location, + Vec::new(), + Vec::new(), + false, // add to parent scope + false, // is contract + true, // is type + ) { + Ok(module_id) => { + let name = unresolved.enum_def.name.clone(); + let span = unresolved.enum_def.span; + let attributes = unresolved.enum_def.attributes.clone(); + let local_id = module_id.local_id; + interner.new_type(name, span, attributes, resolved_generics, krate, local_id, file_id) + } + Err(error) => { + definition_errors.push((error.into(), file_id)); + return None; + } + }; + + interner.set_doc_comments(ReferenceId::Type(id), doc_comments); + + for (index, variant) in unresolved.enum_def.variants.iter().enumerate() { + if !variant.doc_comments.is_empty() { + let id = ReferenceId::EnumVariant(id, index); + interner.set_doc_comments(id, variant.doc_comments.clone()); + } + } + + // Add the enum to scope so its path can be looked up later + let visibility = unresolved.enum_def.visibility; + let result = def_map.modules[module_id.0].declare_type(name.clone(), visibility, id); + + let parent_module_id = ModuleId { krate, local_id: module_id }; + + if !unresolved.enum_def.is_abi() { + usage_tracker.add_unused_item( + parent_module_id, + name.clone(), + UnusedItem::Enum(id), + visibility, + ); + } + + if let Err((first_def, second_def)) = result { + let error = DefCollectorErrorKind::Duplicate { + typ: DuplicateType::TypeDefinition, + first_def, + second_def, + }; + definition_errors.push((error.into(), file_id)); + } + + if interner.is_in_lsp_mode() { + interner.register_type(id, name.to_string(), location, visibility, parent_module_id); } Some((id, unresolved)) @@ -1318,14 +1454,35 @@ fn check_duplicate_field_names( } let previous_field_name = *seen_field_names.get(field_name).unwrap(); - definition_errors.push(( - DefCollectorErrorKind::DuplicateField { - first_def: previous_field_name.clone(), - second_def: field_name.clone(), - } - .into(), - file, - )); + let error = DefCollectorErrorKind::Duplicate { + typ: DuplicateType::StructField, + first_def: previous_field_name.clone(), + second_def: field_name.clone(), + }; + definition_errors.push((error.into(), file)); + } +} + +fn check_duplicate_variant_names( + enum_def: &NoirEnumeration, + file: FileId, + definition_errors: &mut Vec<(CompilationError, FileId)>, +) { + let mut seen_variant_names = std::collections::HashSet::new(); + for variant in &enum_def.variants { + let variant_name = &variant.item.name; + + if seen_variant_names.insert(variant_name) { + continue; + } + + let previous_variant_name = *seen_variant_names.get(variant_name).unwrap(); + let error = DefCollectorErrorKind::Duplicate { + typ: DuplicateType::EnumVariant, + first_def: previous_variant_name.clone(), + second_def: variant_name.clone(), + }; + definition_errors.push((error.into(), file)); } } diff --git a/noir/noir-repo/compiler/noirc_frontend/src/hir/def_collector/errors.rs b/noir/noir-repo/compiler/noirc_frontend/src/hir/def_collector/errors.rs index 1582e297144..1ca62acd29b 100644 --- a/noir/noir-repo/compiler/noirc_frontend/src/hir/def_collector/errors.rs +++ b/noir/noir-repo/compiler/noirc_frontend/src/hir/def_collector/errors.rs @@ -21,21 +21,21 @@ pub enum DuplicateType { TraitAssociatedType, TraitAssociatedConst, TraitAssociatedFunction, + StructField, + EnumVariant, } #[derive(Error, Debug, Clone)] pub enum DefCollectorErrorKind { - #[error("duplicate {typ} found in namespace")] + #[error("Duplicate {typ}")] Duplicate { typ: DuplicateType, first_def: Ident, second_def: Ident }, - #[error("duplicate struct field {first_def}")] - DuplicateField { first_def: Ident, second_def: Ident }, - #[error("unresolved import")] + #[error("Unresolved import")] UnresolvedModuleDecl { mod_name: Ident, expected_path: String, alternative_path: String }, - #[error("overlapping imports")] + #[error("Overlapping imports")] OverlappingModuleDecls { mod_name: Ident, expected_path: String, alternative_path: String }, - #[error("path resolution error")] + #[error("Path resolution error")] PathResolutionError(PathResolutionError), - #[error("cannot re-export {item_name} because it has less visibility than this use statement")] + #[error("Cannot re-export {item_name} because it has less visibility than this use statement")] CannotReexportItemWithLessVisibility { item_name: Ident, desired_visibility: ItemVisibility }, #[error("Non-struct type used in impl")] NonStructTypeInImpl { span: Span }, @@ -120,6 +120,8 @@ impl fmt::Display for DuplicateType { DuplicateType::TraitAssociatedType => write!(f, "trait associated type"), DuplicateType::TraitAssociatedConst => write!(f, "trait associated constant"), DuplicateType::TraitAssociatedFunction => write!(f, "trait associated function"), + DuplicateType::StructField => write!(f, "struct field"), + DuplicateType::EnumVariant => write!(f, "enum variant"), } } } @@ -144,23 +146,6 @@ impl<'a> From<&'a DefCollectorErrorKind> for Diagnostic { diag } } - DefCollectorErrorKind::DuplicateField { first_def, second_def } => { - let primary_message = format!( - "Duplicate definitions of struct field with name {} found", - &first_def.0.contents - ); - { - let first_span = first_def.0.span(); - let second_span = second_def.0.span(); - let mut diag = Diagnostic::simple_error( - primary_message, - "First definition found here".to_string(), - first_span, - ); - diag.add_secondary("Second definition found here".to_string(), second_span); - diag - } - } DefCollectorErrorKind::UnresolvedModuleDecl { mod_name, expected_path, alternative_path } => { let span = mod_name.0.span(); let mod_name = &mod_name.0.contents; diff --git a/noir/noir-repo/compiler/noirc_frontend/src/hir/def_map/mod.rs b/noir/noir-repo/compiler/noirc_frontend/src/hir/def_map/mod.rs index f7fc6ca08ea..fae891a1647 100644 --- a/noir/noir-repo/compiler/noirc_frontend/src/hir/def_map/mod.rs +++ b/noir/noir-repo/compiler/noirc_frontend/src/hir/def_map/mod.rs @@ -1,7 +1,7 @@ use crate::graph::{CrateGraph, CrateId}; use crate::hir::def_collector::dc_crate::{CompilationError, DefCollector}; use crate::hir::Context; -use crate::node_interner::{FuncId, GlobalId, NodeInterner, StructId}; +use crate::node_interner::{FuncId, GlobalId, NodeInterner, TypeId}; use crate::parse_program; use crate::parser::{ParsedModule, ParserError}; use crate::token::{FunctionAttribute, SecondaryAttribute, TestScope}; @@ -241,7 +241,7 @@ impl CrateDefMap { module.type_definitions().for_each(|id| { if let ModuleDefId::TypeId(struct_id) = id { - interner.struct_attributes(&struct_id).iter().for_each(|attr| { + interner.type_attributes(&struct_id).iter().for_each(|attr| { if let SecondaryAttribute::Abi(tag) = attr { if let Some(tagged) = outputs.structs.get_mut(tag) { tagged.push(struct_id); @@ -356,7 +356,7 @@ pub struct ContractFunctionMeta { } pub struct ContractOutputs { - pub structs: HashMap>, + pub structs: HashMap>, pub globals: HashMap>, } diff --git a/noir/noir-repo/compiler/noirc_frontend/src/hir/def_map/module_data.rs b/noir/noir-repo/compiler/noirc_frontend/src/hir/def_map/module_data.rs index 06188f3920b..5df0e08cbb2 100644 --- a/noir/noir-repo/compiler/noirc_frontend/src/hir/def_map/module_data.rs +++ b/noir/noir-repo/compiler/noirc_frontend/src/hir/def_map/module_data.rs @@ -4,7 +4,7 @@ use noirc_errors::Location; use super::{ItemScope, LocalModuleId, ModuleDefId, ModuleId, PerNs}; use crate::ast::{Ident, ItemVisibility}; -use crate::node_interner::{FuncId, GlobalId, StructId, TraitId, TypeAliasId}; +use crate::node_interner::{FuncId, GlobalId, TraitId, TypeAliasId, TypeId}; use crate::token::SecondaryAttribute; /// Contains the actual contents of a module: its parent (if one exists), @@ -31,8 +31,8 @@ pub struct ModuleData { /// True if this module is a `contract Foo { ... }` module containing contract functions pub is_contract: bool, - /// True if this module is actually a struct - pub is_struct: bool, + /// True if this module is actually a type + pub is_type: bool, pub attributes: Vec, } @@ -44,7 +44,7 @@ impl ModuleData { outer_attributes: Vec, inner_attributes: Vec, is_contract: bool, - is_struct: bool, + is_type: bool, ) -> ModuleData { let mut attributes = outer_attributes; attributes.extend(inner_attributes); @@ -57,7 +57,7 @@ impl ModuleData { definitions: ItemScope::default(), location, is_contract, - is_struct, + is_type, attributes, } } @@ -120,11 +120,11 @@ impl ModuleData { self.declare(name, visibility, id.into(), None) } - pub fn declare_struct( + pub fn declare_type( &mut self, name: Ident, visibility: ItemVisibility, - id: StructId, + id: TypeId, ) -> Result<(), (Ident, Ident)> { self.declare(name, visibility, ModuleDefId::TypeId(id), None) } diff --git a/noir/noir-repo/compiler/noirc_frontend/src/hir/def_map/module_def.rs b/noir/noir-repo/compiler/noirc_frontend/src/hir/def_map/module_def.rs index a751eacd2dd..40d57ae2e23 100644 --- a/noir/noir-repo/compiler/noirc_frontend/src/hir/def_map/module_def.rs +++ b/noir/noir-repo/compiler/noirc_frontend/src/hir/def_map/module_def.rs @@ -1,4 +1,4 @@ -use crate::node_interner::{FuncId, GlobalId, StructId, TraitId, TypeAliasId}; +use crate::node_interner::{FuncId, GlobalId, TraitId, TypeAliasId, TypeId}; use super::ModuleId; @@ -7,7 +7,7 @@ use super::ModuleId; pub enum ModuleDefId { ModuleId(ModuleId), FunctionId(FuncId), - TypeId(StructId), + TypeId(TypeId), TypeAliasId(TypeAliasId), TraitId(TraitId), GlobalId(GlobalId), @@ -21,7 +21,7 @@ impl ModuleDefId { } } - pub fn as_type(&self) -> Option { + pub fn as_type(&self) -> Option { match self { ModuleDefId::TypeId(type_id) => Some(*type_id), _ => None, diff --git a/noir/noir-repo/compiler/noirc_frontend/src/hir/mod.rs b/noir/noir-repo/compiler/noirc_frontend/src/hir/mod.rs index b231f8c9698..fea52be88bc 100644 --- a/noir/noir-repo/compiler/noirc_frontend/src/hir/mod.rs +++ b/noir/noir-repo/compiler/noirc_frontend/src/hir/mod.rs @@ -9,7 +9,7 @@ use crate::ast::UnresolvedGenerics; use crate::debug::DebugInstrumenter; use crate::graph::{CrateGraph, CrateId}; use crate::hir_def::function::FuncMeta; -use crate::node_interner::{FuncId, NodeInterner, StructId}; +use crate::node_interner::{FuncId, NodeInterner, TypeId}; use crate::parser::ParserError; use crate::usage_tracker::UsageTracker; use crate::{Generics, Kind, ParsedModule, ResolvedGeneric, TypeVariable}; @@ -151,7 +151,7 @@ impl Context<'_, '_> { /// /// For example, if you project contains a `main.nr` and `foo.nr` and you provide the `main_crate_id` and the /// `bar_struct_id` where the `Bar` struct is inside `foo.nr`, this function would return `foo::Bar` as a [String]. - pub fn fully_qualified_struct_path(&self, crate_id: &CrateId, id: StructId) -> String { + pub fn fully_qualified_struct_path(&self, crate_id: &CrateId, id: TypeId) -> String { fully_qualified_module_path(&self.def_maps, &self.crate_graph, crate_id, id.module_id()) } diff --git a/noir/noir-repo/compiler/noirc_frontend/src/hir/resolution/errors.rs b/noir/noir-repo/compiler/noirc_frontend/src/hir/resolution/errors.rs index e0e09d53311..6298ef796b4 100644 --- a/noir/noir-repo/compiler/noirc_frontend/src/hir/resolution/errors.rs +++ b/noir/noir-repo/compiler/noirc_frontend/src/hir/resolution/errors.rs @@ -98,6 +98,10 @@ pub enum ResolverError { DependencyCycle { span: Span, item: String, cycle: String }, #[error("break/continue are only allowed in unconstrained functions")] JumpInConstrainedFn { is_break: bool, span: Span }, + #[error("`loop` is only allowed in unconstrained functions")] + LoopInConstrainedFn { span: Span }, + #[error("`loop` must have at least one `break` in it")] + LoopWithoutBreak { span: Span }, #[error("break/continue are only allowed within loops")] JumpOutsideLoop { is_break: bool, span: Span }, #[error("Only `comptime` globals can be mutable")] @@ -112,8 +116,8 @@ pub enum ResolverError { NonIntegralGlobalType { span: Span, global_value: Value }, #[error("Global value `{global_value}` is larger than its kind's maximum value")] GlobalLargerThanKind { span: Span, global_value: FieldElement, kind: Kind }, - #[error("Self-referential structs are not supported")] - SelfReferentialStruct { span: Span }, + #[error("Self-referential types are not supported")] + SelfReferentialType { span: Span }, #[error("#[no_predicates] attribute is only allowed on constrained functions")] NoPredicatesAttributeOnUnconstrained { ident: Ident }, #[error("#[fold] attribute is only allowed on constrained functions")] @@ -434,6 +438,20 @@ impl<'a> From<&'a ResolverError> for Diagnostic { *span, ) }, + ResolverError::LoopInConstrainedFn { span } => { + Diagnostic::simple_error( + "loop is only allowed in unconstrained functions".into(), + "Constrained code must always have a known number of loop iterations".into(), + *span, + ) + }, + ResolverError::LoopWithoutBreak { span } => { + Diagnostic::simple_error( + "`loop` must have at least one `break` in it".into(), + "Infinite loops are disallowed".into(), + *span, + ) + }, ResolverError::JumpOutsideLoop { is_break, span } => { let item = if *is_break { "break" } else { "continue" }; Diagnostic::simple_error( @@ -484,9 +502,9 @@ impl<'a> From<&'a ResolverError> for Diagnostic { *span, ) } - ResolverError::SelfReferentialStruct { span } => { + ResolverError::SelfReferentialType { span } => { Diagnostic::simple_error( - "Self-referential structs are not supported".into(), + "Self-referential types are not supported".into(), "".into(), *span, ) diff --git a/noir/noir-repo/compiler/noirc_frontend/src/hir/resolution/visibility.rs b/noir/noir-repo/compiler/noirc_frontend/src/hir/resolution/visibility.rs index 557f799df89..c592175ffcb 100644 --- a/noir/noir-repo/compiler/noirc_frontend/src/hir/resolution/visibility.rs +++ b/noir/noir-repo/compiler/noirc_frontend/src/hir/resolution/visibility.rs @@ -1,5 +1,5 @@ use crate::graph::CrateId; -use crate::node_interner::{FuncId, NodeInterner, StructId, TraitId}; +use crate::node_interner::{FuncId, NodeInterner, TraitId, TypeId}; use crate::Type; use std::collections::BTreeMap; @@ -71,11 +71,11 @@ fn module_is_parent_of_struct_module( target: LocalModuleId, ) -> bool { let module_data = &def_map.modules[target.0]; - module_data.is_struct && module_data.parent == Some(current) + module_data.is_type && module_data.parent == Some(current) } pub fn struct_member_is_visible( - struct_id: StructId, + struct_id: TypeId, visibility: ItemVisibility, current_module_id: ModuleId, def_maps: &BTreeMap, @@ -158,7 +158,7 @@ pub fn method_call_is_visible( ); } - if let Some(struct_id) = func_meta.struct_id { + if let Some(struct_id) = func_meta.type_id { return struct_member_is_visible( struct_id, modifiers.visibility, diff --git a/noir/noir-repo/compiler/noirc_frontend/src/hir/type_check/generics.rs b/noir/noir-repo/compiler/noirc_frontend/src/hir/type_check/generics.rs index 370223f1f11..f823b495040 100644 --- a/noir/noir-repo/compiler/noirc_frontend/src/hir/type_check/generics.rs +++ b/noir/noir-repo/compiler/noirc_frontend/src/hir/type_check/generics.rs @@ -5,7 +5,7 @@ use iter_extended::vecmap; use crate::{ hir_def::traits::NamedType, node_interner::{FuncId, NodeInterner, TraitId, TypeAliasId}, - ResolvedGeneric, StructType, Type, + DataType, ResolvedGeneric, Type, }; /// Represents something that can be generic over type variables @@ -74,7 +74,7 @@ impl Generic for TypeAliasId { } } -impl Generic for Ref<'_, StructType> { +impl Generic for Ref<'_, DataType> { fn item_kind(&self) -> &'static str { "struct" } diff --git a/noir/noir-repo/compiler/noirc_frontend/src/hir_def/expr.rs b/noir/noir-repo/compiler/noirc_frontend/src/hir_def/expr.rs index 9b3bf4962bb..647969471ab 100644 --- a/noir/noir-repo/compiler/noirc_frontend/src/hir_def/expr.rs +++ b/noir/noir-repo/compiler/noirc_frontend/src/hir_def/expr.rs @@ -12,7 +12,7 @@ use crate::Shared; use super::stmt::HirPattern; use super::traits::{ResolvedTraitBound, TraitConstraint}; -use super::types::{StructType, Type}; +use super::types::{DataType, Type}; /// A HirExpression is the result of an Expression in the AST undergoing /// name resolution. It is almost identical to the Expression AST node, but @@ -30,6 +30,7 @@ pub enum HirExpression { Infix(HirInfixExpression), Index(HirIndexExpression), Constructor(HirConstructorExpression), + EnumConstructor(HirEnumConstructorExpression), MemberAccess(HirMemberAccess), Call(HirCallExpression), MethodCall(HirMethodCallExpression), @@ -225,24 +226,15 @@ impl HirMethodReference { } } } -} -impl HirMethodCallExpression { - /// Converts a method call into a function call - /// - /// Returns ((func_var_id, func_var), call_expr) - pub fn into_function_call( - mut self, - method: HirMethodReference, + pub fn into_function_id_and_name( + self, object_type: Type, - is_macro_call: bool, + generics: Option>, location: Location, interner: &mut NodeInterner, - ) -> ((ExprId, HirIdent), HirCallExpression) { - let mut arguments = vec![self.object]; - arguments.append(&mut self.arguments); - - let (id, impl_kind) = match method { + ) -> (ExprId, HirIdent) { + let (id, impl_kind) = match self { HirMethodReference::FuncId(func_id) => { (interner.function_definition_id(func_id), ImplKind::NotATraitMethod) } @@ -261,16 +253,28 @@ impl HirMethodCallExpression { } }; let func_var = HirIdent { location, id, impl_kind }; - let func = interner.push_expr(HirExpression::Ident(func_var.clone(), self.generics)); + let func = interner.push_expr(HirExpression::Ident(func_var.clone(), generics)); interner.push_expr_location(func, location.span, location.file); - let expr = HirCallExpression { func, arguments, location, is_macro_call }; - ((func, func_var), expr) + (func, func_var) + } +} + +impl HirMethodCallExpression { + pub fn into_function_call( + mut self, + func: ExprId, + is_macro_call: bool, + location: Location, + ) -> HirCallExpression { + let mut arguments = vec![self.object]; + arguments.append(&mut self.arguments); + HirCallExpression { func, arguments, location, is_macro_call } } } #[derive(Debug, Clone)] pub struct HirConstructorExpression { - pub r#type: Shared, + pub r#type: Shared, pub struct_generics: Vec, // NOTE: It is tempting to make this a BTreeSet to force ordering of field @@ -281,6 +285,28 @@ pub struct HirConstructorExpression { pub fields: Vec<(Ident, ExprId)>, } +/// An enum constructor is an expression such as `Option::Some(foo)` +/// to construct an enum. These are usually inserted by the compiler itself +/// since `Some` is actually a function with the body implicitly being an +/// enum constructor expression, but in the future these may be directly +/// represented when using enums with named fields. +/// +/// During monomorphization, these expressions are translated to tuples of +/// (tag, variant0_fields, variant1_fields, ..) since we cannot actually +/// make a true union in a circuit. +#[derive(Debug, Clone)] +pub struct HirEnumConstructorExpression { + pub r#type: Shared, + pub enum_generics: Vec, + pub variant_index: usize, + + /// This refers to just the arguments that are passed. E.g. just + /// `foo` in `Foo::Bar(foo)`, even if other variants have their + /// "fields" defaulted to `std::mem::zeroed`, these aren't specified + /// at this step. + pub arguments: Vec, +} + /// Indexing, as in `array[index]` #[derive(Debug, Clone)] pub struct HirIndexExpression { diff --git a/noir/noir-repo/compiler/noirc_frontend/src/hir_def/function.rs b/noir/noir-repo/compiler/noirc_frontend/src/hir_def/function.rs index aa04738733f..75bb4f50541 100644 --- a/noir/noir-repo/compiler/noirc_frontend/src/hir_def/function.rs +++ b/noir/noir-repo/compiler/noirc_frontend/src/hir_def/function.rs @@ -8,7 +8,7 @@ use super::traits::TraitConstraint; use crate::ast::{BlockExpression, FunctionKind, FunctionReturnType, Visibility}; use crate::graph::CrateId; use crate::hir::def_map::LocalModuleId; -use crate::node_interner::{ExprId, NodeInterner, StructId, TraitId, TraitImplId}; +use crate::node_interner::{ExprId, NodeInterner, TraitId, TraitImplId, TypeId}; use crate::{ResolvedGeneric, Type}; @@ -132,8 +132,8 @@ pub struct FuncMeta { pub trait_constraints: Vec, - /// The struct this function belongs to, if any - pub struct_id: Option, + /// The type this method belongs to, if any + pub type_id: Option, // The trait this function belongs to, if any pub trait_id: Option, @@ -141,6 +141,9 @@ pub struct FuncMeta { /// The trait impl this function belongs to, if any pub trait_impl: Option, + /// If this function is the one related to an enum variant, this holds its index (relative to `type_id`) + pub enum_variant_index: Option, + /// True if this function is an entry point to the program. /// For non-contracts, this means the function is `main`. pub is_entry_point: bool, diff --git a/noir/noir-repo/compiler/noirc_frontend/src/hir_def/stmt.rs b/noir/noir-repo/compiler/noirc_frontend/src/hir_def/stmt.rs index c42b8230290..8a580e735b1 100644 --- a/noir/noir-repo/compiler/noirc_frontend/src/hir_def/stmt.rs +++ b/noir/noir-repo/compiler/noirc_frontend/src/hir_def/stmt.rs @@ -16,6 +16,7 @@ pub enum HirStatement { Constrain(HirConstrainStatement), Assign(HirAssignStatement), For(HirForStatement), + Loop(ExprId), Break, Continue, Expression(ExprId), diff --git a/noir/noir-repo/compiler/noirc_frontend/src/hir_def/traits.rs b/noir/noir-repo/compiler/noirc_frontend/src/hir_def/traits.rs index ff0cac027b1..a80c25492a3 100644 --- a/noir/noir-repo/compiler/noirc_frontend/src/hir_def/traits.rs +++ b/noir/noir-repo/compiler/noirc_frontend/src/hir_def/traits.rs @@ -186,22 +186,22 @@ impl Trait { (ordered, named) } - /// Returns a TraitConstraint for this trait using Self as the object - /// type and the uninstantiated generics for any trait generics. - pub fn as_constraint(&self, span: Span) -> TraitConstraint { + pub fn get_trait_generics(&self, span: Span) -> TraitGenerics { let ordered = vecmap(&self.generics, |generic| generic.clone().as_named_generic()); let named = vecmap(&self.associated_types, |generic| { let name = Ident::new(generic.name.to_string(), span); NamedType { name, typ: generic.clone().as_named_generic() } }); + TraitGenerics { ordered, named } + } + /// Returns a TraitConstraint for this trait using Self as the object + /// type and the uninstantiated generics for any trait generics. + pub fn as_constraint(&self, span: Span) -> TraitConstraint { + let trait_generics = self.get_trait_generics(span); TraitConstraint { typ: Type::TypeVariable(self.self_type_typevar.clone()), - trait_bound: ResolvedTraitBound { - trait_generics: TraitGenerics { ordered, named }, - trait_id: self.id, - span, - }, + trait_bound: ResolvedTraitBound { trait_generics, trait_id: self.id, span }, } } } diff --git a/noir/noir-repo/compiler/noirc_frontend/src/hir_def/types.rs b/noir/noir-repo/compiler/noirc_frontend/src/hir_def/types.rs index c0dbf6f9500..1a9241b3b46 100644 --- a/noir/noir-repo/compiler/noirc_frontend/src/hir_def/types.rs +++ b/noir/noir-repo/compiler/noirc_frontend/src/hir_def/types.rs @@ -21,7 +21,7 @@ use noirc_printable_type::PrintableType; use crate::{ ast::{Ident, Signedness}, - node_interner::StructId, + node_interner::TypeId, }; use super::{ @@ -67,7 +67,7 @@ pub enum Type { /// A user-defined struct type. The `Shared` field here refers to /// the shared definition for each instance of this struct type. The `Vec` /// represents the generic arguments (if any) to this struct type. - Struct(Shared, Vec), + DataType(Shared, Vec), /// A user-defined alias to another type. Similar to a Struct, this carries a shared /// reference to the definition of the alias along with any generics that may have @@ -97,10 +97,7 @@ pub enum Type { /// A cast (to, from) that's checked at monomorphization. /// /// Simplifications on arithmetic generics are only allowed on the LHS. - CheckedCast { - from: Box, - to: Box, - }, + CheckedCast { from: Box, to: Box }, /// A functions with arguments, a return type and environment. /// the environment should be `Unit` by default, @@ -132,7 +129,13 @@ pub enum Type { /// The type of quoted code in macros. This is always a comptime-only type Quoted(QuotedType), - InfixExpr(Box, BinaryTypeOperator, Box), + /// An infix expression in the form `lhs * rhs`. + /// + /// The `inversion` bool keeps track of whether this expression came from + /// an expression like `4 = a / b` which was transformed to `a = 4 / b` + /// so that if at some point a infix expression `b * (4 / b)` is created, + /// it could be simplified back to `4`. + InfixExpr(Box, BinaryTypeOperator, Box, bool /* inversion */), /// The result of some type error. Remembering type errors as their own type variant lets /// us avoid issuing repeat type errors for the same item. For example, a lambda with @@ -312,6 +315,7 @@ pub enum QuotedType { Type, TypedExpr, StructDefinition, + EnumDefinition, TraitConstraint, TraitDefinition, TraitImpl, @@ -326,31 +330,52 @@ pub enum QuotedType { /// the binding to later be undone if needed. pub type TypeBindings = HashMap; -/// Represents a struct type in the type system. Each instance of this -/// rust struct will be shared across all Type::Struct variants that represent -/// the same struct type. -pub struct StructType { - /// A unique id representing this struct type. Used to check if two - /// struct types are equal. - pub id: StructId, +/// Represents a struct or enum type in the type system. Each instance of this +/// rust struct will be shared across all Type::DataType variants that represent +/// the same struct or enum type. +pub struct DataType { + /// A unique id representing this type. Used to check if two types are equal. + pub id: TypeId, pub name: Ident, - /// Fields are ordered and private, they should only - /// be accessed through get_field(), get_fields(), or instantiate() + /// A type's body is private to force struct fields or enum variants to only be + /// accessed through get_field(), get_fields(), instantiate(), or similar functions /// since these will handle applying generic arguments to fields as well. - fields: Vec, + body: TypeBody, pub generics: Generics, pub location: Location, } +enum TypeBody { + /// A type with no body is still in the process of being created + None, + Struct(Vec), + + #[allow(unused)] + Enum(Vec), +} + +#[derive(Clone)] pub struct StructField { pub visibility: ItemVisibility, pub name: Ident, pub typ: Type, } +#[derive(Clone)] +pub struct EnumVariant { + pub name: Ident, + pub params: Vec, +} + +impl EnumVariant { + pub fn new(name: Ident, params: Vec) -> EnumVariant { + Self { name, params } + } +} + /// Corresponds to generic lists such as `` in the source program. /// Used mainly for resolved types which no longer need information such /// as names or kinds @@ -384,42 +409,35 @@ enum FunctionCoercionResult { UnconstrainedMismatch(Type), } -impl std::hash::Hash for StructType { +impl std::hash::Hash for DataType { fn hash(&self, state: &mut H) { self.id.hash(state); } } -impl Eq for StructType {} +impl Eq for DataType {} -impl PartialEq for StructType { +impl PartialEq for DataType { fn eq(&self, other: &Self) -> bool { self.id == other.id } } -impl PartialOrd for StructType { +impl PartialOrd for DataType { fn partial_cmp(&self, other: &Self) -> Option { Some(self.cmp(other)) } } -impl Ord for StructType { +impl Ord for DataType { fn cmp(&self, other: &Self) -> std::cmp::Ordering { self.id.cmp(&other.id) } } -impl StructType { - pub fn new( - id: StructId, - name: Ident, - - location: Location, - fields: Vec, - generics: Generics, - ) -> StructType { - StructType { id, fields, name, location, generics } +impl DataType { + pub fn new(id: TypeId, name: Ident, location: Location, generics: Generics) -> DataType { + DataType { id, name, location, generics, body: TypeBody::None } } /// To account for cyclic references between structs, a struct's @@ -427,14 +445,56 @@ impl StructType { /// created. Therefore, this method is used to set the fields once they /// become known. pub fn set_fields(&mut self, fields: Vec) { - self.fields = fields; + self.body = TypeBody::Struct(fields); + } + + pub(crate) fn init_variants(&mut self) { + match &mut self.body { + TypeBody::None => { + self.body = TypeBody::Enum(vec![]); + } + _ => panic!("Called init_variants but body was None"), + } } - pub fn num_fields(&self) -> usize { - self.fields.len() + pub(crate) fn push_variant(&mut self, variant: EnumVariant) { + match &mut self.body { + TypeBody::Enum(variants) => variants.push(variant), + _ => panic!("Called push_variant on {self} but body wasn't an enum"), + } + } + + pub fn is_struct(&self) -> bool { + matches!(&self.body, TypeBody::Struct(_)) + } + + /// Retrieve the fields of this type with no modifications. + /// Returns None if this is not a struct type. + pub fn fields_raw(&self) -> Option<&[StructField]> { + match &self.body { + TypeBody::Struct(fields) => Some(fields), + _ => None, + } + } + + /// Retrieve the variants of this type with no modifications. + /// Panics if this is not an enum type. + fn variants_raw(&self) -> Option<&[EnumVariant]> { + match &self.body { + TypeBody::Enum(variants) => Some(variants), + _ => None, + } + } + + /// Return the generics on this type as a vector of types + pub fn generic_types(&self) -> Vec { + vecmap(&self.generics, |generic| { + Type::NamedGeneric(generic.type_var.clone(), generic.name.clone()) + }) } /// Returns the field matching the given field name, as well as its visibility and field index. + /// Always returns None if this is not a struct type. pub fn get_field( &self, field_name: &str, @@ -442,45 +502,52 @@ impl StructType { ) -> Option<(Type, ItemVisibility, usize)> { assert_eq!(self.generics.len(), generic_args.len()); - self.fields.iter().enumerate().find(|(_, field)| field.name.0.contents == field_name).map( - |(i, field)| { - let substitutions = self - .generics - .iter() - .zip(generic_args) - .map(|(old, new)| { - ( - old.type_var.id(), - (old.type_var.clone(), old.type_var.kind(), new.clone()), - ) - }) - .collect(); + let mut fields = self.fields_raw()?.iter().enumerate(); + fields.find(|(_, field)| field.name.0.contents == field_name).map(|(i, field)| { + let generics = self.generics.iter().zip(generic_args); + let substitutions = generics + .map(|(old, new)| { + (old.type_var.id(), (old.type_var.clone(), old.type_var.kind(), new.clone())) + }) + .collect(); - (field.typ.substitute(&substitutions), field.visibility, i) - }, - ) + (field.typ.substitute(&substitutions), field.visibility, i) + }) } /// Returns all the fields of this type, after being applied to the given generic arguments. + /// Returns None if this is not a struct type. pub fn get_fields_with_visibility( &self, generic_args: &[Type], - ) -> Vec<(String, ItemVisibility, Type)> { + ) -> Option> { let substitutions = self.get_fields_substitutions(generic_args); - vecmap(&self.fields, |field| { + Some(vecmap(self.fields_raw()?, |field| { let name = field.name.0.contents.clone(); (name, field.visibility, field.typ.substitute(&substitutions)) - }) + })) } - pub fn get_fields(&self, generic_args: &[Type]) -> Vec<(String, Type)> { + /// Retrieve the fields of this type. Returns None if this is not a field type + pub fn get_fields(&self, generic_args: &[Type]) -> Option> { let substitutions = self.get_fields_substitutions(generic_args); - vecmap(&self.fields, |field| { + Some(vecmap(self.fields_raw()?, |field| { let name = field.name.0.contents.clone(); (name, field.typ.substitute(&substitutions)) - }) + })) + } + + /// Retrieve the variants of this type. Returns None if this is not an enum type + pub fn get_variants(&self, generic_args: &[Type]) -> Option)>> { + let substitutions = self.get_fields_substitutions(generic_args); + + Some(vecmap(self.variants_raw()?, |variant| { + let name = variant.name.to_string(); + let args = vecmap(&variant.params, |param| param.substitute(&substitutions)); + (name, args) + })) } fn get_fields_substitutions( @@ -504,21 +571,36 @@ impl StructType { /// /// This method is almost never what is wanted for type checking or monomorphization, /// prefer to use `get_fields` whenever possible. - pub fn get_fields_as_written(&self) -> Vec { - vecmap(&self.fields, |field| StructField { - visibility: field.visibility, - name: field.name.clone(), - typ: field.typ.clone(), - }) + /// + /// Returns None if this is not a struct type. + pub fn get_fields_as_written(&self) -> Option> { + Some(self.fields_raw()?.to_vec()) + } + + /// Returns the name and raw parameters of each variant of this type. + /// This will not substitute any generic arguments so a generic variant like `X` + /// in `enum Foo { X(T) }` will return a `("X", Vec)` pair. + /// + /// Returns None if this is not an enum type. + pub fn get_variants_as_written(&self) -> Option> { + Some(self.variants_raw()?.to_vec()) } - /// Returns the field at the given index. Panics if no field exists at the given index. + /// Returns the field at the given index. Panics if no field exists at the given index or this + /// is not a struct type. pub fn field_at(&self, index: usize) -> &StructField { - &self.fields[index] + &self.fields_raw().unwrap()[index] } - pub fn field_names(&self) -> BTreeSet { - self.fields.iter().map(|field| field.name.clone()).collect() + /// Returns the enum variant at the given index. Panics if no field exists at the given index + /// or this is not an enum type. + pub fn variant_at(&self, index: usize) -> &EnumVariant { + &self.variants_raw().unwrap()[index] + } + + /// Returns each of this type's field names. Returns None if this is not a struct type. + pub fn field_names(&self) -> Option> { + Some(self.fields_raw()?.iter().map(|field| field.name.clone()).collect()) } /// Instantiate this struct type, returning a Vec of the new generic args (in @@ -526,9 +608,38 @@ impl StructType { pub fn instantiate(&self, interner: &mut NodeInterner) -> Vec { vecmap(&self.generics, |generic| interner.next_type_variable_with_kind(generic.kind())) } + + /// Returns the function type of the variant at the given index of this enum. + /// Requires the `Shared` handle of self to create the given function type. + /// Panics if this is not an enum. + /// + /// The function type uses the variant "as written" ie. no generic substitutions. + /// Although the returned function is technically generic, Type::Function is returned + /// instead of Type::Forall. + pub fn variant_function_type(&self, variant_index: usize, this: Shared) -> Type { + let variant = self.variant_at(variant_index); + let args = variant.params.clone(); + assert_eq!(this.borrow().id, self.id); + let generics = self.generic_types(); + let ret = Box::new(Type::DataType(this, generics)); + Type::Function(args, ret, Box::new(Type::Unit), false) + } + + /// Returns the function type of the variant at the given index of this enum. + /// Requires the `Shared` handle of self to create the given function type. + /// Panics if this is not an enum. + pub fn variant_function_type_with_forall( + &self, + variant_index: usize, + this: Shared, + ) -> Type { + let function_type = self.variant_function_type(variant_index, this); + let typevars = vecmap(&self.generics, |generic| generic.type_var.clone()); + Type::Forall(typevars, Box::new(function_type)) + } } -impl std::fmt::Display for StructType { +impl std::fmt::Display for DataType { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "{}", self.name) } @@ -846,7 +957,7 @@ impl std::fmt::Display for Type { } } } - Type::Struct(s, args) => { + Type::DataType(s, args) => { let args = vecmap(args, |arg| arg.to_string()); if args.is_empty() { write!(f, "{}", s.borrow()) @@ -905,7 +1016,7 @@ impl std::fmt::Display for Type { write!(f, "&mut {element}") } Type::Quoted(quoted) => write!(f, "{}", quoted), - Type::InfixExpr(lhs, op, rhs) => { + Type::InfixExpr(lhs, op, rhs, _) => { let this = self.canonicalize_checked(); // Prevent infinite recursion @@ -955,6 +1066,7 @@ impl std::fmt::Display for QuotedType { QuotedType::Type => write!(f, "Type"), QuotedType::TypedExpr => write!(f, "TypedExpr"), QuotedType::StructDefinition => write!(f, "StructDefinition"), + QuotedType::EnumDefinition => write!(f, "EnumDefinition"), QuotedType::TraitDefinition => write!(f, "TraitDefinition"), QuotedType::TraitConstraint => write!(f, "TraitConstraint"), QuotedType::TraitImpl => write!(f, "TraitImpl"), @@ -1079,7 +1191,7 @@ impl Type { alias_type.borrow().get_type(&generics).is_primitive() } Type::MutableReference(typ) => typ.is_primitive(), - Type::Struct(..) + Type::DataType(..) | Type::TypeVariable(..) | Type::TraitAsType(..) | Type::NamedGeneric(..) @@ -1140,13 +1252,16 @@ impl Type { } Type::String(length) => length.is_valid_for_program_input(), Type::Tuple(elements) => elements.iter().all(|elem| elem.is_valid_for_program_input()), - Type::Struct(definition, generics) => definition - .borrow() - .get_fields(generics) - .into_iter() - .all(|(_, field)| field.is_valid_for_program_input()), + Type::DataType(definition, generics) => { + if let Some(fields) = definition.borrow().get_fields(generics) { + fields.into_iter().all(|(_, field)| field.is_valid_for_program_input()) + } else { + // Arbitrarily disallow enums from program input, though we may support them later + false + } + } - Type::InfixExpr(lhs, _, rhs) => { + Type::InfixExpr(lhs, _, rhs, _) => { lhs.is_valid_for_program_input() && rhs.is_valid_for_program_input() } } @@ -1195,11 +1310,14 @@ impl Type { } Type::String(length) => length.is_valid_non_inlined_function_input(), Type::Tuple(elements) => elements.iter().all(|elem| elem.is_valid_non_inlined_function_input()), - Type::Struct(definition, generics) => definition - .borrow() - .get_fields(generics) - .into_iter() - .all(|(_, field)| field.is_valid_non_inlined_function_input()), + Type::DataType(definition, generics) => { + if let Some(fields) = definition.borrow().get_fields(generics) { + fields.into_iter() + .all(|(_, field)| field.is_valid_non_inlined_function_input()) + } else { + false + } + } } } @@ -1247,11 +1365,13 @@ impl Type { Type::Tuple(elements) => { elements.iter().all(|elem| elem.is_valid_for_unconstrained_boundary()) } - Type::Struct(definition, generics) => definition - .borrow() - .get_fields(generics) - .into_iter() - .all(|(_, field)| field.is_valid_for_unconstrained_boundary()), + Type::DataType(definition, generics) => { + if let Some(fields) = definition.borrow().get_fields(generics) { + fields.into_iter().all(|(_, field)| field.is_valid_for_unconstrained_boundary()) + } else { + false + } + } } } @@ -1307,7 +1427,7 @@ impl Type { TypeBinding::Bound(ref typ) => typ.kind(), TypeBinding::Unbound(_, ref type_var_kind) => type_var_kind.clone(), }, - Type::InfixExpr(lhs, _op, rhs) => lhs.infix_kind(rhs), + Type::InfixExpr(lhs, _op, rhs, _) => lhs.infix_kind(rhs), Type::Alias(def, generics) => def.borrow().get_type(generics).kind(), // This is a concrete FieldElement, not an IntegerOrField Type::FieldElement @@ -1319,7 +1439,7 @@ impl Type { | Type::FmtString(..) | Type::Unit | Type::Tuple(..) - | Type::Struct(..) + | Type::DataType(..) | Type::TraitAsType(..) | Type::Function(..) | Type::MutableReference(..) @@ -1340,6 +1460,48 @@ impl Type { } } + /// Creates an `InfixExpr`. + pub fn infix_expr(lhs: Box, op: BinaryTypeOperator, rhs: Box) -> Type { + Self::new_infix_expr(lhs, op, rhs, false) + } + + /// Creates an `InfixExpr` that results from the compiler trying to unify something like + /// `4 = a * b` into `a = 4 / b` (where `4 / b` is the "inverted" expression). + pub fn inverted_infix_expr(lhs: Box, op: BinaryTypeOperator, rhs: Box) -> Type { + Self::new_infix_expr(lhs, op, rhs, true) + } + + pub fn new_infix_expr( + lhs: Box, + op: BinaryTypeOperator, + rhs: Box, + inversion: bool, + ) -> Type { + // If an InfixExpr like this is tried to be created: + // + // a * (b / a) + // + // where `b / a` resulted from the compiler creating an inverted InfixExpr from a previous + // unification (that is, the compiler had `b = a / y` and ended up doing `y = b / a` where + // `y` is `rhs` here) then we can simplify this to just `b` because there wasn't an actual + // division in the original expression, so multiplying it back is just going back to the + // original `y` + if let Type::InfixExpr(rhs_lhs, rhs_op, rhs_rhs, true) = &*rhs { + if op.approx_inverse() == Some(*rhs_op) && lhs == *rhs_rhs { + return *rhs_lhs.clone(); + } + } + + // Same thing but on the other side. + if let Type::InfixExpr(lhs_lhs, lhs_op, lhs_rhs, true) = &*lhs { + if op.approx_inverse() == Some(*lhs_op) && rhs == *lhs_rhs { + return *lhs_lhs.clone(); + } + } + + Self::InfixExpr(lhs, op, rhs, inversion) + } + /// Returns the number of field elements required to represent the type once encoded. pub fn field_count(&self, location: &Location) -> u32 { match self { @@ -1351,10 +1513,21 @@ impl Type { let typ = typ.as_ref(); length * typ.field_count(location) } - Type::Struct(def, args) => { + Type::DataType(def, args) => { let struct_type = def.borrow(); - let fields = struct_type.get_fields(args); - fields.iter().fold(0, |acc, (_, field_type)| acc + field_type.field_count(location)) + if let Some(fields) = struct_type.get_fields(args) { + fields.iter().map(|(_, field_type)| field_type.field_count(location)).sum() + } else if let Some(variants) = struct_type.get_variants(args) { + let mut size = 1; // start with the tag size + for (_, args) in variants { + for arg in args { + size += arg.field_count(location); + } + } + size + } else { + 0 + } } Type::CheckedCast { to, .. } => to.field_count(location), Type::Alias(def, generics) => def.borrow().get_type(generics).field_count(location), @@ -1392,10 +1565,14 @@ impl Type { pub(crate) fn contains_slice(&self) -> bool { match self { Type::Slice(_) => true, - Type::Struct(struct_typ, generics) => { - let fields = struct_typ.borrow().get_fields(generics); - for field in fields.iter() { - if field.1.contains_slice() { + Type::DataType(typ, generics) => { + let typ = typ.borrow(); + if let Some(fields) = typ.get_fields(generics) { + if fields.iter().any(|(_, field)| field.contains_slice()) { + return true; + } + } else if let Some(variants) = typ.get_variants(generics) { + if variants.iter().flat_map(|(_, args)| args).any(|typ| typ.contains_slice()) { return true; } } @@ -1640,7 +1817,7 @@ impl Type { // No recursive try_unify call for struct fields. Don't want // to mutate shared type variables within struct definitions. // This isn't possible currently but will be once noir gets generic types - (Struct(id_a, args_a), Struct(id_b, args_b)) => { + (DataType(id_a, args_a), DataType(id_b, args_b)) => { if id_a == id_b && args_a.len() == args_b.len() { for (a, b) in args_a.iter().zip(args_b) { a.try_unify(b, bindings)?; @@ -1697,7 +1874,7 @@ impl Type { elem_a.try_unify(elem_b, bindings) } - (InfixExpr(lhs_a, op_a, rhs_a), InfixExpr(lhs_b, op_b, rhs_b)) => { + (InfixExpr(lhs_a, op_a, rhs_a, _), InfixExpr(lhs_b, op_b, rhs_b, _)) => { if op_a == op_b { // We need to preserve the original bindings since if syntactic equality // fails we fall back to other equality strategies. @@ -1724,14 +1901,15 @@ impl Type { } else { Err(UnificationError) } - } else if let InfixExpr(lhs, op, rhs) = other { + } else if let InfixExpr(lhs, op, rhs, _) = other { if let Some(inverse) = op.approx_inverse() { // Handle cases like `4 = a + b` by trying to solve to `a = 4 - b` - let new_type = InfixExpr( + let new_type = Type::inverted_infix_expr( Box::new(Constant(*value, kind.clone())), inverse, rhs.clone(), ); + new_type.try_unify(lhs, bindings)?; Ok(()) } else { @@ -1937,7 +2115,7 @@ impl Type { }) } } - Type::InfixExpr(lhs, op, rhs) => { + Type::InfixExpr(lhs, op, rhs, _) => { let infix_kind = lhs.infix_kind(&rhs); if kind.unifies(&infix_kind) { let lhs_value = lhs.evaluate_to_field_element_helper( @@ -1989,28 +2167,6 @@ impl Type { } } - /// Iterate over the fields of this type. - /// Panics if the type is not a struct or tuple. - pub fn iter_fields(&self) -> impl Iterator { - let fields: Vec<_> = match self { - // Unfortunately the .borrow() here forces us to collect into a Vec - // only to have to call .into_iter again afterward. Trying to elide - // collecting to a Vec leads to us dropping the temporary Ref before - // the iterator is returned - Type::Struct(def, args) => vecmap(&def.borrow().fields, |field| { - let name = &field.name.0.contents; - let typ = def.borrow().get_field(name, args).unwrap().0; - (name.clone(), typ) - }), - Type::Tuple(fields) => { - let fields = fields.iter().enumerate(); - vecmap(fields, |(i, field)| (i.to_string(), field.clone())) - } - other => panic!("Tried to iterate over the fields of '{other}', which has none"), - }; - fields.into_iter() - } - /// Retrieves the type of the given field name /// Panics if the type is not a struct or tuple. pub fn get_field_type_and_visibility( @@ -2018,7 +2174,7 @@ impl Type { field_name: &str, ) -> Option<(Type, ItemVisibility)> { match self.follow_bindings() { - Type::Struct(def, args) => def + Type::DataType(def, args) => def .borrow() .get_field(field_name, &args) .map(|(typ, visibility, _)| (typ, visibility)), @@ -2218,11 +2374,11 @@ impl Type { // Do not substitute_helper fields, it can lead to infinite recursion // and we should not match fields when type checking anyway. - Type::Struct(fields, args) => { + Type::DataType(fields, args) => { let args = vecmap(args, |arg| { arg.substitute_helper(type_bindings, substitute_bound_typevars) }); - Type::Struct(fields.clone(), args) + Type::DataType(fields.clone(), args) } Type::Alias(alias, args) => { let args = vecmap(args, |arg| { @@ -2267,10 +2423,10 @@ impl Type { }); Type::TraitAsType(*s, name.clone(), TraitGenerics { ordered, named }) } - Type::InfixExpr(lhs, op, rhs) => { + Type::InfixExpr(lhs, op, rhs, inversion) => { let lhs = lhs.substitute_helper(type_bindings, substitute_bound_typevars); let rhs = rhs.substitute_helper(type_bindings, substitute_bound_typevars); - Type::InfixExpr(Box::new(lhs), *op, Box::new(rhs)) + Type::InfixExpr(Box::new(lhs), *op, Box::new(rhs), *inversion) } Type::FieldElement @@ -2294,7 +2450,7 @@ impl Type { let field_occurs = fields.occurs(target_id); len_occurs || field_occurs } - Type::Struct(_, generic_args) | Type::Alias(_, generic_args) => { + Type::DataType(_, generic_args) | Type::Alias(_, generic_args) => { generic_args.iter().any(|arg| arg.occurs(target_id)) } Type::TraitAsType(_, _, args) => { @@ -2320,7 +2476,7 @@ impl Type { || env.occurs(target_id) } Type::MutableReference(element) => element.occurs(target_id), - Type::InfixExpr(lhs, _op, rhs) => lhs.occurs(target_id) || rhs.occurs(target_id), + Type::InfixExpr(lhs, _op, rhs, _) => lhs.occurs(target_id) || rhs.occurs(target_id), Type::FieldElement | Type::Integer(_, _) @@ -2351,9 +2507,9 @@ impl Type { let args = Box::new(args.follow_bindings()); FmtString(size, args) } - Struct(def, args) => { + DataType(def, args) => { let args = vecmap(args, |arg| arg.follow_bindings()); - Struct(def.clone(), args) + DataType(def.clone(), args) } Alias(def, args) => { // We don't need to vecmap(args, follow_bindings) since we're recursively @@ -2389,10 +2545,10 @@ impl Type { }); TraitAsType(*s, name.clone(), TraitGenerics { ordered, named }) } - InfixExpr(lhs, op, rhs) => { + InfixExpr(lhs, op, rhs, inversion) => { let lhs = lhs.follow_bindings(); let rhs = rhs.follow_bindings(); - InfixExpr(Box::new(lhs), *op, Box::new(rhs)) + InfixExpr(Box::new(lhs), *op, Box::new(rhs), *inversion) } // Expect that this function should only be called on instantiated types @@ -2451,7 +2607,7 @@ impl Type { field.replace_named_generics_with_type_variables(); } } - Type::Struct(_, generics) => { + Type::DataType(_, generics) => { for generic in generics { generic.replace_named_generics_with_type_variables(); } @@ -2502,7 +2658,7 @@ impl Type { } Type::MutableReference(elem) => elem.replace_named_generics_with_type_variables(), Type::Forall(_, typ) => typ.replace_named_generics_with_type_variables(), - Type::InfixExpr(lhs, _op, rhs) => { + Type::InfixExpr(lhs, _op, rhs, _) => { lhs.replace_named_generics_with_type_variables(); rhs.replace_named_generics_with_type_variables(); } @@ -2544,7 +2700,7 @@ impl Type { TypeBinding::Unbound(_, kind) => kind.integral_maximum_size(), }, Type::MutableReference(typ) => typ.integral_maximum_size(), - Type::InfixExpr(lhs, _op, rhs) => lhs.infix_kind(rhs).integral_maximum_size(), + Type::InfixExpr(lhs, _op, rhs, _) => lhs.infix_kind(rhs).integral_maximum_size(), Type::Constant(_, kind) => kind.integral_maximum_size(), Type::Array(..) @@ -2553,7 +2709,7 @@ impl Type { | Type::FmtString(..) | Type::Unit | Type::Tuple(..) - | Type::Struct(..) + | Type::DataType(..) | Type::TraitAsType(..) | Type::Function(..) | Type::Forall(..) @@ -2711,11 +2867,20 @@ impl From<&Type> for PrintableType { Type::Error => unreachable!(), Type::Unit => PrintableType::Unit, Type::Constant(_, _) => unreachable!(), - Type::Struct(def, ref args) => { - let struct_type = def.borrow(); - let fields = struct_type.get_fields(args); - let fields = vecmap(fields, |(name, typ)| (name, typ.into())); - PrintableType::Struct { fields, name: struct_type.name.to_string() } + Type::DataType(def, ref args) => { + let data_type = def.borrow(); + let name = data_type.name.to_string(); + + if let Some(fields) = data_type.get_fields(args) { + let fields = vecmap(fields, |(name, typ)| (name, typ.into())); + PrintableType::Struct { fields, name } + } else if let Some(variants) = data_type.get_variants(args) { + let variants = + vecmap(variants, |(name, args)| (name, vecmap(args, Into::into))); + PrintableType::Enum { name, variants } + } else { + unreachable!() + } } Type::Alias(alias, args) => alias.borrow().get_type(args).into(), Type::TraitAsType(..) => unreachable!(), @@ -2767,7 +2932,7 @@ impl std::fmt::Debug for Type { write!(f, "{}", binding.borrow()) } } - Type::Struct(s, args) => { + Type::DataType(s, args) => { let args = vecmap(args, |arg| format!("{:?}", arg)); if args.is_empty() { write!(f, "{}", s.borrow()) @@ -2827,7 +2992,7 @@ impl std::fmt::Debug for Type { write!(f, "&mut {element:?}") } Type::Quoted(quoted) => write!(f, "{}", quoted), - Type::InfixExpr(lhs, op, rhs) => write!(f, "({lhs:?} {op} {rhs:?})"), + Type::InfixExpr(lhs, op, rhs, _) => write!(f, "({lhs:?} {op} {rhs:?})"), } } } @@ -2849,7 +3014,7 @@ impl std::fmt::Debug for TypeVariable { } } -impl std::fmt::Debug for StructType { +impl std::fmt::Debug for DataType { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "{}", self.name) } @@ -2886,7 +3051,7 @@ impl std::hash::Hash for Type { env.hash(state); } Type::Tuple(elems) => elems.hash(state), - Type::Struct(def, args) => { + Type::DataType(def, args) => { def.hash(state); args.hash(state); } @@ -2913,7 +3078,7 @@ impl std::hash::Hash for Type { Type::CheckedCast { to, .. } => to.hash(state), Type::Constant(value, _) => value.hash(state), Type::Quoted(typ) => typ.hash(state), - Type::InfixExpr(lhs, op, rhs) => { + Type::InfixExpr(lhs, op, rhs, _) => { lhs.hash(state); op.hash(state); rhs.hash(state); @@ -2957,7 +3122,7 @@ impl PartialEq for Type { lhs_len == rhs_len && lhs_env == rhs_env } (Tuple(lhs_types), Tuple(rhs_types)) => lhs_types == rhs_types, - (Struct(lhs_struct, lhs_generics), Struct(rhs_struct, rhs_generics)) => { + (DataType(lhs_struct, lhs_generics), DataType(rhs_struct, rhs_generics)) => { lhs_struct == rhs_struct && lhs_generics == rhs_generics } (Alias(lhs_alias, lhs_generics), Alias(rhs_alias, rhs_generics)) => { @@ -2982,7 +3147,7 @@ impl PartialEq for Type { lhs == rhs && lhs_kind == rhs_kind } (Quoted(lhs), Quoted(rhs)) => lhs == rhs, - (InfixExpr(l_lhs, l_op, l_rhs), InfixExpr(r_lhs, r_op, r_rhs)) => { + (InfixExpr(l_lhs, l_op, l_rhs, _), InfixExpr(r_lhs, r_op, r_rhs, _)) => { l_lhs == r_lhs && l_op == r_op && l_rhs == r_rhs } // Special case: we consider unbound named generics and type variables to be equal to each diff --git a/noir/noir-repo/compiler/noirc_frontend/src/hir_def/types/arithmetic.rs b/noir/noir-repo/compiler/noirc_frontend/src/hir_def/types/arithmetic.rs index 8cdf6f5502c..5750365c62d 100644 --- a/noir/noir-repo/compiler/noirc_frontend/src/hir_def/types/arithmetic.rs +++ b/noir/noir-repo/compiler/noirc_frontend/src/hir_def/types/arithmetic.rs @@ -58,7 +58,7 @@ impl Type { run_simplifications: bool, ) -> Type { match self.follow_bindings() { - Type::InfixExpr(lhs, op, rhs) => { + Type::InfixExpr(lhs, op, rhs, inversion) => { let kind = lhs.infix_kind(&rhs); let dummy_span = Span::default(); // evaluate_to_field_element also calls canonicalize so if we just called @@ -76,7 +76,7 @@ impl Type { let rhs = rhs.canonicalize_helper(found_checked_cast, run_simplifications); if !run_simplifications { - return Type::InfixExpr(Box::new(lhs), op, Box::new(rhs)); + return Type::InfixExpr(Box::new(lhs), op, Box::new(rhs), inversion); } if let Some(result) = Self::try_simplify_non_constants_in_lhs(&lhs, op, &rhs) { @@ -97,7 +97,7 @@ impl Type { return Self::sort_commutative(&lhs, op, &rhs); } - Type::InfixExpr(Box::new(lhs), op, Box::new(rhs)) + Type::InfixExpr(Box::new(lhs), op, Box::new(rhs), inversion) } Type::CheckedCast { from, to } => { let inner_found_checked_cast = true; @@ -131,7 +131,7 @@ impl Type { // Push each non-constant term to `sorted` to sort them. Recur on InfixExprs with the same operator. while let Some(item) = queue.pop() { match item.canonicalize_unchecked() { - Type::InfixExpr(lhs_inner, new_op, rhs_inner) if new_op == op => { + Type::InfixExpr(lhs_inner, new_op, rhs_inner, _) if new_op == op => { queue.push(*lhs_inner); queue.push(*rhs_inner); } @@ -157,18 +157,18 @@ impl Type { // - 1 since `typ` already is set to the first instance for _ in 0..first_type_count - 1 { - typ = Type::InfixExpr(Box::new(typ), op, Box::new(first.0.clone())); + typ = Type::infix_expr(Box::new(typ), op, Box::new(first.0.clone())); } for (rhs, rhs_count) in sorted { for _ in 0..rhs_count { - typ = Type::InfixExpr(Box::new(typ), op, Box::new(rhs.clone())); + typ = Type::infix_expr(Box::new(typ), op, Box::new(rhs.clone())); } } if constant != zero_value { let constant = Type::Constant(constant, lhs.infix_kind(rhs)); - typ = Type::InfixExpr(Box::new(typ), op, Box::new(constant)); + typ = Type::infix_expr(Box::new(typ), op, Box::new(constant)); } typ @@ -192,11 +192,11 @@ impl Type { match lhs.follow_bindings() { Type::CheckedCast { from, to } => { // Apply operation directly to `from` while attempting simplification to `to`. - let from = Type::InfixExpr(from, op, Box::new(rhs.clone())); + let from = Type::infix_expr(from, op, Box::new(rhs.clone())); let to = Self::try_simplify_non_constants_in_lhs(&to, op, rhs)?; Some(Type::CheckedCast { from: Box::new(from), to: Box::new(to) }) } - Type::InfixExpr(l_lhs, l_op, l_rhs) => { + Type::InfixExpr(l_lhs, l_op, l_rhs, _) => { // Note that this is exact, syntactic equality, not unification. // `rhs` is expected to already be in canonical form. if l_op.approx_inverse() != Some(op) @@ -229,11 +229,11 @@ impl Type { match rhs.follow_bindings() { Type::CheckedCast { from, to } => { // Apply operation directly to `from` while attempting simplification to `to`. - let from = Type::InfixExpr(Box::new(lhs.clone()), op, from); + let from = Type::infix_expr(Box::new(lhs.clone()), op, from); let to = Self::try_simplify_non_constants_in_rhs(lhs, op, &to)?; Some(Type::CheckedCast { from: Box::new(from), to: Box::new(to) }) } - Type::InfixExpr(r_lhs, r_op, r_rhs) => { + Type::InfixExpr(r_lhs, r_op, r_rhs, _) => { // `N / (M * N)` should be simplified to `1 / M`, but we only handle // simplifying to `M` in this function. if op == BinaryTypeOperator::Division && r_op == BinaryTypeOperator::Multiplication @@ -268,7 +268,7 @@ impl Type { let dummy_span = Span::default(); let rhs = rhs.evaluate_to_field_element(&kind, dummy_span).ok()?; - let Type::InfixExpr(l_type, l_op, l_rhs) = lhs.follow_bindings() else { + let Type::InfixExpr(l_type, l_op, l_rhs, _) = lhs.follow_bindings() else { return None; }; @@ -302,7 +302,7 @@ impl Type { let result = op.function(l_const, r_const, &lhs.infix_kind(rhs), dummy_span).ok()?; let constant = Type::Constant(result, lhs.infix_kind(rhs)); - Some(Type::InfixExpr(l_type, l_op, Box::new(constant))) + Some(Type::infix_expr(l_type, l_op, Box::new(constant))) } (Multiplication, Division) => { // We need to ensure the result divides evenly to preserve integer division semantics @@ -317,7 +317,7 @@ impl Type { let result = op.function(l_const, r_const, &lhs.infix_kind(rhs), dummy_span).ok()?; let constant = Box::new(Type::Constant(result, lhs.infix_kind(rhs))); - Some(Type::InfixExpr(l_type, l_op, constant)) + Some(Type::infix_expr(l_type, l_op, constant)) } } _ => None, @@ -331,13 +331,14 @@ impl Type { other: &Type, bindings: &mut TypeBindings, ) -> Result<(), UnificationError> { - if let Type::InfixExpr(lhs_a, op_a, rhs_a) = self { + if let Type::InfixExpr(lhs_a, op_a, rhs_a, _) = self { if let Some(inverse) = op_a.approx_inverse() { let kind = lhs_a.infix_kind(rhs_a); let dummy_span = Span::default(); if let Ok(rhs_a_value) = rhs_a.evaluate_to_field_element(&kind, dummy_span) { let rhs_a = Box::new(Type::Constant(rhs_a_value, kind)); - let new_other = Type::InfixExpr(Box::new(other.clone()), inverse, rhs_a); + let new_other = + Type::inverted_infix_expr(Box::new(other.clone()), inverse, rhs_a); let mut tmp_bindings = bindings.clone(); if lhs_a.try_unify(&new_other, &mut tmp_bindings).is_ok() { @@ -348,13 +349,14 @@ impl Type { } } - if let Type::InfixExpr(lhs_b, op_b, rhs_b) = other { + if let Type::InfixExpr(lhs_b, op_b, rhs_b, inversion) = other { if let Some(inverse) = op_b.approx_inverse() { let kind = lhs_b.infix_kind(rhs_b); let dummy_span = Span::default(); if let Ok(rhs_b_value) = rhs_b.evaluate_to_field_element(&kind, dummy_span) { let rhs_b = Box::new(Type::Constant(rhs_b_value, kind)); - let new_self = Type::InfixExpr(Box::new(self.clone()), inverse, rhs_b); + let new_self = + Type::InfixExpr(Box::new(self.clone()), inverse, rhs_b, !inversion); let mut tmp_bindings = bindings.clone(); if new_self.try_unify(lhs_b, &mut tmp_bindings).is_ok() { @@ -384,7 +386,7 @@ mod tests { TypeVariable::unbound(TypeVariableId(0), Kind::u32()), std::rc::Rc::new("N".to_owned()), ); - let n_minus_one = Type::InfixExpr( + let n_minus_one = Type::infix_expr( Box::new(n.clone()), BinaryTypeOperator::Subtraction, Box::new(Type::Constant(FieldElement::one(), Kind::u32())), @@ -392,7 +394,7 @@ mod tests { let checked_cast_n_minus_one = Type::CheckedCast { from: Box::new(n_minus_one.clone()), to: Box::new(n_minus_one) }; - let n_minus_one_plus_one = Type::InfixExpr( + let n_minus_one_plus_one = Type::infix_expr( Box::new(checked_cast_n_minus_one.clone()), BinaryTypeOperator::Addition, Box::new(Type::Constant(FieldElement::one(), Kind::u32())), @@ -405,7 +407,7 @@ mod tests { // We also want to check that if the `CheckedCast` is on the RHS then we'll still be able to canonicalize // the expression `1 + (N - 1)` to `N`. - let one_plus_n_minus_one = Type::InfixExpr( + let one_plus_n_minus_one = Type::infix_expr( Box::new(Type::Constant(FieldElement::one(), Kind::u32())), BinaryTypeOperator::Addition, Box::new(checked_cast_n_minus_one), @@ -423,13 +425,13 @@ mod tests { let x_type = Type::TypeVariable(x_var.clone()); let one = Type::Constant(FieldElement::one(), field_element_kind.clone()); - let lhs = Type::InfixExpr( + let lhs = Type::infix_expr( Box::new(x_type.clone()), BinaryTypeOperator::Addition, Box::new(one.clone()), ); let rhs = - Type::InfixExpr(Box::new(one), BinaryTypeOperator::Addition, Box::new(x_type.clone())); + Type::infix_expr(Box::new(one), BinaryTypeOperator::Addition, Box::new(x_type.clone())); // canonicalize let lhs = lhs.canonicalize(); @@ -546,7 +548,7 @@ mod proptests { 10, // We put up to 10 items per collection |inner| { (inner.clone(), any::(), inner) - .prop_map(|(lhs, op, rhs)| Type::InfixExpr(Box::new(lhs), op, Box::new(rhs))) + .prop_map(|(lhs, op, rhs)| Type::infix_expr(Box::new(lhs), op, Box::new(rhs))) }, ) } diff --git a/noir/noir-repo/compiler/noirc_frontend/src/lexer/token.rs b/noir/noir-repo/compiler/noirc_frontend/src/lexer/token.rs index 8c136f5e45d..7d11b97ca16 100644 --- a/noir/noir-repo/compiler/noirc_frontend/src/lexer/token.rs +++ b/noir/noir-repo/compiler/noirc_frontend/src/lexer/token.rs @@ -1020,6 +1020,7 @@ pub enum Keyword { Dep, Else, Enum, + EnumDefinition, Expr, Field, Fn, @@ -1080,6 +1081,7 @@ impl fmt::Display for Keyword { Keyword::Dep => write!(f, "dep"), Keyword::Else => write!(f, "else"), Keyword::Enum => write!(f, "enum"), + Keyword::EnumDefinition => write!(f, "EnumDefinition"), Keyword::Expr => write!(f, "Expr"), Keyword::Field => write!(f, "Field"), Keyword::Fn => write!(f, "fn"), @@ -1143,6 +1145,7 @@ impl Keyword { "dep" => Keyword::Dep, "else" => Keyword::Else, "enum" => Keyword::Enum, + "EnumDefinition" => Keyword::EnumDefinition, "Expr" => Keyword::Expr, "Field" => Keyword::Field, "fn" => Keyword::Fn, diff --git a/noir/noir-repo/compiler/noirc_frontend/src/locations.rs b/noir/noir-repo/compiler/noirc_frontend/src/locations.rs index ecae5b19a95..08100a3c351 100644 --- a/noir/noir-repo/compiler/noirc_frontend/src/locations.rs +++ b/noir/noir-repo/compiler/noirc_frontend/src/locations.rs @@ -7,7 +7,7 @@ use crate::{ ast::{FunctionDefinition, ItemVisibility}, hir::def_map::{ModuleDefId, ModuleId}, node_interner::{ - DefinitionId, FuncId, GlobalId, NodeInterner, ReferenceId, StructId, TraitId, TypeAliasId, + DefinitionId, FuncId, GlobalId, NodeInterner, ReferenceId, TraitId, TypeAliasId, TypeId, }, }; use petgraph::prelude::NodeIndex as PetGraphIndex; @@ -60,18 +60,22 @@ impl NodeInterner { match reference { ReferenceId::Module(id) => self.module_attributes(&id).location, ReferenceId::Function(id) => self.function_modifiers(&id).name_location, - ReferenceId::Struct(id) => { - let struct_type = self.get_struct(id); - let struct_type = struct_type.borrow(); - Location::new(struct_type.name.span(), struct_type.location.file) + ReferenceId::Type(id) => { + let typ = self.get_type(id); + let typ = typ.borrow(); + Location::new(typ.name.span(), typ.location.file) } ReferenceId::StructMember(id, field_index) => { - let struct_type = self.get_struct(id); + let struct_type = self.get_type(id); let struct_type = struct_type.borrow(); - Location::new( - struct_type.field_at(field_index).name.span(), - struct_type.location.file, - ) + let file = struct_type.location.file; + Location::new(struct_type.field_at(field_index).name.span(), file) + } + ReferenceId::EnumVariant(id, variant_index) => { + let typ = self.get_type(id); + let typ = typ.borrow(); + let file = typ.location.file; + Location::new(typ.variant_at(variant_index).name.span(), file) } ReferenceId::Trait(id) => { let trait_type = self.get_trait(id); @@ -105,8 +109,8 @@ impl NodeInterner { ModuleDefId::FunctionId(func_id) => { self.add_function_reference(func_id, location); } - ModuleDefId::TypeId(struct_id) => { - self.add_struct_reference(struct_id, location, is_self_type); + ModuleDefId::TypeId(type_id) => { + self.add_type_reference(type_id, location, is_self_type); } ModuleDefId::TraitId(trait_id) => { self.add_trait_reference(trait_id, location, is_self_type); @@ -124,18 +128,18 @@ impl NodeInterner { self.add_reference(ReferenceId::Module(id), location, false); } - pub(crate) fn add_struct_reference( + pub(crate) fn add_type_reference( &mut self, - id: StructId, + id: TypeId, location: Location, is_self_type: bool, ) { - self.add_reference(ReferenceId::Struct(id), location, is_self_type); + self.add_reference(ReferenceId::Type(id), location, is_self_type); } pub(crate) fn add_struct_member_reference( &mut self, - id: StructId, + id: TypeId, member_index: usize, location: Location, ) { @@ -190,6 +194,7 @@ impl NodeInterner { pub(crate) fn add_definition_location( &mut self, referenced: ReferenceId, + referenced_location: Location, module_id: Option, ) { if !self.lsp_mode { @@ -197,7 +202,6 @@ impl NodeInterner { } let referenced_index = self.get_or_insert_reference(referenced); - let referenced_location = self.reference_location(referenced); self.location_indices.add_location(referenced_location, referenced_index); if let Some(module_id) = module_id { self.reference_modules.insert(referenced, module_id); @@ -315,21 +319,23 @@ impl NodeInterner { &mut self, id: GlobalId, name: String, + location: Location, visibility: ItemVisibility, parent_module_id: ModuleId, ) { - self.add_definition_location(ReferenceId::Global(id), Some(parent_module_id)); + self.add_definition_location(ReferenceId::Global(id), location, Some(parent_module_id)); self.register_name_for_auto_import(name, ModuleDefId::GlobalId(id), visibility, None); } - pub(crate) fn register_struct( + pub(crate) fn register_type( &mut self, - id: StructId, + id: TypeId, name: String, + location: Location, visibility: ItemVisibility, parent_module_id: ModuleId, ) { - self.add_definition_location(ReferenceId::Struct(id), Some(parent_module_id)); + self.add_definition_location(ReferenceId::Type(id), location, Some(parent_module_id)); self.register_name_for_auto_import(name, ModuleDefId::TypeId(id), visibility, None); } @@ -337,10 +343,11 @@ impl NodeInterner { &mut self, id: TraitId, name: String, + location: Location, visibility: ItemVisibility, parent_module_id: ModuleId, ) { - self.add_definition_location(ReferenceId::Trait(id), Some(parent_module_id)); + self.add_definition_location(ReferenceId::Trait(id), location, Some(parent_module_id)); self.register_name_for_auto_import(name, ModuleDefId::TraitId(id), visibility, None); } @@ -348,10 +355,11 @@ impl NodeInterner { &mut self, id: TypeAliasId, name: String, + location: Location, visibility: ItemVisibility, parent_module_id: ModuleId, ) { - self.add_definition_location(ReferenceId::Alias(id), Some(parent_module_id)); + self.add_definition_location(ReferenceId::Alias(id), location, Some(parent_module_id)); self.register_name_for_auto_import(name, ModuleDefId::TypeAliasId(id), visibility, None); } diff --git a/noir/noir-repo/compiler/noirc_frontend/src/monomorphization/ast.rs b/noir/noir-repo/compiler/noirc_frontend/src/monomorphization/ast.rs index d219e8f7c2d..621eb30e4f8 100644 --- a/noir/noir-repo/compiler/noirc_frontend/src/monomorphization/ast.rs +++ b/noir/noir-repo/compiler/noirc_frontend/src/monomorphization/ast.rs @@ -36,6 +36,7 @@ pub enum Expression { Index(Index), Cast(Cast), For(For), + Loop(Box), If(If), Tuple(Vec), ExtractTupleField(Box, usize), @@ -227,7 +228,9 @@ pub type Parameters = Vec<(LocalId, /*mutable:*/ bool, /*name:*/ String, Type)>; /// Represents how an Acir function should be inlined. /// This type is only relevant for ACIR functions as we do not inline any Brillig functions -#[derive(Default, Clone, Copy, PartialEq, Eq, Debug, Hash, Serialize, Deserialize)] +#[derive( + Default, Clone, Copy, PartialEq, Eq, Debug, Hash, Serialize, Deserialize, PartialOrd, Ord, +)] pub enum InlineType { /// The most basic entry point can expect all its functions to be inlined. /// All function calls are expected to be inlined into a single ACIR. diff --git a/noir/noir-repo/compiler/noirc_frontend/src/monomorphization/mod.rs b/noir/noir-repo/compiler/noirc_frontend/src/monomorphization/mod.rs index b0c8744ea8f..de8e4f6f864 100644 --- a/noir/noir-repo/compiler/noirc_frontend/src/monomorphization/mod.rs +++ b/noir/noir-repo/compiler/noirc_frontend/src/monomorphization/mod.rs @@ -594,6 +594,9 @@ impl<'interner> Monomorphizer<'interner> { HirExpression::Comptime(_) => { unreachable!("comptime expression remaining in runtime code") } + HirExpression::EnumConstructor(constructor) => { + self.enum_constructor(constructor, expr)? + } }; Ok(expr) @@ -695,6 +698,10 @@ impl<'interner> Monomorphizer<'interner> { block, })) } + HirStatement::Loop(block) => { + let block = Box::new(self.expr(block)?); + Ok(ast::Expression::Loop(block)) + } HirStatement::Expression(expr) => self.expr(expr), HirStatement::Semi(expr) => { self.expr(expr).map(|expr| ast::Expression::Semi(Box::new(expr))) @@ -769,6 +776,48 @@ impl<'interner> Monomorphizer<'interner> { Ok(ast::Expression::Block(new_exprs)) } + /// For an enum like: + /// ``` + /// enum Foo { + /// A(i32, u32), + /// B(Field), + /// C + /// } + /// ``` + /// this will translate the call `Foo::A(1, 2)` into `(0, (1, 2), (0,), ())` where + /// the first field `0` is the tag value, the second is `A`, third is `B`, and fourth is `C`. + /// Each variant that isn't the desired variant has zeroed values filled in for its data. + fn enum_constructor( + &mut self, + constructor: HirEnumConstructorExpression, + id: node_interner::ExprId, + ) -> Result { + let location = self.interner.expr_location(&id); + let typ = self.interner.id_type(id); + let variants = unwrap_enum_type(&typ, location)?; + + // Fill in each field of the translated enum tuple. + // For most fields this will be simply `std::mem::zeroed::()`, + // but for the given variant we just pack all the arguments into a tuple for that field. + let mut fields = try_vecmap(variants.into_iter().enumerate(), |(i, (_, arg_types))| { + let fields = if i == constructor.variant_index { + try_vecmap(&constructor.arguments, |arg| self.expr(*arg)) + } else { + try_vecmap(arg_types, |typ| { + let typ = Self::convert_type(&typ, location)?; + Ok(self.zeroed_value_of_type(&typ, location)) + }) + }?; + Ok(ast::Expression::Tuple(fields)) + })?; + + let tag_value = FieldElement::from(constructor.variant_index); + let tag = ast::Literal::Integer(tag_value, false, ast::Type::Field, location); + fields.insert(0, ast::Expression::Literal(tag)); + + Ok(ast::Expression::Tuple(fields)) + } + fn block( &mut self, statement_ids: Vec, @@ -878,6 +927,7 @@ impl<'interner> Monomorphizer<'interner> { fn local_ident( &mut self, ident: &HirIdent, + typ: &Type, ) -> Result, MonomorphizationError> { let definition = self.interner.definition(ident.id); let name = definition.name.clone(); @@ -887,7 +937,7 @@ impl<'interner> Monomorphizer<'interner> { return Ok(None); }; - let typ = Self::convert_type(&self.interner.definition_type(ident.id), ident.location)?; + let typ = Self::convert_type(typ, ident.location)?; Ok(Some(ast::Ident { location: Some(ident.location), mutable, definition, name, typ })) } @@ -952,7 +1002,7 @@ impl<'interner> Monomorphizer<'interner> { DefinitionKind::Local(_) => match self.lookup_captured_expr(ident.id) { Some(expr) => expr, None => { - let Some(ident) = self.local_ident(&ident)? else { + let Some(ident) = self.local_ident(&ident, &typ)? else { let location = self.interner.id_location(expr_id); let message = "ICE: Variable not found during monomorphization"; return Err(MonomorphizationError::InternalError { location, message }); @@ -1155,16 +1205,30 @@ impl<'interner> Monomorphizer<'interner> { monomorphized_default } - HirType::Struct(def, args) => { - // Not all generic arguments may be used in a struct's fields so we have to check + HirType::DataType(def, args) => { + // Not all generic arguments may be used in a datatype's fields so we have to check // the arguments as well as the fields in case any need to be defaulted or are unbound. for arg in args { Self::check_type(arg, location)?; } - let fields = def.borrow().get_fields(args); - let fields = try_vecmap(fields, |(_, field)| Self::convert_type(&field, location))?; - ast::Type::Tuple(fields) + let def = def.borrow(); + if let Some(fields) = def.get_fields(args) { + let fields = + try_vecmap(fields, |(_, field)| Self::convert_type(&field, location))?; + ast::Type::Tuple(fields) + } else if let Some(variants) = def.get_variants(args) { + // Enums are represented as (tag, variant1, variant2, .., variantN) + let mut fields = vec![ast::Type::Field]; + for (_, variant_fields) in variants { + let variant_fields = + try_vecmap(variant_fields, |typ| Self::convert_type(&typ, location))?; + fields.push(ast::Type::Tuple(variant_fields)); + } + ast::Type::Tuple(fields) + } else { + unreachable!("Data type has no body") + } } HirType::Alias(def, args) => { @@ -1275,7 +1339,7 @@ impl<'interner> Monomorphizer<'interner> { Self::check_type(&default, location) } - HirType::Struct(_def, args) => { + HirType::DataType(_def, args) => { for arg in args { Self::check_type(arg, location)?; } @@ -1309,7 +1373,7 @@ impl<'interner> Monomorphizer<'interner> { } HirType::MutableReference(element) => Self::check_type(element, location), - HirType::InfixExpr(lhs, _, rhs) => { + HirType::InfixExpr(lhs, _, rhs, _) => { Self::check_type(lhs, location)?; Self::check_type(rhs, location) } @@ -1692,9 +1756,9 @@ impl<'interner> Monomorphizer<'interner> { fn lvalue(&mut self, lvalue: HirLValue) -> Result { let value = match lvalue { - HirLValue::Ident(ident, _) => match self.lookup_captured_lvalue(ident.id) { + HirLValue::Ident(ident, typ) => match self.lookup_captured_lvalue(ident.id) { Some(value) => value, - None => ast::LValue::Ident(self.local_ident(&ident)?.unwrap()), + None => ast::LValue::Ident(self.local_ident(&ident, &typ)?.unwrap()), }, HirLValue::MemberAccess { object, field_index, .. } => { let field_index = field_index.unwrap(); @@ -1827,7 +1891,8 @@ impl<'interner> Monomorphizer<'interner> { Ok(ast::Expression::ExtractTupleField(ident, field_index)) } None => { - let ident = self.local_ident(&capture.ident)?.unwrap(); + let typ = self.interner.definition_type(capture.ident.id); + let ident = self.local_ident(&capture.ident, &typ)?.unwrap(); Ok(ast::Expression::Ident(ident)) } } @@ -2129,18 +2194,35 @@ fn unwrap_struct_type( location: Location, ) -> Result, MonomorphizationError> { match typ.follow_bindings() { - HirType::Struct(def, args) => { + HirType::DataType(def, args) => { // Some of args might not be mentioned in fields, so we need to check that they aren't unbound. for arg in &args { Monomorphizer::check_type(arg, location)?; } - Ok(def.borrow().get_fields(&args)) + Ok(def.borrow().get_fields(&args).unwrap()) } other => unreachable!("unwrap_struct_type: expected struct, found {:?}", other), } } +fn unwrap_enum_type( + typ: &HirType, + location: Location, +) -> Result)>, MonomorphizationError> { + match typ.follow_bindings() { + HirType::DataType(def, args) => { + // Some of args might not be mentioned in fields, so we need to check that they aren't unbound. + for arg in &args { + Monomorphizer::check_type(arg, location)?; + } + + Ok(def.borrow().get_variants(&args).unwrap()) + } + other => unreachable!("unwrap_enum_type: expected enum, found {:?}", other), + } +} + pub fn perform_instantiation_bindings(bindings: &TypeBindings) { for (var, _kind, binding) in bindings.values() { var.force_bind(binding.clone()); diff --git a/noir/noir-repo/compiler/noirc_frontend/src/monomorphization/printer.rs b/noir/noir-repo/compiler/noirc_frontend/src/monomorphization/printer.rs index 25ac1336075..665f4dcd371 100644 --- a/noir/noir-repo/compiler/noirc_frontend/src/monomorphization/printer.rs +++ b/noir/noir-repo/compiler/noirc_frontend/src/monomorphization/printer.rs @@ -49,6 +49,7 @@ impl AstPrinter { write!(f, " as {})", cast.r#type) } Expression::For(for_expr) => self.print_for(for_expr, f), + Expression::Loop(block) => self.print_loop(block, f), Expression::If(if_expr) => self.print_if(if_expr, f), Expression::Tuple(tuple) => self.print_tuple(tuple, f), Expression::ExtractTupleField(expr, index) => { @@ -209,6 +210,15 @@ impl AstPrinter { write!(f, "}}") } + fn print_loop(&mut self, block: &Expression, f: &mut Formatter) -> Result<(), std::fmt::Error> { + write!(f, "loop {{")?; + self.indent_level += 1; + self.print_expr_expect_block(block, f)?; + self.indent_level -= 1; + self.next_line(f)?; + write!(f, "}}") + } + fn print_if( &mut self, if_expr: &super::ast::If, diff --git a/noir/noir-repo/compiler/noirc_frontend/src/node_interner.rs b/noir/noir-repo/compiler/noirc_frontend/src/node_interner.rs index ae2cf224cbd..1ebcb6aff96 100644 --- a/noir/noir-repo/compiler/noirc_frontend/src/node_interner.rs +++ b/noir/noir-repo/compiler/noirc_frontend/src/node_interner.rs @@ -18,7 +18,7 @@ use crate::ast::{ use crate::graph::CrateId; use crate::hir::comptime; use crate::hir::def_collector::dc_crate::CompilationError; -use crate::hir::def_collector::dc_crate::{UnresolvedStruct, UnresolvedTrait, UnresolvedTypeAlias}; +use crate::hir::def_collector::dc_crate::{UnresolvedTrait, UnresolvedTypeAlias}; use crate::hir::def_map::DefMaps; use crate::hir::def_map::{LocalModuleId, ModuleDefId, ModuleId}; use crate::hir::type_check::generics::TraitGenerics; @@ -32,7 +32,7 @@ use crate::hir_def::expr::HirIdent; use crate::hir_def::stmt::HirLetStatement; use crate::hir_def::traits::TraitImpl; use crate::hir_def::traits::{Trait, TraitConstraint}; -use crate::hir_def::types::{Kind, StructType, Type}; +use crate::hir_def::types::{DataType, Kind, Type}; use crate::hir_def::{ expr::HirExpression, function::{FuncMeta, HirFunction}, @@ -56,7 +56,7 @@ pub struct ModuleAttributes { pub visibility: ItemVisibility, } -type StructAttributes = Vec; +type TypeAttributes = Vec; /// The node interner is the central storage location of all nodes in Noir's Hir (the /// various node types can be found in hir_def). The interner is also used to collect @@ -106,14 +106,14 @@ pub struct NodeInterner { // Similar to `id_to_type` but maps definitions to their type definition_to_type: HashMap, - // Struct map. + // Struct and Enum map. // - // Each struct definition is possibly shared across multiple type nodes. + // Each type definition is possibly shared across multiple type nodes. // It is also mutated through the RefCell during name resolution to append // methods from impls to the type. - structs: HashMap>, + data_types: HashMap>, - struct_attributes: HashMap, + type_attributes: HashMap, // Maps TypeAliasId -> Shared // @@ -286,7 +286,7 @@ pub struct NodeInterner { /// ``` #[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)] pub enum DependencyId { - Struct(StructId), + Struct(TypeId), Global(GlobalId), Function(FuncId), Alias(TypeAliasId), @@ -299,8 +299,9 @@ pub enum DependencyId { #[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)] pub enum ReferenceId { Module(ModuleId), - Struct(StructId), - StructMember(StructId, usize), + Type(TypeId), + StructMember(TypeId, usize), + EnumVariant(TypeId, usize), Trait(TraitId), Global(GlobalId), Function(FuncId), @@ -465,14 +466,14 @@ impl fmt::Display for FuncId { } #[derive(Debug, Eq, PartialEq, Hash, Copy, Clone, PartialOrd, Ord)] -pub struct StructId(ModuleId); +pub struct TypeId(ModuleId); -impl StructId { +impl TypeId { //dummy id for error reporting // This can be anything, as the program will ultimately fail // after resolution - pub fn dummy_id() -> StructId { - StructId(ModuleId { krate: CrateId::dummy_id(), local_id: LocalModuleId::dummy_id() }) + pub fn dummy_id() -> TypeId { + TypeId(ModuleId { krate: CrateId::dummy_id(), local_id: LocalModuleId::dummy_id() }) } pub fn module_id(self) -> ModuleId { @@ -652,8 +653,8 @@ impl Default for NodeInterner { definitions: vec![], id_to_type: HashMap::default(), definition_to_type: HashMap::default(), - structs: HashMap::default(), - struct_attributes: HashMap::default(), + data_types: HashMap::default(), + type_attributes: HashMap::default(), type_aliases: Vec::new(), traits: HashMap::default(), trait_implementations: HashMap::default(), @@ -747,25 +748,25 @@ impl NodeInterner { self.traits.insert(type_id, new_trait); } - pub fn new_struct( + /// Creates a new struct or enum type with no fields or variants. + #[allow(clippy::too_many_arguments)] + pub fn new_type( &mut self, - typ: &UnresolvedStruct, + name: Ident, + span: Span, + attributes: Vec, generics: Generics, krate: CrateId, local_id: LocalModuleId, file_id: FileId, - ) -> StructId { - let struct_id = StructId(ModuleId { krate, local_id }); - let name = typ.struct_def.name.clone(); - - // Fields will be filled in later - let no_fields = Vec::new(); + ) -> TypeId { + let type_id = TypeId(ModuleId { krate, local_id }); - let location = Location::new(typ.struct_def.span, file_id); - let new_struct = StructType::new(struct_id, name, location, no_fields, generics); - self.structs.insert(struct_id, Shared::new(new_struct)); - self.struct_attributes.insert(struct_id, typ.struct_def.attributes.clone()); - struct_id + let location = Location::new(span, file_id); + let new_type = DataType::new(type_id, name, location, generics); + self.data_types.insert(type_id, Shared::new(new_type)); + self.type_attributes.insert(type_id, attributes); + type_id } pub fn push_type_alias( @@ -791,8 +792,9 @@ impl NodeInterner { pub fn add_type_alias_ref(&mut self, type_id: TypeAliasId, location: Location) { self.type_alias_ref.push((type_id, location)); } - pub fn update_struct(&mut self, type_id: StructId, f: impl FnOnce(&mut StructType)) { - let mut value = self.structs.get_mut(&type_id).unwrap().borrow_mut(); + + pub fn update_type(&mut self, type_id: TypeId, f: impl FnOnce(&mut DataType)) { + let mut value = self.data_types.get_mut(&type_id).unwrap().borrow_mut(); f(&mut value); } @@ -801,12 +803,8 @@ impl NodeInterner { f(value); } - pub fn update_struct_attributes( - &mut self, - type_id: StructId, - f: impl FnOnce(&mut StructAttributes), - ) { - let value = self.struct_attributes.get_mut(&type_id).unwrap(); + pub fn update_type_attributes(&mut self, type_id: TypeId, f: impl FnOnce(&mut TypeAttributes)) { + let value = self.type_attributes.get_mut(&type_id).unwrap(); f(value); } @@ -956,7 +954,7 @@ impl NodeInterner { self.definitions.push(DefinitionInfo { name, mutable, comptime, kind, location }); if is_local { - self.add_definition_location(ReferenceId::Local(id), None); + self.add_definition_location(ReferenceId::Local(id), location, None); } id @@ -981,6 +979,7 @@ impl NodeInterner { module: ModuleId, location: Location, ) -> DefinitionId { + let name_location = Location::new(function.name.span(), location.file); let modifiers = FunctionModifiers { name: function.name.0.contents.clone(), visibility: function.visibility, @@ -988,14 +987,10 @@ impl NodeInterner { is_unconstrained: function.is_unconstrained, generic_count: function.generics.len(), is_comptime: function.is_comptime, - name_location: Location::new(function.name.span(), location.file), + name_location, }; let definition_id = self.push_function_definition(id, modifiers, module, location); - - // This needs to be done after pushing the definition since it will reference the - // location that was stored - self.add_definition_location(ReferenceId::Function(id), Some(module)); - + self.add_definition_location(ReferenceId::Function(id), name_location, Some(module)); definition_id } @@ -1096,8 +1091,8 @@ impl NodeInterner { &self.function_modifiers[func_id].attributes } - pub fn struct_attributes(&self, struct_id: &StructId) -> &StructAttributes { - &self.struct_attributes[struct_id] + pub fn type_attributes(&self, struct_id: &TypeId) -> &TypeAttributes { + &self.type_attributes[struct_id] } pub fn add_module_attributes(&mut self, module_id: ModuleId, attributes: ModuleAttributes) { @@ -1213,8 +1208,8 @@ impl NodeInterner { self.id_to_location.insert(id.into(), Location::new(span, file)); } - pub fn get_struct(&self, id: StructId) -> Shared { - self.structs[&id].clone() + pub fn get_type(&self, id: TypeId) -> Shared { + self.data_types[&id].clone() } pub fn get_type_methods(&self, typ: &Type) -> Option<&HashMap> { @@ -1387,7 +1382,7 @@ impl NodeInterner { unreachable!("Cannot add a method to the unsupported type '{}'", self_type) }); - if trait_id.is_none() && matches!(self_type, Type::Struct(..)) { + if trait_id.is_none() && matches!(self_type, Type::DataType(..)) { if let Some(existing) = self.lookup_direct_method(self_type, &method_name, true) { return Some(existing); @@ -1722,8 +1717,18 @@ impl NodeInterner { let instantiated_object_type = object_type.substitute(&substitutions); let trait_generics = &trait_impl.borrow().trait_generics; + + // Replace any associated types with fresh type variables so that we match + // any existing impl regardless of associated types if one already exists. + // E.g. if we already have an `impl Foo for Baz`, we should + // reject `impl Foo for Baz` if it were to be added. let associated_types = self.get_associated_types_for_impl(impl_id); + let associated_types = vecmap(associated_types, |named| { + let typ = self.next_type_variable(); + NamedType { name: named.name.clone(), typ } + }); + // Ignoring overlapping `TraitImplKind::Assumed` impls here is perfectly fine. // It should never happen since impls are defined at global scope, but even // if they were, we should never prevent defining a new impl because a 'where' @@ -1732,7 +1737,7 @@ impl NodeInterner { &instantiated_object_type, trait_id, trait_generics, - associated_types, + &associated_types, ) { let existing_impl = self.get_trait_implementation(existing); let existing_impl = existing_impl.borrow(); @@ -1970,7 +1975,7 @@ impl NodeInterner { /// Register that `dependent` depends on `dependency`. /// This is usually because `dependent` refers to `dependency` in one of its struct fields. - pub fn add_type_dependency(&mut self, dependent: DependencyId, dependency: StructId) { + pub fn add_type_dependency(&mut self, dependent: DependencyId, dependency: TypeId) { self.add_dependency(dependent, DependencyId::Struct(dependency)); } @@ -2023,7 +2028,7 @@ impl NodeInterner { for (i, index) in scc.iter().enumerate() { match self.dependency_graph[*index] { DependencyId::Struct(struct_id) => { - let struct_type = self.get_struct(struct_id); + let struct_type = self.get_type(struct_id); let struct_type = struct_type.borrow(); push_error(struct_type.name.to_string(), &scc, i, struct_type.location); break; @@ -2070,7 +2075,7 @@ impl NodeInterner { /// element at the given start index. fn get_cycle_error_string(&self, scc: &[PetGraphIndex], start_index: usize) -> String { let index_to_string = |index: PetGraphIndex| match self.dependency_graph[index] { - DependencyId::Struct(id) => Cow::Owned(self.get_struct(id).borrow().name.to_string()), + DependencyId::Struct(id) => Cow::Owned(self.get_type(id).borrow().name.to_string()), DependencyId::Function(id) => Cow::Borrowed(self.function_name(&id)), DependencyId::Alias(id) => { Cow::Owned(self.get_type_alias(id).borrow().name.to_string()) @@ -2412,7 +2417,7 @@ enum TypeMethodKey { Function, Generic, Quoted(QuotedType), - Struct(StructId), + Struct(TypeId), } fn get_type_method_key(typ: &Type) -> Option { @@ -2440,7 +2445,7 @@ fn get_type_method_key(typ: &Type) -> Option { Type::Quoted(quoted) => Some(Quoted(*quoted)), Type::MutableReference(element) => get_type_method_key(element), Type::Alias(alias, _) => get_type_method_key(&alias.borrow().typ), - Type::Struct(struct_type, _) => Some(Struct(struct_type.borrow().id)), + Type::DataType(struct_type, _) => Some(Struct(struct_type.borrow().id)), // We do not support adding methods to these types Type::Forall(_, _) diff --git a/noir/noir-repo/compiler/noirc_frontend/src/parser/errors.rs b/noir/noir-repo/compiler/noirc_frontend/src/parser/errors.rs index f44f109e1ce..508ed33857e 100644 --- a/noir/noir-repo/compiler/noirc_frontend/src/parser/errors.rs +++ b/noir/noir-repo/compiler/noirc_frontend/src/parser/errors.rs @@ -83,8 +83,8 @@ pub enum ParserErrorReason { "Multiple primary attributes found. Only one function attribute is allowed per function" )] MultipleFunctionAttributesFound, - #[error("A function attribute cannot be placed on a struct")] - NoFunctionAttributesAllowedOnStruct, + #[error("A function attribute cannot be placed on a struct or enum")] + NoFunctionAttributesAllowedOnType, #[error("Assert statements can only accept string literals")] AssertMessageNotString, #[error("Integer bit size {0} isn't supported")] diff --git a/noir/noir-repo/compiler/noirc_frontend/src/parser/mod.rs b/noir/noir-repo/compiler/noirc_frontend/src/parser/mod.rs index 17c156476a7..c433adbfdfb 100644 --- a/noir/noir-repo/compiler/noirc_frontend/src/parser/mod.rs +++ b/noir/noir-repo/compiler/noirc_frontend/src/parser/mod.rs @@ -13,7 +13,8 @@ mod parser; use crate::ast::{ Documented, Ident, ImportStatement, ItemVisibility, LetStatement, ModuleDeclaration, - NoirFunction, NoirStruct, NoirTrait, NoirTraitImpl, NoirTypeAlias, TypeImpl, UseTree, + NoirEnumeration, NoirFunction, NoirStruct, NoirTrait, NoirTraitImpl, NoirTypeAlias, TypeImpl, + UseTree, }; use crate::token::SecondaryAttribute; @@ -26,7 +27,8 @@ pub use parser::{parse_program, Parser, StatementOrExpressionOrLValue}; pub struct SortedModule { pub imports: Vec, pub functions: Vec>, - pub types: Vec>, + pub structs: Vec>, + pub enums: Vec>, pub traits: Vec>, pub trait_impls: Vec, pub impls: Vec, @@ -57,7 +59,7 @@ impl std::fmt::Display for SortedModule { write!(f, "{global_const}")?; } - for type_ in &self.types { + for type_ in &self.structs { write!(f, "{type_}")?; } @@ -96,7 +98,8 @@ impl ParsedModule { match item.kind { ItemKind::Import(import, visibility) => module.push_import(import, visibility), ItemKind::Function(func) => module.push_function(func, item.doc_comments), - ItemKind::Struct(typ) => module.push_type(typ, item.doc_comments), + ItemKind::Struct(typ) => module.push_struct(typ, item.doc_comments), + ItemKind::Enum(typ) => module.push_enum(typ, item.doc_comments), ItemKind::Trait(noir_trait) => module.push_trait(noir_trait, item.doc_comments), ItemKind::TraitImpl(trait_impl) => module.push_trait_impl(trait_impl), ItemKind::Impl(r#impl) => module.push_impl(r#impl), @@ -134,6 +137,7 @@ pub enum ItemKind { Import(UseTree, ItemVisibility), Function(NoirFunction), Struct(NoirStruct), + Enum(NoirEnumeration), Trait(NoirTrait), TraitImpl(NoirTraitImpl), Impl(TypeImpl), @@ -147,6 +151,7 @@ pub enum ItemKind { impl std::fmt::Display for ItemKind { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { + ItemKind::Enum(e) => e.fmt(f), ItemKind::Function(fun) => fun.fmt(f), ItemKind::ModuleDecl(m) => m.fmt(f), ItemKind::Import(tree, visibility) => { @@ -222,8 +227,12 @@ impl SortedModule { self.functions.push(Documented::new(func, doc_comments)); } - fn push_type(&mut self, typ: NoirStruct, doc_comments: Vec) { - self.types.push(Documented::new(typ, doc_comments)); + fn push_struct(&mut self, typ: NoirStruct, doc_comments: Vec) { + self.structs.push(Documented::new(typ, doc_comments)); + } + + fn push_enum(&mut self, typ: NoirEnumeration, doc_comments: Vec) { + self.enums.push(Documented::new(typ, doc_comments)); } fn push_trait(&mut self, noir_trait: NoirTrait, doc_comments: Vec) { diff --git a/noir/noir-repo/compiler/noirc_frontend/src/parser/parser.rs b/noir/noir-repo/compiler/noirc_frontend/src/parser/parser.rs index 05f8ae3c2bb..e554248fb03 100644 --- a/noir/noir-repo/compiler/noirc_frontend/src/parser/parser.rs +++ b/noir/noir-repo/compiler/noirc_frontend/src/parser/parser.rs @@ -13,6 +13,7 @@ use super::{labels::ParsingRuleLabel, ParsedModule, ParserError, ParserErrorReas mod arguments; mod attributes; mod doc_comments; +mod enums; mod expression; mod function; mod generics; @@ -191,14 +192,10 @@ impl<'a> Parser<'a> { fn read_token_internal(&mut self) -> SpannedToken { loop { - let token = self.tokens.next(); - if let Some(token) = token { - match token { - Ok(token) => return token, - Err(lexer_error) => self.errors.push(lexer_error.into()), - } - } else { - return eof_spanned_token(); + match self.tokens.next() { + Some(Ok(token)) => return token, + Some(Err(lexer_error)) => self.errors.push(lexer_error.into()), + None => return eof_spanned_token(), } } } diff --git a/noir/noir-repo/compiler/noirc_frontend/src/parser/parser/attributes.rs b/noir/noir-repo/compiler/noirc_frontend/src/parser/parser/attributes.rs index 12cb37edb4b..e32e7d3cb23 100644 --- a/noir/noir-repo/compiler/noirc_frontend/src/parser/parser/attributes.rs +++ b/noir/noir-repo/compiler/noirc_frontend/src/parser/parser/attributes.rs @@ -92,7 +92,7 @@ impl<'a> Parser<'a> { .into_iter() .filter_map(|(attribute, span)| match attribute { Attribute::Function(..) => { - self.push_error(ParserErrorReason::NoFunctionAttributesAllowedOnStruct, span); + self.push_error(ParserErrorReason::NoFunctionAttributesAllowedOnType, span); None } Attribute::Secondary(attr) => Some(attr), diff --git a/noir/noir-repo/compiler/noirc_frontend/src/parser/parser/enums.rs b/noir/noir-repo/compiler/noirc_frontend/src/parser/parser/enums.rs new file mode 100644 index 00000000000..f95c0f8f72b --- /dev/null +++ b/noir/noir-repo/compiler/noirc_frontend/src/parser/parser/enums.rs @@ -0,0 +1,265 @@ +use noirc_errors::Span; + +use crate::{ + ast::{Documented, EnumVariant, Ident, ItemVisibility, NoirEnumeration, UnresolvedGenerics}, + parser::ParserErrorReason, + token::{Attribute, SecondaryAttribute, Token}, +}; + +use super::{ + parse_many::{separated_by_comma_until_right_brace, separated_by_comma_until_right_paren}, + Parser, +}; + +impl<'a> Parser<'a> { + /// Enum = 'enum' identifier Generics '{' EnumVariant* '}' + /// + /// EnumField = OuterDocComments identifier ':' Type + pub(crate) fn parse_enum( + &mut self, + attributes: Vec<(Attribute, Span)>, + visibility: ItemVisibility, + start_span: Span, + ) -> NoirEnumeration { + let attributes = self.validate_secondary_attributes(attributes); + + self.push_error(ParserErrorReason::ExperimentalFeature("Enums"), start_span); + + let Some(name) = self.eat_ident() else { + self.expected_identifier(); + return self.empty_enum( + Ident::default(), + attributes, + visibility, + Vec::new(), + start_span, + ); + }; + + let generics = self.parse_generics(); + + if !self.eat_left_brace() { + self.expected_token(Token::LeftBrace); + return self.empty_enum(name, attributes, visibility, generics, start_span); + } + + let comma_separated = separated_by_comma_until_right_brace(); + let variants = self.parse_many("enum variants", comma_separated, Self::parse_enum_variant); + + NoirEnumeration { + name, + attributes, + visibility, + generics, + variants, + span: self.span_since(start_span), + } + } + + fn parse_enum_variant(&mut self) -> Option> { + let mut doc_comments; + let name; + + // Loop until we find an identifier, skipping anything that's not one + loop { + let doc_comments_start_span = self.current_token_span; + doc_comments = self.parse_outer_doc_comments(); + + if let Some(ident) = self.eat_ident() { + name = ident; + break; + } + + if !doc_comments.is_empty() { + self.push_error( + ParserErrorReason::DocCommentDoesNotDocumentAnything, + self.span_since(doc_comments_start_span), + ); + } + + // Though we do have to stop at EOF + if self.at_eof() { + self.expected_token(Token::RightBrace); + return None; + } + + // Or if we find a right brace + if self.at(Token::RightBrace) { + return None; + } + + self.expected_identifier(); + self.bump(); + } + + let mut parameters = Vec::new(); + + if self.eat_left_paren() { + let comma_separated = separated_by_comma_until_right_paren(); + parameters = self.parse_many("variant parameters", comma_separated, Self::parse_type); + } + + Some(Documented::new(EnumVariant { name, parameters }, doc_comments)) + } + + fn empty_enum( + &self, + name: Ident, + attributes: Vec, + visibility: ItemVisibility, + generics: UnresolvedGenerics, + start_span: Span, + ) -> NoirEnumeration { + NoirEnumeration { + name, + attributes, + visibility, + generics, + variants: Vec::new(), + span: self.span_since(start_span), + } + } +} + +#[cfg(test)] +mod tests { + use crate::{ + ast::{IntegerBitSize, NoirEnumeration, Signedness, UnresolvedGeneric, UnresolvedTypeData}, + parser::{ + parser::{ + parse_program, + tests::{expect_no_errors, get_source_with_error_span}, + }, + ItemKind, ParserErrorReason, + }, + }; + + fn parse_enum_no_errors(src: &str) -> NoirEnumeration { + let (mut module, errors) = parse_program(src); + expect_no_errors(&errors); + assert_eq!(module.items.len(), 1); + let item = module.items.remove(0); + let ItemKind::Enum(noir_enum) = item.kind else { + panic!("Expected enum"); + }; + noir_enum + } + + #[test] + fn parse_empty_enum() { + let src = "enum Foo {}"; + let noir_enum = parse_enum_no_errors(src); + assert_eq!("Foo", noir_enum.name.to_string()); + assert!(noir_enum.variants.is_empty()); + assert!(noir_enum.generics.is_empty()); + } + + #[test] + fn parse_empty_enum_with_generics() { + let src = "enum Foo {}"; + let mut noir_enum = parse_enum_no_errors(src); + assert_eq!("Foo", noir_enum.name.to_string()); + assert!(noir_enum.variants.is_empty()); + assert_eq!(noir_enum.generics.len(), 2); + + let generic = noir_enum.generics.remove(0); + let UnresolvedGeneric::Variable(ident) = generic else { + panic!("Expected generic variable"); + }; + assert_eq!("A", ident.to_string()); + + let generic = noir_enum.generics.remove(0); + let UnresolvedGeneric::Numeric { ident, typ } = generic else { + panic!("Expected generic numeric"); + }; + assert_eq!("B", ident.to_string()); + assert_eq!( + typ.typ, + UnresolvedTypeData::Integer(Signedness::Unsigned, IntegerBitSize::ThirtyTwo) + ); + } + + #[test] + fn parse_enum_with_variants() { + let src = "enum Foo { X(i32), y(Field, u32), Z }"; + let mut noir_enum = parse_enum_no_errors(src); + assert_eq!("Foo", noir_enum.name.to_string()); + assert_eq!(noir_enum.variants.len(), 3); + + let variant = noir_enum.variants.remove(0).item; + assert_eq!("X", variant.name.to_string()); + assert!(matches!( + variant.parameters[0].typ, + UnresolvedTypeData::Integer(Signedness::Signed, IntegerBitSize::ThirtyTwo) + )); + + let variant = noir_enum.variants.remove(0).item; + assert_eq!("y", variant.name.to_string()); + assert!(matches!(variant.parameters[0].typ, UnresolvedTypeData::FieldElement)); + assert!(matches!(variant.parameters[1].typ, UnresolvedTypeData::Integer(..))); + + let variant = noir_enum.variants.remove(0).item; + assert_eq!("Z", variant.name.to_string()); + assert_eq!(variant.parameters.len(), 0); + } + + #[test] + fn parse_empty_enum_with_doc_comments() { + let src = "/// Hello\nenum Foo {}"; + let (module, errors) = parse_program(src); + expect_no_errors(&errors); + assert_eq!(module.items.len(), 1); + let item = &module.items[0]; + assert_eq!(item.doc_comments.len(), 1); + let ItemKind::Enum(noir_enum) = &item.kind else { + panic!("Expected enum"); + }; + assert_eq!("Foo", noir_enum.name.to_string()); + } + + #[test] + fn parse_unclosed_enum() { + let src = "enum Foo {"; + let (module, errors) = parse_program(src); + assert_eq!(errors.len(), 2); + assert_eq!(module.items.len(), 1); + let item = &module.items[0]; + let ItemKind::Enum(noir_enum) = &item.kind else { + panic!("Expected enum"); + }; + assert_eq!("Foo", noir_enum.name.to_string()); + } + + #[test] + fn parse_error_no_function_attributes_allowed_on_enum() { + let src = " + #[test] enum Foo {} + ^^^^^^^ + "; + let (src, _) = get_source_with_error_span(src); + let (_, errors) = parse_program(&src); + let reason = errors[0].reason().unwrap(); + assert!(matches!(reason, ParserErrorReason::NoFunctionAttributesAllowedOnType)); + } + + #[test] + fn recovers_on_non_field() { + let src = " + enum Foo { 42 X(i32) } + ^^ + "; + let (src, _) = get_source_with_error_span(src); + let (module, errors) = parse_program(&src); + + assert_eq!(module.items.len(), 1); + let item = &module.items[0]; + let ItemKind::Enum(noir_enum) = &item.kind else { + panic!("Expected enum"); + }; + assert_eq!("Foo", noir_enum.name.to_string()); + assert_eq!(noir_enum.variants.len(), 1); + + let error = &errors[1]; + assert_eq!(error.to_string(), "Expected an identifier but found '42'"); + } +} diff --git a/noir/noir-repo/compiler/noirc_frontend/src/parser/parser/item.rs b/noir/noir-repo/compiler/noirc_frontend/src/parser/parser/item.rs index ce712b559d8..d928d8e82d3 100644 --- a/noir/noir-repo/compiler/noirc_frontend/src/parser/parser/item.rs +++ b/noir/noir-repo/compiler/noirc_frontend/src/parser/parser/item.rs @@ -107,6 +107,7 @@ impl<'a> Parser<'a> { /// ( Use /// | ModOrContract /// | Struct + /// | Enum /// | Impl /// | Trait /// | Global @@ -148,6 +149,16 @@ impl<'a> Parser<'a> { ))]; } + if self.eat_keyword(Keyword::Enum) { + self.comptime_mutable_and_unconstrained_not_applicable(modifiers); + + return vec![ItemKind::Enum(self.parse_enum( + attributes, + modifiers.visibility, + start_span, + ))]; + } + if self.eat_keyword(Keyword::Impl) { self.comptime_mutable_and_unconstrained_not_applicable(modifiers); diff --git a/noir/noir-repo/compiler/noirc_frontend/src/parser/parser/module.rs b/noir/noir-repo/compiler/noirc_frontend/src/parser/parser/module.rs index da733168099..1bc3d7b5beb 100644 --- a/noir/noir-repo/compiler/noirc_frontend/src/parser/parser/module.rs +++ b/noir/noir-repo/compiler/noirc_frontend/src/parser/parser/module.rs @@ -25,6 +25,7 @@ impl<'a> Parser<'a> { visibility, ident: Ident::default(), outer_attributes, + has_semicolon: false, }); }; @@ -41,10 +42,16 @@ impl<'a> Parser<'a> { is_contract, }) } else { - if !self.eat_semicolons() { + let has_semicolon = self.eat_semicolons(); + if !has_semicolon { self.expected_token(Token::Semicolon); } - ItemKind::ModuleDecl(ModuleDeclaration { visibility, ident, outer_attributes }) + ItemKind::ModuleDecl(ModuleDeclaration { + visibility, + ident, + outer_attributes, + has_semicolon, + }) } } } diff --git a/noir/noir-repo/compiler/noirc_frontend/src/parser/parser/statement.rs b/noir/noir-repo/compiler/noirc_frontend/src/parser/parser/statement.rs index 465e48e3bad..005216b1deb 100644 --- a/noir/noir-repo/compiler/noirc_frontend/src/parser/parser/statement.rs +++ b/noir/noir-repo/compiler/noirc_frontend/src/parser/parser/statement.rs @@ -157,8 +157,8 @@ impl<'a> Parser<'a> { return Some(StatementKind::For(for_loop)); } - if let Some(block) = self.parse_loop() { - return Some(StatementKind::Loop(block)); + if let Some((block, span)) = self.parse_loop() { + return Some(StatementKind::Loop(block, span)); } if let Some(kind) = self.parse_if_expr() { @@ -293,11 +293,14 @@ impl<'a> Parser<'a> { } /// LoopStatement = 'loop' Block - fn parse_loop(&mut self) -> Option { + fn parse_loop(&mut self) -> Option<(Expression, Span)> { + let start_span = self.current_token_span; if !self.eat_keyword(Keyword::Loop) { return None; } + self.push_error(ParserErrorReason::ExperimentalFeature("loops"), start_span); + let block_start_span = self.current_token_span; let block = if let Some(block) = self.parse_block() { Expression { @@ -309,7 +312,7 @@ impl<'a> Parser<'a> { Expression { kind: ExpressionKind::Error, span: self.span_since(block_start_span) } }; - Some(block) + Some((block, start_span)) } /// ForRange @@ -819,21 +822,25 @@ mod tests { #[test] fn parses_empty_loop() { let src = "loop { }"; - let statement = parse_statement_no_errors(src); - let StatementKind::Loop(block) = statement.kind else { + let mut parser = Parser::for_str(src); + let statement = parser.parse_statement_or_error(); + let StatementKind::Loop(block, span) = statement.kind else { panic!("Expected loop"); }; let ExpressionKind::Block(block) = block.kind else { panic!("Expected block"); }; assert!(block.statements.is_empty()); + assert_eq!(span.start(), 0); + assert_eq!(span.end(), 4); } #[test] fn parses_loop_with_statements() { let src = "loop { 1; 2 }"; - let statement = parse_statement_no_errors(src); - let StatementKind::Loop(block) = statement.kind else { + let mut parser = Parser::for_str(src); + let statement = parser.parse_statement_or_error(); + let StatementKind::Loop(block, _) = statement.kind else { panic!("Expected loop"); }; let ExpressionKind::Block(block) = block.kind else { diff --git a/noir/noir-repo/compiler/noirc_frontend/src/parser/parser/structs.rs b/noir/noir-repo/compiler/noirc_frontend/src/parser/parser/structs.rs index da8ac64e021..b066565e680 100644 --- a/noir/noir-repo/compiler/noirc_frontend/src/parser/parser/structs.rs +++ b/noir/noir-repo/compiler/noirc_frontend/src/parser/parser/structs.rs @@ -251,7 +251,7 @@ mod tests { let (src, span) = get_source_with_error_span(src); let (_, errors) = parse_program(&src); let reason = get_single_error_reason(&errors, span); - assert!(matches!(reason, ParserErrorReason::NoFunctionAttributesAllowedOnStruct)); + assert!(matches!(reason, ParserErrorReason::NoFunctionAttributesAllowedOnType)); } #[test] diff --git a/noir/noir-repo/compiler/noirc_frontend/src/parser/parser/tests.rs b/noir/noir-repo/compiler/noirc_frontend/src/parser/parser/tests.rs index ea8b1fc638d..7308458e948 100644 --- a/noir/noir-repo/compiler/noirc_frontend/src/parser/parser/tests.rs +++ b/noir/noir-repo/compiler/noirc_frontend/src/parser/parser/tests.rs @@ -44,7 +44,7 @@ pub(super) fn get_single_error_reason( } pub(super) fn expect_no_errors(errors: &[ParserError]) { - if errors.is_empty() { + if errors.is_empty() || errors.iter().all(|error| error.is_warning()) { return; } diff --git a/noir/noir-repo/compiler/noirc_frontend/src/parser/parser/types.rs b/noir/noir-repo/compiler/noirc_frontend/src/parser/parser/types.rs index f20483f0a7b..210bfe16bef 100644 --- a/noir/noir-repo/compiler/noirc_frontend/src/parser/parser/types.rs +++ b/noir/noir-repo/compiler/noirc_frontend/src/parser/parser/types.rs @@ -208,6 +208,9 @@ impl<'a> Parser<'a> { if self.eat_keyword(Keyword::StructDefinition) { return Some(UnresolvedTypeData::Quoted(QuotedType::StructDefinition)); } + if self.eat_keyword(Keyword::EnumDefinition) { + return Some(UnresolvedTypeData::Quoted(QuotedType::EnumDefinition)); + } if self.eat_keyword(Keyword::TraitConstraint) { return Some(UnresolvedTypeData::Quoted(QuotedType::TraitConstraint)); } diff --git a/noir/noir-repo/compiler/noirc_frontend/src/resolve_locations.rs b/noir/noir-repo/compiler/noirc_frontend/src/resolve_locations.rs index b9e86bf0ef7..4daf088a2f1 100644 --- a/noir/noir-repo/compiler/noirc_frontend/src/resolve_locations.rs +++ b/noir/noir-repo/compiler/noirc_frontend/src/resolve_locations.rs @@ -93,7 +93,7 @@ impl NodeInterner { fn get_type_location_from_index(&self, index: impl Into) -> Option { match self.id_type(index.into()) { - Type::Struct(struct_type, _) => Some(struct_type.borrow().location), + Type::DataType(struct_type, _) => Some(struct_type.borrow().location), _ => None, } } @@ -150,12 +150,12 @@ impl NodeInterner { let expr_rhs = &expr_member_access.rhs; let lhs_self_struct = match self.id_type(expr_lhs) { - Type::Struct(struct_type, _) => struct_type, + Type::DataType(struct_type, _) => struct_type, _ => return None, }; let struct_type = lhs_self_struct.borrow(); - let field_names = struct_type.field_names(); + let field_names = struct_type.field_names()?; field_names.iter().find(|field_name| field_name.0 == expr_rhs.0).map(|found_field_name| { Location::new(found_field_name.span(), struct_type.location.file) @@ -217,7 +217,7 @@ impl NodeInterner { .iter() .find(|(_typ, type_ref_location)| type_ref_location.contains(&location)) .and_then(|(typ, _)| match typ { - Type::Struct(struct_typ, _) => Some(struct_typ.borrow().location), + Type::DataType(struct_typ, _) => Some(struct_typ.borrow().location), _ => None, }) } diff --git a/noir/noir-repo/compiler/noirc_frontend/src/tests.rs b/noir/noir-repo/compiler/noirc_frontend/src/tests.rs index 637b15e7197..6acb3b4b59e 100644 --- a/noir/noir-repo/compiler/noirc_frontend/src/tests.rs +++ b/noir/noir-repo/compiler/noirc_frontend/src/tests.rs @@ -903,6 +903,7 @@ fn find_lambda_captures(stmts: &[StmtId], interner: &NodeInterner, result: &mut HirStatement::Constrain(constr_stmt) => constr_stmt.0, HirStatement::Semi(semi_expr) => semi_expr, HirStatement::For(for_loop) => for_loop.block, + HirStatement::Loop(block) => block, HirStatement::Error => panic!("Invalid HirStatement!"), HirStatement::Break => panic!("Unexpected break"), HirStatement::Continue => panic!("Unexpected continue"), @@ -2817,12 +2818,13 @@ fn duplicate_struct_field() { let errors = get_program_errors(src); assert_eq!(errors.len(), 1); - let CompilationError::DefinitionError(DefCollectorErrorKind::DuplicateField { + let CompilationError::DefinitionError(DefCollectorErrorKind::Duplicate { + typ: _, first_def, second_def, }) = &errors[0].0 else { - panic!("Expected a duplicate field error, got {:?}", errors[0].0); + panic!("Expected a 'duplicate' error, got {:?}", errors[0].0); }; assert_eq!(first_def.to_string(), "x"); @@ -3015,13 +3017,13 @@ fn do_not_eagerly_error_on_cast_on_type_variable() { #[test] fn error_on_cast_over_type_variable() { let src = r#" - pub fn foo(x: T, f: fn(T) -> U) -> U { + pub fn foo(f: fn(T) -> U, x: T, ) -> U { f(x) } fn main() { let x = "a"; - let _: Field = foo(x, |x| x as Field); + let _: Field = foo(|x| x as Field, x); } "#; @@ -3456,6 +3458,11 @@ fn arithmetic_generics_rounding_fail_on_struct() { #[test] fn unconditional_recursion_fail() { + // These examples are self recursive top level functions, which would actually + // not be inlined in the SSA (there is nothing to inline into but self), so it + // wouldn't panic due to infinite recursion, but the errors asserted here + // come from the compilation checks, which does static analysis to catch the + // problem before it even has a chance to cause a panic. let srcs = vec![ r#" fn main() { @@ -3978,3 +3985,196 @@ fn checks_visibility_of_trait_related_to_trait_impl_on_method_call() { "#; assert_no_errors(src); } + +#[test] +fn infers_lambda_argument_from_method_call_function_type() { + let src = r#" + struct Foo { + value: Field, + } + + impl Foo { + fn foo(self) -> Field { + self.value + } + } + + struct Box { + value: T, + } + + impl Box { + fn map(self, f: fn(T) -> U) -> Box { + Box { value: f(self.value) } + } + } + + fn main() { + let box = Box { value: Foo { value: 1 } }; + let _ = box.map(|foo| foo.foo()); + } + "#; + assert_no_errors(src); +} + +#[test] +fn infers_lambda_argument_from_call_function_type() { + let src = r#" + struct Foo { + value: Field, + } + + fn call(f: fn(Foo) -> Field) -> Field { + f(Foo { value: 1 }) + } + + fn main() { + let _ = call(|foo| foo.value); + } + "#; + assert_no_errors(src); +} + +#[test] +fn infers_lambda_argument_from_call_function_type_in_generic_call() { + let src = r#" + struct Foo { + value: Field, + } + + fn call(t: T, f: fn(T) -> Field) -> Field { + f(t) + } + + fn main() { + let _ = call(Foo { value: 1 }, |foo| foo.value); + } + "#; + assert_no_errors(src); +} + +#[test] +fn regression_7088() { + // A test for code that initially broke when implementing inferring + // lambda parameter types from the function type related to the call + // the lambda is in (PR #7088). + let src = r#" + struct U60Repr {} + + impl U60Repr { + fn new(_: [Field; N * NumFieldSegments]) -> Self { + U60Repr {} + } + } + + fn main() { + let input: [Field; 6] = [0; 6]; + let _: U60Repr<3, 6> = U60Repr::new(input); + } + "#; + assert_no_errors(src); +} + +#[test] +fn error_with_duplicate_enum_variant() { + let src = r#" + enum Foo { + Bar(i32), + Bar(u8), + } + + fn main() {} + "#; + let errors = get_program_errors(src); + assert_eq!(errors.len(), 2); + assert!(matches!( + &errors[0].0, + CompilationError::DefinitionError(DefCollectorErrorKind::Duplicate { .. }) + )); + assert!(matches!( + &errors[1].0, + CompilationError::ResolverError(ResolverError::UnusedItem { .. }) + )); +} + +#[test] +fn errors_on_empty_loop_no_break() { + let src = r#" + fn main() { + /// Safety: test + unsafe { + foo() + } + } + + unconstrained fn foo() { + loop {} + } + "#; + let errors = get_program_errors(src); + assert_eq!(errors.len(), 1); + assert!(matches!( + &errors[0].0, + CompilationError::ResolverError(ResolverError::LoopWithoutBreak { .. }) + )); +} + +#[test] +fn errors_on_loop_without_break() { + let src = r#" + fn main() { + /// Safety: test + unsafe { + foo() + } + } + + unconstrained fn foo() { + let mut x = 1; + loop { + x += 1; + bar(x); + } + } + + fn bar(_: Field) {} + "#; + let errors = get_program_errors(src); + assert_eq!(errors.len(), 1); + assert!(matches!( + &errors[0].0, + CompilationError::ResolverError(ResolverError::LoopWithoutBreak { .. }) + )); +} + +#[test] +fn errors_on_loop_without_break_with_nested_loop() { + let src = r#" + fn main() { + /// Safety: test + unsafe { + foo() + } + } + + unconstrained fn foo() { + let mut x = 1; + loop { + x += 1; + bar(x); + loop { + x += 2; + break; + } + } + } + + fn bar(_: Field) {} + "#; + let errors = get_program_errors(src); + assert_eq!(errors.len(), 1); + assert!(matches!( + &errors[0].0, + CompilationError::ResolverError(ResolverError::LoopWithoutBreak { .. }) + )); +} diff --git a/noir/noir-repo/compiler/noirc_frontend/src/tests/metaprogramming.rs b/noir/noir-repo/compiler/noirc_frontend/src/tests/metaprogramming.rs index 8256744e18f..b42342fa47d 100644 --- a/noir/noir-repo/compiler/noirc_frontend/src/tests/metaprogramming.rs +++ b/noir/noir-repo/compiler/noirc_frontend/src/tests/metaprogramming.rs @@ -3,6 +3,7 @@ use noirc_errors::Spanned; use crate::{ ast::Ident, hir::{ + comptime::InterpreterError, def_collector::{ dc_crate::CompilationError, errors::{DefCollectorErrorKind, DuplicateType}, @@ -26,6 +27,26 @@ fn comptime_let() { assert_eq!(errors.len(), 0); } +#[test] +fn comptime_code_rejects_dynamic_variable() { + let src = r#"fn main(x: Field) { + comptime let my_var = (x - x) + 2; + assert_eq(my_var, 2); + }"#; + let errors = get_program_errors(src); + + assert_eq!(errors.len(), 1); + match &errors[0].0 { + CompilationError::InterpreterError(InterpreterError::NonComptimeVarReferenced { + name, + .. + }) => { + assert_eq!(name, "x"); + } + _ => panic!("expected an InterpreterError"), + } +} + #[test] fn comptime_type_in_runtime_code() { let source = "pub fn foo(_f: FunctionDefinition) {}"; diff --git a/noir/noir-repo/compiler/noirc_frontend/src/tests/traits.rs b/noir/noir-repo/compiler/noirc_frontend/src/tests/traits.rs index 5e42d8901fe..7f252b556c2 100644 --- a/noir/noir-repo/compiler/noirc_frontend/src/tests/traits.rs +++ b/noir/noir-repo/compiler/noirc_frontend/src/tests/traits.rs @@ -1,4 +1,5 @@ use crate::hir::def_collector::dc_crate::CompilationError; +use crate::hir::def_collector::errors::DefCollectorErrorKind; use crate::hir::resolution::errors::ResolverError; use crate::hir::resolution::import::PathResolutionError; use crate::hir::type_check::TypeCheckError; @@ -1237,6 +1238,64 @@ fn warns_if_trait_is_not_in_scope_for_generic_function_call_and_there_is_only_on assert_eq!(trait_name, "private_mod::Foo"); } +#[test] +fn error_on_duplicate_impl_with_associated_type() { + let src = r#" + trait Foo { + type Bar; + } + + impl Foo for i32 { + type Bar = u32; + } + + impl Foo for i32 { + type Bar = u8; + } + + fn main() {} + "#; + + // Expect "Impl for type `i32` overlaps with existing impl" + // and "Previous impl defined here" + let errors = get_program_errors(src); + assert_eq!(errors.len(), 2); + + use CompilationError::DefinitionError; + use DefCollectorErrorKind::*; + assert!(matches!(&errors[0].0, DefinitionError(OverlappingImpl { .. }))); + assert!(matches!(&errors[1].0, DefinitionError(OverlappingImplNote { .. }))); +} + +#[test] +fn error_on_duplicate_impl_with_associated_constant() { + let src = r#" + trait Foo { + let Bar: u32; + } + + impl Foo for i32 { + let Bar = 5; + } + + impl Foo for i32 { + let Bar = 6; + } + + fn main() {} + "#; + + // Expect "Impl for type `i32` overlaps with existing impl" + // and "Previous impl defined here" + let errors = get_program_errors(src); + assert_eq!(errors.len(), 2); + + use CompilationError::DefinitionError; + use DefCollectorErrorKind::*; + assert!(matches!(&errors[0].0, DefinitionError(OverlappingImpl { .. }))); + assert!(matches!(&errors[1].0, DefinitionError(OverlappingImplNote { .. }))); +} + // See https://github.com/noir-lang/noir/issues/6530 #[test] fn regression_6530() { @@ -1244,11 +1303,11 @@ fn regression_6530() { pub trait From { fn from(input: T) -> Self; } - + pub trait Into { fn into(self) -> T; } - + impl Into for U where T: From, @@ -1257,23 +1316,23 @@ fn regression_6530() { T::from(self) } } - + struct Foo { inner: Field, } - + impl Into for Foo { fn into(self) -> Field { self.inner } } - + fn main() { let foo = Foo { inner: 0 }; - + // This works: let _: Field = Into::::into(foo); - + // This was failing with 'No matching impl': let _: Field = foo.into(); } @@ -1282,9 +1341,7 @@ fn regression_6530() { assert_eq!(errors.len(), 0); } -// See https://github.com/noir-lang/noir/issues/7090 #[test] -#[should_panic] fn calls_trait_method_using_struct_name_when_multiple_impls_exist() { let src = r#" trait From2 { @@ -1308,3 +1365,32 @@ fn calls_trait_method_using_struct_name_when_multiple_impls_exist() { "#; assert_no_errors(src); } + +#[test] +fn calls_trait_method_using_struct_name_when_multiple_impls_exist_and_errors_turbofish() { + let src = r#" + trait From2 { + fn from2(input: T) -> Self; + } + struct U60Repr {} + impl From2<[Field; 3]> for U60Repr { + fn from2(_: [Field; 3]) -> Self { + U60Repr {} + } + } + impl From2 for U60Repr { + fn from2(_: Field) -> Self { + U60Repr {} + } + } + fn main() { + let _ = U60Repr::::from2([1, 2, 3]); + } + "#; + let errors = get_program_errors(src); + assert_eq!(errors.len(), 1); + assert!(matches!( + errors[0].0, + CompilationError::TypeError(TypeCheckError::TypeMismatch { .. }) + )); +} diff --git a/noir/noir-repo/compiler/noirc_frontend/src/usage_tracker.rs b/noir/noir-repo/compiler/noirc_frontend/src/usage_tracker.rs index 6987358ddb7..ea4919096c0 100644 --- a/noir/noir-repo/compiler/noirc_frontend/src/usage_tracker.rs +++ b/noir/noir-repo/compiler/noirc_frontend/src/usage_tracker.rs @@ -3,14 +3,15 @@ use std::collections::HashMap; use crate::{ ast::{Ident, ItemVisibility}, hir::def_map::ModuleId, - node_interner::{FuncId, GlobalId, StructId, TraitId, TypeAliasId}, + node_interner::{FuncId, GlobalId, TraitId, TypeAliasId, TypeId}, }; #[derive(Debug, Copy, Clone, PartialEq, Eq)] pub enum UnusedItem { Import, Function(FuncId), - Struct(StructId), + Struct(TypeId), + Enum(TypeId), Trait(TraitId), TypeAlias(TypeAliasId), Global(GlobalId), @@ -22,6 +23,7 @@ impl UnusedItem { UnusedItem::Import => "import", UnusedItem::Function(_) => "function", UnusedItem::Struct(_) => "struct", + UnusedItem::Enum(_) => "enum", UnusedItem::Trait(_) => "trait", UnusedItem::TypeAlias(_) => "type alias", UnusedItem::Global(_) => "global", diff --git a/noir/noir-repo/compiler/noirc_printable_type/src/lib.rs b/noir/noir-repo/compiler/noirc_printable_type/src/lib.rs index eb74d2470fb..1831180d0ab 100644 --- a/noir/noir-repo/compiler/noirc_printable_type/src/lib.rs +++ b/noir/noir-repo/compiler/noirc_printable_type/src/lib.rs @@ -36,6 +36,10 @@ pub enum PrintableType { name: String, fields: Vec<(String, PrintableType)>, }, + Enum { + name: String, + variants: Vec<(String, Vec)>, + }, String { length: u32, }, diff --git a/noir/noir-repo/cspell.json b/noir/noir-repo/cspell.json index ed9f7427c6f..1174a56dd33 100644 --- a/noir/noir-repo/cspell.json +++ b/noir/noir-repo/cspell.json @@ -32,9 +32,11 @@ "boilerplates", "bridgekeeper", "brillig", + "brillig_", "bunx", "bytecount", "cachix", + "callees", "callsite", "callsites", "callstack", @@ -204,6 +206,7 @@ "Secpr", "signedness", "signorecello", + "signum", "smallvec", "smol", "splitn", diff --git a/noir/noir-repo/noir_stdlib/src/array/mod.nr b/noir/noir-repo/noir_stdlib/src/array/mod.nr index 47dc3ca7bb9..85cc0580aae 100644 --- a/noir/noir-repo/noir_stdlib/src/array/mod.nr +++ b/noir/noir-repo/noir_stdlib/src/array/mod.nr @@ -157,7 +157,7 @@ where /// } /// ``` pub fn sort(self) -> Self { - self.sort_via(|a: T, b: T| a <= b) + self.sort_via(|a, b| a <= b) } } diff --git a/noir/noir-repo/noir_stdlib/src/collections/bounded_vec.nr b/noir/noir-repo/noir_stdlib/src/collections/bounded_vec.nr index 7aed5e6a0e4..c030544e791 100644 --- a/noir/noir-repo/noir_stdlib/src/collections/bounded_vec.nr +++ b/noir/noir-repo/noir_stdlib/src/collections/bounded_vec.nr @@ -1,4 +1,4 @@ -use crate::{cmp::Eq, convert::From}; +use crate::{cmp::Eq, convert::From, runtime::is_unconstrained, static_assert}; /// A `BoundedVec` is a growable storage similar to a `Vec` except that it /// is bounded with a maximum possible length. Unlike `Vec`, `BoundedVec` is not implemented @@ -320,12 +320,18 @@ impl BoundedVec { let new_len = self.len + append_len; assert(new_len <= MaxLen, "extend_from_bounded_vec out of bounds"); - let mut exceeded_len = false; - for i in 0..Len { - exceeded_len |= i == append_len; - if !exceeded_len { + if is_unconstrained() { + for i in 0..append_len { self.storage[self.len + i] = vec.get_unchecked(i); } + } else { + let mut exceeded_len = false; + for i in 0..Len { + exceeded_len |= i == append_len; + if !exceeded_len { + self.storage[self.len + i] = vec.get_unchecked(i); + } + } } self.len = new_len; } @@ -339,7 +345,7 @@ impl BoundedVec { /// let bounded_vec: BoundedVec = BoundedVec::from_array([1, 2, 3]) /// ``` pub fn from_array(array: [T; Len]) -> Self { - assert(Len <= MaxLen, "from array out of bounds"); + static_assert(Len <= MaxLen, "from array out of bounds"); let mut vec: BoundedVec = BoundedVec::new(); vec.extend_from_array(array); vec @@ -389,12 +395,19 @@ impl BoundedVec { /// ``` pub fn any(self, predicate: fn[Env](T) -> bool) -> bool { let mut ret = false; - let mut exceeded_len = false; - for i in 0..MaxLen { - exceeded_len |= i == self.len; - if !exceeded_len { + if is_unconstrained() { + for i in 0..self.len { ret |= predicate(self.storage[i]); } + } else { + let mut ret = false; + let mut exceeded_len = false; + for i in 0..MaxLen { + exceeded_len |= i == self.len; + if !exceeded_len { + ret |= predicate(self.storage[i]); + } + } } ret } @@ -413,11 +426,19 @@ impl BoundedVec { pub fn map(self, f: fn[Env](T) -> U) -> BoundedVec { let mut ret = BoundedVec::new(); ret.len = self.len(); - for i in 0..MaxLen { - if i < self.len() { + + if is_unconstrained() { + for i in 0..self.len() { ret.storage[i] = f(self.get_unchecked(i)); } + } else { + for i in 0..MaxLen { + if i < self.len() { + ret.storage[i] = f(self.get_unchecked(i)); + } + } } + ret } @@ -437,11 +458,19 @@ impl BoundedVec { pub fn from_parts(mut array: [T; MaxLen], len: u32) -> Self { assert(len <= MaxLen); let zeroed = crate::mem::zeroed(); - for i in 0..MaxLen { - if i >= len { + + if is_unconstrained() { + for i in len..MaxLen { array[i] = zeroed; } + } else { + for i in 0..MaxLen { + if i >= len { + array[i] = zeroed; + } + } } + BoundedVec { storage: array, len } } diff --git a/noir/noir-repo/noir_stdlib/src/field/mod.nr b/noir/noir-repo/noir_stdlib/src/field/mod.nr index 7ebeb29b05b..d066ad2e9de 100644 --- a/noir/noir-repo/noir_stdlib/src/field/mod.nr +++ b/noir/noir-repo/noir_stdlib/src/field/mod.nr @@ -1,5 +1,5 @@ pub mod bn254; -use crate::runtime::is_unconstrained; +use crate::{runtime::is_unconstrained, static_assert}; use bn254::lt as bn254_lt; impl Field { @@ -10,7 +10,10 @@ impl Field { // docs:start:assert_max_bit_size pub fn assert_max_bit_size(self) { // docs:end:assert_max_bit_size - assert(BIT_SIZE < modulus_num_bits() as u32); + static_assert( + BIT_SIZE < modulus_num_bits() as u32, + "BIT_SIZE must be less than modulus_num_bits", + ); self.__assert_max_bit_size(BIT_SIZE); } @@ -29,9 +32,7 @@ impl Field { /// (e.g. 254 for the BN254 field) allow for multiple bit decompositions. This is due to how the `Field` will /// wrap around due to overflow when verifying the decomposition. #[builtin(to_le_bits)] - // docs:start:to_le_bits - pub fn to_le_bits(self: Self) -> [u1; N] {} - // docs:end:to_le_bits + fn _to_le_bits(self: Self) -> [u1; N] {} /// Decomposes `self` into its big endian bit decomposition as a `[u1; N]` array. /// This array will be zero padded should not all bits be necessary to represent `self`. @@ -45,9 +46,71 @@ impl Field { /// (e.g. 254 for the BN254 field) allow for multiple bit decompositions. This is due to how the `Field` will /// wrap around due to overflow when verifying the decomposition. #[builtin(to_be_bits)] + fn _to_be_bits(self: Self) -> [u1; N] {} + + /// Decomposes `self` into its little endian bit decomposition as a `[u1; N]` array. + /// This slice will be zero padded should not all bits be necessary to represent `self`. + /// + /// # Failures + /// Causes a constraint failure for `Field` values exceeding `2^N` as the resulting slice will not + /// be able to represent the original `Field`. + /// + /// # Safety + /// The bit decomposition returned is canonical and is guaranteed to not overflow the modulus. + // docs:start:to_le_bits + pub fn to_le_bits(self: Self) -> [u1; N] { + // docs:end:to_le_bits + let bits = self._to_le_bits(); + + if !is_unconstrained() { + // Ensure that the byte decomposition does not overflow the modulus + let p = modulus_le_bits(); + assert(bits.len() <= p.len()); + let mut ok = bits.len() != p.len(); + for i in 0..N { + if !ok { + if (bits[N - 1 - i] != p[N - 1 - i]) { + assert(p[N - 1 - i] == 1); + ok = true; + } + } + } + assert(ok); + } + bits + } + + /// Decomposes `self` into its big endian bit decomposition as a `[u1; N]` array. + /// This array will be zero padded should not all bits be necessary to represent `self`. + /// + /// # Failures + /// Causes a constraint failure for `Field` values exceeding `2^N` as the resulting slice will not + /// be able to represent the original `Field`. + /// + /// # Safety + /// The bit decomposition returned is canonical and is guaranteed to not overflow the modulus. // docs:start:to_be_bits - pub fn to_be_bits(self: Self) -> [u1; N] {} - // docs:end:to_be_bits + pub fn to_be_bits(self: Self) -> [u1; N] { + // docs:end:to_be_bits + let bits = self._to_be_bits(); + + if !is_unconstrained() { + // Ensure that the decomposition does not overflow the modulus + let p = modulus_be_bits(); + assert(bits.len() <= p.len()); + let mut ok = bits.len() != p.len(); + for i in 0..N { + if !ok { + if (bits[i] != p[i]) { + assert(p[i] == 1); + ok = true; + } + } + } + assert(ok); + } + bits + } /// Decomposes `self` into its little endian byte decomposition as a `[u8;N]` array /// This array will be zero padded should not all bytes be necessary to represent `self`. @@ -61,6 +124,10 @@ impl Field { // docs:start:to_le_bytes pub fn to_le_bytes(self: Self) -> [u8; N] { // docs:end:to_le_bytes + static_assert( + N <= modulus_le_bytes().len(), + "N must be less than or equal to modulus_le_bytes().len()", + ); // Compute the byte decomposition let bytes = self.to_le_radix(256); @@ -94,6 +161,10 @@ impl Field { // docs:start:to_be_bytes pub fn to_be_bytes(self: Self) -> [u8; N] { // docs:end:to_be_bytes + static_assert( + N <= modulus_le_bytes().len(), + "N must be less than or equal to modulus_le_bytes().len()", + ); // Compute the byte decomposition let bytes = self.to_be_radix(256); @@ -119,7 +190,9 @@ impl Field { pub fn to_le_radix(self: Self, radix: u32) -> [u8; N] { // Brillig does not need an immediate radix if !crate::runtime::is_unconstrained() { - crate::assert_constant(radix); + static_assert(1 < radix, "radix must be greater than 1"); + static_assert(radix <= 256, "radix must be less than or equal to 256"); + static_assert(radix & (radix - 1) == 0, "radix must be a power of 2"); } self.__to_le_radix(radix) } @@ -139,6 +212,7 @@ impl Field { #[builtin(to_le_radix)] fn __to_le_radix(self, radix: u32) -> [u8; N] {} + // `_radix` must be less than 256 #[builtin(to_be_radix)] fn __to_be_radix(self, radix: u32) -> [u8; N] {} @@ -172,6 +246,10 @@ impl Field { /// Convert a little endian byte array to a field element. /// If the provided byte array overflows the field modulus then the Field will silently wrap around. pub fn from_le_bytes(bytes: [u8; N]) -> Field { + static_assert( + N <= modulus_le_bytes().len(), + "N must be less than or equal to modulus_le_bytes().len()", + ); let mut v = 1; let mut result = 0; @@ -263,6 +341,7 @@ fn lt_fallback(x: Field, y: Field) -> bool { } mod tests { + use crate::{panic::panic, runtime}; use super::field_less_than; #[test] @@ -323,6 +402,77 @@ mod tests { } // docs:end:to_le_radix_example + #[test(should_fail_with = "radix must be greater than 1")] + fn test_to_le_radix_1() { + // this test should only fail in constrained mode + if !runtime::is_unconstrained() { + let field = 2; + let _: [u8; 8] = field.to_le_radix(1); + } else { + panic(f"radix must be greater than 1"); + } + } + + // TODO: Update this test to account for the Brillig restriction that the radix must be greater than 2 + // #[test] + // fn test_to_le_radix_brillig_1() { + // // this test should only fail in constrained mode + // if runtime::is_unconstrained() { + // let field = 1; + // let out: [u8; 8] = field.to_le_radix(1); + // crate::println(out); + // let expected = [0; 8]; + // assert(out == expected, "unexpected result"); + // } + // } + + #[test(should_fail_with = "radix must be a power of 2")] + fn test_to_le_radix_3() { + // this test should only fail in constrained mode + if !runtime::is_unconstrained() { + let field = 2; + let _: [u8; 8] = field.to_le_radix(3); + } else { + panic(f"radix must be a power of 2"); + } + } + + #[test] + fn test_to_le_radix_brillig_3() { + // this test should only fail in constrained mode + if runtime::is_unconstrained() { + let field = 1; + let out: [u8; 8] = field.to_le_radix(3); + let mut expected = [0; 8]; + expected[0] = 1; + assert(out == expected, "unexpected result"); + } + } + + #[test(should_fail_with = "radix must be less than or equal to 256")] + fn test_to_le_radix_512() { + // this test should only fail in constrained mode + if !runtime::is_unconstrained() { + let field = 2; + let _: [u8; 8] = field.to_le_radix(512); + } else { + panic(f"radix must be less than or equal to 256") + } + } + + // TODO: Update this test to account for the Brillig restriction that the radix must be less than 512 + // #[test] + // fn test_to_le_radix_brillig_512() { + // // this test should only fail in constrained mode + // if runtime::is_unconstrained() { + // let field = 1; + // let out: [u8; 8] = field.to_le_radix(512); + // let mut expected = [0; 8]; + // expected[0] = 1; + // assert(out == expected, "unexpected result"); + // } + // } + #[test] unconstrained fn test_field_less_than() { assert(field_less_than(0, 1)); diff --git a/noir/noir-repo/noir_stdlib/src/meta/ctstring.nr b/noir/noir-repo/noir_stdlib/src/meta/ctstring.nr index e23567ece7d..00b4f1fdb6f 100644 --- a/noir/noir-repo/noir_stdlib/src/meta/ctstring.nr +++ b/noir/noir-repo/noir_stdlib/src/meta/ctstring.nr @@ -7,7 +7,8 @@ impl CtString { "".as_ctstring() } - // Bug: using &mut self as the object results in this method not being found + // TODO(https://github.com/noir-lang/noir/issues/6980): Bug: using &mut self + // as the object results in this method not being found // docs:start:append_str pub comptime fn append_str(self, s: str) -> Self { // docs:end:append_str diff --git a/noir/noir-repo/noir_stdlib/src/meta/expr.nr b/noir/noir-repo/noir_stdlib/src/meta/expr.nr index 7538b26dc44..a1663135c20 100644 --- a/noir/noir-repo/noir_stdlib/src/meta/expr.nr +++ b/noir/noir-repo/noir_stdlib/src/meta/expr.nr @@ -285,33 +285,31 @@ impl Expr { } comptime fn modify_array(expr: Expr, f: fn[Env](Expr) -> Option) -> Option { - expr.as_array().map(|exprs: [Expr]| { + expr.as_array().map(|exprs| { let exprs = modify_expressions(exprs, f); new_array(exprs) }) } comptime fn modify_assert(expr: Expr, f: fn[Env](Expr) -> Option) -> Option { - expr.as_assert().map(|expr: (Expr, Option)| { - let (predicate, msg) = expr; + expr.as_assert().map(|(predicate, msg)| { let predicate = predicate.modify(f); - let msg = msg.map(|msg: Expr| msg.modify(f)); + let msg = msg.map(|msg| msg.modify(f)); new_assert(predicate, msg) }) } comptime fn modify_assert_eq(expr: Expr, f: fn[Env](Expr) -> Option) -> Option { - expr.as_assert_eq().map(|expr: (Expr, Expr, Option)| { - let (lhs, rhs, msg) = expr; + expr.as_assert_eq().map(|(lhs, rhs, msg)| { let lhs = lhs.modify(f); let rhs = rhs.modify(f); - let msg = msg.map(|msg: Expr| msg.modify(f)); + let msg = msg.map(|msg| msg.modify(f)); new_assert_eq(lhs, rhs, msg) }) } comptime fn modify_assign(expr: Expr, f: fn[Env](Expr) -> Option) -> Option { - expr.as_assign().map(|expr: (Expr, Expr)| { + expr.as_assign().map(|expr| { let (lhs, rhs) = expr; let lhs = lhs.modify(f); let rhs = rhs.modify(f); @@ -320,8 +318,7 @@ comptime fn modify_assign(expr: Expr, f: fn[Env](Expr) -> Option) -> } comptime fn modify_binary_op(expr: Expr, f: fn[Env](Expr) -> Option) -> Option { - expr.as_binary_op().map(|expr: (Expr, BinaryOp, Expr)| { - let (lhs, op, rhs) = expr; + expr.as_binary_op().map(|(lhs, op, rhs)| { let lhs = lhs.modify(f); let rhs = rhs.modify(f); new_binary_op(lhs, op, rhs) @@ -329,34 +326,29 @@ comptime fn modify_binary_op(expr: Expr, f: fn[Env](Expr) -> Option) } comptime fn modify_block(expr: Expr, f: fn[Env](Expr) -> Option) -> Option { - expr.as_block().map(|exprs: [Expr]| { + expr.as_block().map(|exprs| { let exprs = modify_expressions(exprs, f); new_block(exprs) }) } comptime fn modify_cast(expr: Expr, f: fn[Env](Expr) -> Option) -> Option { - expr.as_cast().map(|expr: (Expr, UnresolvedType)| { - let (expr, typ) = expr; + expr.as_cast().map(|(expr, typ)| { let expr = expr.modify(f); new_cast(expr, typ) }) } comptime fn modify_comptime(expr: Expr, f: fn[Env](Expr) -> Option) -> Option { - expr.as_comptime().map(|exprs: [Expr]| { - let exprs = exprs.map(|expr: Expr| expr.modify(f)); + expr.as_comptime().map(|exprs| { + let exprs = exprs.map(|expr| expr.modify(f)); new_comptime(exprs) }) } comptime fn modify_constructor(expr: Expr, f: fn[Env](Expr) -> Option) -> Option { - expr.as_constructor().map(|expr: (UnresolvedType, [(Quoted, Expr)])| { - let (typ, fields) = expr; - let fields = fields.map(|field: (Quoted, Expr)| { - let (name, value) = field; - (name, value.modify(f)) - }); + expr.as_constructor().map(|(typ, fields)| { + let fields = fields.map(|(name, value)| (name, value.modify(f))); new_constructor(typ, fields) }) } @@ -365,27 +357,24 @@ comptime fn modify_function_call( expr: Expr, f: fn[Env](Expr) -> Option, ) -> Option { - expr.as_function_call().map(|expr: (Expr, [Expr])| { - let (function, arguments) = expr; + expr.as_function_call().map(|(function, arguments)| { let function = function.modify(f); - let arguments = arguments.map(|arg: Expr| arg.modify(f)); + let arguments = arguments.map(|arg| arg.modify(f)); new_function_call(function, arguments) }) } comptime fn modify_if(expr: Expr, f: fn[Env](Expr) -> Option) -> Option { - expr.as_if().map(|expr: (Expr, Expr, Option)| { - let (condition, consequence, alternative) = expr; + expr.as_if().map(|(condition, consequence, alternative)| { let condition = condition.modify(f); let consequence = consequence.modify(f); - let alternative = alternative.map(|alternative: Expr| alternative.modify(f)); + let alternative = alternative.map(|alternative| alternative.modify(f)); new_if(condition, consequence, alternative) }) } comptime fn modify_index(expr: Expr, f: fn[Env](Expr) -> Option) -> Option { - expr.as_index().map(|expr: (Expr, Expr)| { - let (object, index) = expr; + expr.as_index().map(|(object, index)| { let object = object.modify(f); let index = index.modify(f); new_index(object, index) @@ -393,8 +382,7 @@ comptime fn modify_index(expr: Expr, f: fn[Env](Expr) -> Option) -> O } comptime fn modify_for(expr: Expr, f: fn[Env](Expr) -> Option) -> Option { - expr.as_for().map(|expr: (Quoted, Expr, Expr)| { - let (identifier, array, body) = expr; + expr.as_for().map(|(identifier, array, body)| { let array = array.modify(f); let body = body.modify(f); new_for(identifier, array, body) @@ -402,8 +390,7 @@ comptime fn modify_for(expr: Expr, f: fn[Env](Expr) -> Option) -> Opt } comptime fn modify_for_range(expr: Expr, f: fn[Env](Expr) -> Option) -> Option { - expr.as_for_range().map(|expr: (Quoted, Expr, Expr, Expr)| { - let (identifier, from, to, body) = expr; + expr.as_for_range().map(|(identifier, from, to, body)| { let from = from.modify(f); let to = to.modify(f); let body = body.modify(f); @@ -412,18 +399,15 @@ comptime fn modify_for_range(expr: Expr, f: fn[Env](Expr) -> Option) } comptime fn modify_lambda(expr: Expr, f: fn[Env](Expr) -> Option) -> Option { - expr.as_lambda().map(|expr: ([(Expr, Option)], Option, Expr)| { - let (params, return_type, body) = expr; - let params = - params.map(|param: (Expr, Option)| (param.0.modify(f), param.1)); + expr.as_lambda().map(|(params, return_type, body)| { + let params = params.map(|(name, typ)| (name.modify(f), typ)); let body = body.modify(f); new_lambda(params, return_type, body) }) } comptime fn modify_let(expr: Expr, f: fn[Env](Expr) -> Option) -> Option { - expr.as_let().map(|expr: (Expr, Option, Expr)| { - let (pattern, typ, expr) = expr; + expr.as_let().map(|(pattern, typ, expr)| { let pattern = pattern.modify(f); let expr = expr.modify(f); new_let(pattern, typ, expr) @@ -434,18 +418,16 @@ comptime fn modify_member_access( expr: Expr, f: fn[Env](Expr) -> Option, ) -> Option { - expr.as_member_access().map(|expr: (Expr, Quoted)| { - let (object, name) = expr; + expr.as_member_access().map(|(object, name)| { let object = object.modify(f); new_member_access(object, name) }) } comptime fn modify_method_call(expr: Expr, f: fn[Env](Expr) -> Option) -> Option { - expr.as_method_call().map(|expr: (Expr, Quoted, [UnresolvedType], [Expr])| { - let (object, name, generics, arguments) = expr; + expr.as_method_call().map(|(object, name, generics, arguments)| { let object = object.modify(f); - let arguments = arguments.map(|arg: Expr| arg.modify(f)); + let arguments = arguments.map(|arg| arg.modify(f)); new_method_call(object, name, generics, arguments) }) } @@ -454,8 +436,7 @@ comptime fn modify_repeated_element_array( expr: Expr, f: fn[Env](Expr) -> Option, ) -> Option { - expr.as_repeated_element_array().map(|expr: (Expr, Expr)| { - let (expr, length) = expr; + expr.as_repeated_element_array().map(|(expr, length)| { let expr = expr.modify(f); let length = length.modify(f); new_repeated_element_array(expr, length) @@ -466,8 +447,7 @@ comptime fn modify_repeated_element_slice( expr: Expr, f: fn[Env](Expr) -> Option, ) -> Option { - expr.as_repeated_element_slice().map(|expr: (Expr, Expr)| { - let (expr, length) = expr; + expr.as_repeated_element_slice().map(|(expr, length)| { let expr = expr.modify(f); let length = length.modify(f); new_repeated_element_slice(expr, length) @@ -475,36 +455,35 @@ comptime fn modify_repeated_element_slice( } comptime fn modify_slice(expr: Expr, f: fn[Env](Expr) -> Option) -> Option { - expr.as_slice().map(|exprs: [Expr]| { + expr.as_slice().map(|exprs| { let exprs = modify_expressions(exprs, f); new_slice(exprs) }) } comptime fn modify_tuple(expr: Expr, f: fn[Env](Expr) -> Option) -> Option { - expr.as_tuple().map(|exprs: [Expr]| { + expr.as_tuple().map(|exprs| { let exprs = modify_expressions(exprs, f); new_tuple(exprs) }) } comptime fn modify_unary_op(expr: Expr, f: fn[Env](Expr) -> Option) -> Option { - expr.as_unary_op().map(|expr: (UnaryOp, Expr)| { - let (op, rhs) = expr; + expr.as_unary_op().map(|(op, rhs)| { let rhs = rhs.modify(f); new_unary_op(op, rhs) }) } comptime fn modify_unsafe(expr: Expr, f: fn[Env](Expr) -> Option) -> Option { - expr.as_unsafe().map(|exprs: [Expr]| { - let exprs = exprs.map(|expr: Expr| expr.modify(f)); + expr.as_unsafe().map(|exprs| { + let exprs = exprs.map(|expr| expr.modify(f)); new_unsafe(exprs) }) } comptime fn modify_expressions(exprs: [Expr], f: fn[Env](Expr) -> Option) -> [Expr] { - exprs.map(|expr: Expr| expr.modify(f)) + exprs.map(|expr| expr.modify(f)) } comptime fn new_array(exprs: [Expr]) -> Expr { @@ -554,12 +533,7 @@ comptime fn new_comptime(exprs: [Expr]) -> Expr { } comptime fn new_constructor(typ: UnresolvedType, fields: [(Quoted, Expr)]) -> Expr { - let fields = fields - .map(|field: (Quoted, Expr)| { - let (name, value) = field; - quote { $name: $value } - }) - .join(quote { , }); + let fields = fields.map(|(name, value)| quote { $name: $value }).join(quote { , }); quote { $typ { $fields }}.as_expr().unwrap() } @@ -590,8 +564,7 @@ comptime fn new_lambda( body: Expr, ) -> Expr { let params = params - .map(|param: (Expr, Option)| { - let (name, typ) = param; + .map(|(name, typ)| { if typ.is_some() { let typ = typ.unwrap(); quote { $name: $typ } @@ -678,5 +651,5 @@ comptime fn new_unsafe(exprs: [Expr]) -> Expr { } comptime fn join_expressions(exprs: [Expr], separator: Quoted) -> Quoted { - exprs.map(|expr: Expr| expr.quoted()).join(separator) + exprs.map(|expr| expr.quoted()).join(separator) } diff --git a/noir/noir-repo/noir_stdlib/src/meta/mod.nr b/noir/noir-repo/noir_stdlib/src/meta/mod.nr index 7644d5e1dd1..35ba05ba74d 100644 --- a/noir/noir-repo/noir_stdlib/src/meta/mod.nr +++ b/noir/noir-repo/noir_stdlib/src/meta/mod.nr @@ -112,10 +112,7 @@ pub comptime fn make_trait_impl( let where_clause = where_clause.join(quote {, }); // `for_each_field(field1) $join_fields_with for_each_field(field2) $join_fields_with ...` - let fields = s.fields_as_written().map(|f: (Quoted, Type)| { - let name = f.0; - for_each_field(name) - }); + let fields = s.fields_as_written().map(|(name, _)| for_each_field(name)); let body = body(fields.join(join_fields_with)); quote { diff --git a/noir/noir-repo/noir_stdlib/src/uint128.nr b/noir/noir-repo/noir_stdlib/src/uint128.nr index bcb0746832e..6c9b802f5b3 100644 --- a/noir/noir-repo/noir_stdlib/src/uint128.nr +++ b/noir/noir-repo/noir_stdlib/src/uint128.nr @@ -1,5 +1,6 @@ use crate::cmp::{Eq, Ord, Ordering}; use crate::ops::{Add, BitAnd, BitOr, BitXor, Div, Mul, Not, Rem, Shl, Shr, Sub}; +use crate::static_assert; use super::{convert::AsPrimitive, default::Default}; global pow64: Field = 18446744073709551616; //2^64; @@ -67,11 +68,10 @@ impl U128 { } pub fn from_hex(hex: str) -> U128 { - let N = N as u32; let bytes = hex.as_bytes(); // string must starts with "0x" assert((bytes[0] == 48) & (bytes[1] == 120), "Invalid hexadecimal string"); - assert(N < 35, "Input does not fit into a U128"); + static_assert(N < 35, "Input does not fit into a U128"); let mut lo = 0; let mut hi = 0; diff --git a/noir/noir-repo/test_programs/compile_failure/comptime_static_assert_failure/Nargo.toml b/noir/noir-repo/test_programs/compile_failure/comptime_static_assert_failure/Nargo.toml new file mode 100644 index 00000000000..006fd9f7ffe --- /dev/null +++ b/noir/noir-repo/test_programs/compile_failure/comptime_static_assert_failure/Nargo.toml @@ -0,0 +1,6 @@ +[package] +name = "comptime_static_assert_failure" +type = "bin" +authors = [""] + +[dependencies] \ No newline at end of file diff --git a/noir/noir-repo/test_programs/compile_failure/comptime_static_assert_failure/src/main.nr b/noir/noir-repo/test_programs/compile_failure/comptime_static_assert_failure/src/main.nr new file mode 100644 index 00000000000..fcd757f4c94 --- /dev/null +++ b/noir/noir-repo/test_programs/compile_failure/comptime_static_assert_failure/src/main.nr @@ -0,0 +1,13 @@ +use std::static_assert; + +comptime fn foo(x: Field) -> bool { + static_assert(x == 4, "x != 4"); + x == 4 +} + +fn main() { + comptime { + static_assert(foo(3), "expected message"); + } +} + diff --git a/noir/noir-repo/test_programs/compile_success_empty/comptime_static_assert/Nargo.toml b/noir/noir-repo/test_programs/compile_success_empty/comptime_static_assert/Nargo.toml new file mode 100644 index 00000000000..4c969fe7a79 --- /dev/null +++ b/noir/noir-repo/test_programs/compile_success_empty/comptime_static_assert/Nargo.toml @@ -0,0 +1,6 @@ +[package] +name = "comptime_static_assert" +type = "bin" +authors = [""] + +[dependencies] \ No newline at end of file diff --git a/noir/noir-repo/test_programs/compile_success_empty/comptime_static_assert/src/main.nr b/noir/noir-repo/test_programs/compile_success_empty/comptime_static_assert/src/main.nr new file mode 100644 index 00000000000..2ddbba7b0de --- /dev/null +++ b/noir/noir-repo/test_programs/compile_success_empty/comptime_static_assert/src/main.nr @@ -0,0 +1,19 @@ +use std::static_assert; + +comptime fn foo(x: Field) -> bool { + static_assert(x == 4, "x != 4"); + x == 4 +} + +global C: bool = { + let out = foo(2 + 2); + static_assert(out, "foo did not pass in C"); + out +}; + +fn main() { + comptime { + static_assert(foo(4), "foo did not pass in main"); + static_assert(C, "C did not pass") + } +} diff --git a/noir/noir-repo/test_programs/compile_success_empty/enums/Nargo.toml b/noir/noir-repo/test_programs/compile_success_empty/enums/Nargo.toml new file mode 100644 index 00000000000..3f8b42c8a49 --- /dev/null +++ b/noir/noir-repo/test_programs/compile_success_empty/enums/Nargo.toml @@ -0,0 +1,6 @@ +[package] +name = "enums" +type = "bin" +authors = [""] + +[dependencies] diff --git a/noir/noir-repo/test_programs/compile_success_empty/enums/src/main.nr b/noir/noir-repo/test_programs/compile_success_empty/enums/src/main.nr new file mode 100644 index 00000000000..31619bca596 --- /dev/null +++ b/noir/noir-repo/test_programs/compile_success_empty/enums/src/main.nr @@ -0,0 +1,19 @@ +fn main() { + let _a = Foo::A::(1, 2); + let _b: Foo = Foo::B(3); + let _c = Foo::C(4); + + // (#7172): Single variant enums must be called as functions currently + let _d: fn() -> Foo<(i32, i32)> = Foo::D; + let _d: Foo<(i32, i32)> = Foo::D(); + + // Enum variants are functions and can be passed around as such + let _many_cs = [1, 2, 3].map(Foo::C); +} + +enum Foo { + A(Field, Field), + B(u32), + C(T), + D, +} diff --git a/noir/noir-repo/test_programs/compile_success_empty/inject_context_attribute/src/main.nr b/noir/noir-repo/test_programs/compile_success_empty/inject_context_attribute/src/main.nr index 963d4cea969..e682ea34b23 100644 --- a/noir/noir-repo/test_programs/compile_success_empty/inject_context_attribute/src/main.nr +++ b/noir/noir-repo/test_programs/compile_success_empty/inject_context_attribute/src/main.nr @@ -40,19 +40,16 @@ comptime fn inject_context(f: FunctionDefinition) { } comptime fn mapping_function(expr: Expr, f: FunctionDefinition) -> Option { - expr.as_function_call().and_then(|func_call: (Expr, [Expr])| { - let (name, arguments) = func_call; - name.resolve(Option::some(f)).as_function_definition().and_then( - |function_definition: FunctionDefinition| { - if function_definition.has_named_attribute("inject_context") { - let arguments = arguments.push_front(quote { _context }.as_expr().unwrap()); - let arguments = arguments.map(|arg: Expr| arg.quoted()).join(quote { , }); - Option::some(quote { $name($arguments) }.as_expr().unwrap()) - } else { - Option::none() - } - }, - ) + expr.as_function_call().and_then(|(name, arguments)| { + name.resolve(Option::some(f)).as_function_definition().and_then(|function_definition| { + if function_definition.has_named_attribute("inject_context") { + let arguments = arguments.push_front(quote { _context }.as_expr().unwrap()); + let arguments = arguments.map(|arg| arg.quoted()).join(quote { , }); + Option::some(quote { $name($arguments) }.as_expr().unwrap()) + } else { + Option::none() + } + }) }) } diff --git a/noir/noir-repo/test_programs/compile_success_empty/trait_generics/src/main.nr b/noir/noir-repo/test_programs/compile_success_empty/trait_generics/src/main.nr index 08302ded68c..e8b57b6fe6f 100644 --- a/noir/noir-repo/test_programs/compile_success_empty/trait_generics/src/main.nr +++ b/noir/noir-repo/test_programs/compile_success_empty/trait_generics/src/main.nr @@ -24,7 +24,7 @@ where T: MyInto, { fn into(self) -> [U; N] { - self.map(|x: T| x.into()) + self.map(|x| x.into()) } } diff --git a/noir/noir-repo/test_programs/compile_success_empty/unquote_struct/src/main.nr b/noir/noir-repo/test_programs/compile_success_empty/unquote_struct/src/main.nr index d4ab275858c..12c683a94a8 100644 --- a/noir/noir-repo/test_programs/compile_success_empty/unquote_struct/src/main.nr +++ b/noir/noir-repo/test_programs/compile_success_empty/unquote_struct/src/main.nr @@ -10,14 +10,7 @@ fn foo(x: Field, y: u32) -> u32 { // Given a function, wrap its parameters in a struct definition comptime fn output_struct(f: FunctionDefinition) -> Quoted { - let fields = f - .parameters() - .map(|param: (Quoted, Type)| { - let name = param.0; - let typ = param.1; - quote { $name: $typ, } - }) - .join(quote {}); + let fields = f.parameters().map(|(name, typ)| quote { $name: $typ, }).join(quote {}); quote { struct Foo { $fields } diff --git a/noir/noir-repo/test_programs/compile_success_no_bug/check_unconstrained_regression/src/main.nr b/noir/noir-repo/test_programs/compile_success_no_bug/check_unconstrained_regression/src/main.nr index 174b68fd162..e4cb15f099d 100644 --- a/noir/noir-repo/test_programs/compile_success_no_bug/check_unconstrained_regression/src/main.nr +++ b/noir/noir-repo/test_programs/compile_success_no_bug/check_unconstrained_regression/src/main.nr @@ -23,7 +23,9 @@ impl Trigger { let result = unsafe { convert(self) }; assert(result.a == self.x + 1); assert(result.b == self.y - 1 + self.z[2]); + assert(result.c[0] == self.z[0]); assert(result.c[1] == 0); + assert(result.c[2] == self.z[1]); result } } diff --git a/noir/noir-repo/test_programs/execution_failure/regression_7128/Nargo.toml b/noir/noir-repo/test_programs/execution_failure/regression_7128/Nargo.toml new file mode 100644 index 00000000000..4d7b621526a --- /dev/null +++ b/noir/noir-repo/test_programs/execution_failure/regression_7128/Nargo.toml @@ -0,0 +1,6 @@ +[package] +name = "regression_7128" +type = "bin" +authors = [""] + +[dependencies] \ No newline at end of file diff --git a/noir/noir-repo/test_programs/execution_failure/regression_7128/Prover.toml b/noir/noir-repo/test_programs/execution_failure/regression_7128/Prover.toml new file mode 100644 index 00000000000..dd9b68d125e --- /dev/null +++ b/noir/noir-repo/test_programs/execution_failure/regression_7128/Prover.toml @@ -0,0 +1 @@ +in0 = "1" diff --git a/noir/noir-repo/test_programs/execution_failure/regression_7128/src/main.nr b/noir/noir-repo/test_programs/execution_failure/regression_7128/src/main.nr new file mode 100644 index 00000000000..46759fe90a2 --- /dev/null +++ b/noir/noir-repo/test_programs/execution_failure/regression_7128/src/main.nr @@ -0,0 +1,26 @@ +fn main(in0: Field) -> pub Field { + let mut out0: Field = 0; + let tmp1: Field = in0; + + if (out0 == out0) // <== changing out0 to in0 or removing + { + // the comparison changes the result + let in0_as_bytes: [u8; 32] = in0.to_be_bytes(); + let mut result: [u8; 32] = [0; 32]; + for i in 0..32 { + result[i] = in0_as_bytes[i]; + } + } + + let mut tmp2: Field = 0; // <== moving this to the top of main, + if (0.lt(in0)) // changes the result + { + tmp2 = 1; + } + + out0 = (tmp2 - tmp1); + + assert(out0 != 0, "soundness violation"); + + out0 +} diff --git a/noir/noir-repo/test_programs/execution_success/loop_keyword/Nargo.toml b/noir/noir-repo/test_programs/execution_success/loop_keyword/Nargo.toml new file mode 100644 index 00000000000..8189b407cd9 --- /dev/null +++ b/noir/noir-repo/test_programs/execution_success/loop_keyword/Nargo.toml @@ -0,0 +1,5 @@ +[package] +name = "loop_keyword" +type = "bin" +authors = [""] +[dependencies] diff --git a/noir/noir-repo/test_programs/execution_success/loop_keyword/src/main.nr b/noir/noir-repo/test_programs/execution_success/loop_keyword/src/main.nr new file mode 100644 index 00000000000..b038ae22343 --- /dev/null +++ b/noir/noir-repo/test_programs/execution_success/loop_keyword/src/main.nr @@ -0,0 +1,52 @@ +fn main() { + /// Safety: test code + unsafe { + check_loop(); + } + + check_comptime_loop(); +} + +unconstrained fn check_loop() { + let mut i = 0; + let mut sum = 0; + + loop { + if i == 4 { + break; + } + + if i == 2 { + i += 1; + continue; + } + + sum += i; + i += 1; + } + + assert_eq(sum, 1 + 3); +} + +fn check_comptime_loop() { + comptime { + let mut i = 0; + let mut sum = 0; + + loop { + if i == 4 { + break; + } + + if i == 2 { + i += 1; + continue; + } + + sum += i; + i += 1; + } + + assert_eq(sum, 1 + 3); + } +} diff --git a/noir/noir-repo/test_programs/execution_success/regression_11294/Nargo.toml b/noir/noir-repo/test_programs/execution_success/regression_11294/Nargo.toml new file mode 100644 index 00000000000..42fcd7432ff --- /dev/null +++ b/noir/noir-repo/test_programs/execution_success/regression_11294/Nargo.toml @@ -0,0 +1,7 @@ +[package] +name = "regression_11294" +version = "0.1.0" +type = "bin" +authors = [""] + +[dependencies] diff --git a/noir/noir-repo/test_programs/execution_success/regression_11294/Prover.toml b/noir/noir-repo/test_programs/execution_success/regression_11294/Prover.toml new file mode 100644 index 00000000000..c0bc12aeed9 --- /dev/null +++ b/noir/noir-repo/test_programs/execution_success/regression_11294/Prover.toml @@ -0,0 +1,47 @@ +[[previous_kernel_public_inputs.end.private_call_stack]] +args_hash = "0x0c78b411fc893c51d446c08daa5741b9ba6103126c9e450bed90fcde8793168a" +returns_hash = "0x0000000000000000000000000000000000000000000000000000000000000000" +start_side_effect_counter = "0x0000000000000000000000000000000000000000000000000000000000000002" +end_side_effect_counter = "0x0000000000000000000000000000000000000000000000000000000000000007" + +[[previous_kernel_public_inputs.end.private_call_stack]] +args_hash = "0x0000000000000000000000000000000000000000000000000000000000000000" +returns_hash = "0x0000000000000000000000000000000000000000000000000000000000000000" +start_side_effect_counter = "0x0000000000000000000000000000000000000000000000000000000000000000" +end_side_effect_counter = "0x0000000000000000000000000000000000000000000000000000000000000000" + +[[previous_kernel_public_inputs.end.private_call_stack]] +args_hash = "0x0000000000000000000000000000000000000000000000000000000000000000" +returns_hash = "0x0000000000000000000000000000000000000000000000000000000000000000" +start_side_effect_counter = "0x0000000000000000000000000000000000000000000000000000000000000000" +end_side_effect_counter = "0x0000000000000000000000000000000000000000000000000000000000000000" + +[[previous_kernel_public_inputs.end.private_call_stack]] +args_hash = "0x0000000000000000000000000000000000000000000000000000000000000000" +returns_hash = "0x0000000000000000000000000000000000000000000000000000000000000000" +start_side_effect_counter = "0x0000000000000000000000000000000000000000000000000000000000000000" +end_side_effect_counter = "0x0000000000000000000000000000000000000000000000000000000000000000" + +[[previous_kernel_public_inputs.end.private_call_stack]] +args_hash = "0x0000000000000000000000000000000000000000000000000000000000000000" +returns_hash = "0x0000000000000000000000000000000000000000000000000000000000000000" +start_side_effect_counter = "0x0000000000000000000000000000000000000000000000000000000000000000" +end_side_effect_counter = "0x0000000000000000000000000000000000000000000000000000000000000000" + +[[previous_kernel_public_inputs.end.private_call_stack]] +args_hash = "0x0000000000000000000000000000000000000000000000000000000000000000" +returns_hash = "0x0000000000000000000000000000000000000000000000000000000000000000" +start_side_effect_counter = "0x0000000000000000000000000000000000000000000000000000000000000000" +end_side_effect_counter = "0x0000000000000000000000000000000000000000000000000000000000000000" + +[[previous_kernel_public_inputs.end.private_call_stack]] +args_hash = "0x0000000000000000000000000000000000000000000000000000000000000000" +returns_hash = "0x0000000000000000000000000000000000000000000000000000000000000000" +start_side_effect_counter = "0x0000000000000000000000000000000000000000000000000000000000000000" +end_side_effect_counter = "0x0000000000000000000000000000000000000000000000000000000000000000" + +[[previous_kernel_public_inputs.end.private_call_stack]] +args_hash = "0x0000000000000000000000000000000000000000000000000000000000000000" +returns_hash = "0x0000000000000000000000000000000000000000000000000000000000000000" +start_side_effect_counter = "0x0000000000000000000000000000000000000000000000000000000000000000" +end_side_effect_counter = "0x0000000000000000000000000000000000000000000000000000000000000000" diff --git a/noir/noir-repo/test_programs/execution_success/regression_11294/src/main.nr b/noir/noir-repo/test_programs/execution_success/regression_11294/src/main.nr new file mode 100644 index 00000000000..9440a8d1482 --- /dev/null +++ b/noir/noir-repo/test_programs/execution_success/regression_11294/src/main.nr @@ -0,0 +1,186 @@ +// Capture the "attempt to subtract with overflow" from https://github.com/AztecProtocol/aztec-packages/pull/11294 + +pub global MAX_PRIVATE_CALL_STACK_LENGTH_PER_TX: u32 = 8; + +unconstrained fn main( + previous_kernel_public_inputs: PrivateKernelCircuitPublicInputs, +) -> pub PrivateKernelCircuitPublicInputs { + let private_inputs = PrivateKernelInnerCircuitPrivateInputs::new(previous_kernel_public_inputs); + private_inputs.execute() +} + +pub struct PrivateKernelCircuitPublicInputs { + pub end: PrivateAccumulatedData, +} + +pub struct PrivateKernelData { + pub public_inputs: PrivateKernelCircuitPublicInputs, +} + +pub struct PrivateAccumulatedData { + pub private_call_stack: [PrivateCallRequest; MAX_PRIVATE_CALL_STACK_LENGTH_PER_TX], +} + +pub struct PrivateCallRequest { + pub args_hash: Field, + pub returns_hash: Field, + pub start_side_effect_counter: u32, + pub end_side_effect_counter: u32, +} + +pub struct PrivateKernelCircuitPublicInputsComposer { + pub public_inputs: PrivateKernelCircuitPublicInputsBuilder, +} + +impl PrivateKernelCircuitPublicInputsComposer { + pub unconstrained fn new_from_previous_kernel( + previous_kernel_public_inputs: PrivateKernelCircuitPublicInputs, + ) -> Self { + let mut public_inputs = PrivateKernelCircuitPublicInputsBuilder { + end: PrivateAccumulatedDataBuilder { private_call_stack: BoundedVec::new() }, + }; + + let start = previous_kernel_public_inputs.end; + public_inputs.end.private_call_stack = array_to_bounded_vec(start.private_call_stack); + + PrivateKernelCircuitPublicInputsComposer { public_inputs } + } + + pub fn pop_top_call_request(&mut self) -> Self { + // Pop the top item in the call stack, which is the caller of the current call, and shouldn't be propagated to the output. + let _call_request = self.public_inputs.end.private_call_stack.pop(); + *self + } + + pub fn finish(self) -> PrivateKernelCircuitPublicInputs { + self.public_inputs.finish() + } +} + +pub struct PrivateKernelCircuitPublicInputsBuilder { + pub end: PrivateAccumulatedDataBuilder, +} + +impl PrivateKernelCircuitPublicInputsBuilder { + pub fn finish(self) -> PrivateKernelCircuitPublicInputs { + PrivateKernelCircuitPublicInputs { end: self.end.finish() } + } +} + +pub struct PrivateAccumulatedDataBuilder { + pub private_call_stack: BoundedVec, +} + +impl PrivateAccumulatedDataBuilder { + pub fn finish(self) -> PrivateAccumulatedData { + PrivateAccumulatedData { private_call_stack: self.private_call_stack.storage() } + } +} + +pub struct PrivateKernelInnerCircuitPrivateInputs { + previous_kernel: PrivateKernelData, +} + +impl PrivateKernelInnerCircuitPrivateInputs { + pub fn new(public_inputs: PrivateKernelCircuitPublicInputs) -> Self { + Self { previous_kernel: PrivateKernelData { public_inputs } } + } + + unconstrained fn generate_output(self) -> PrivateKernelCircuitPublicInputs { + // XXX: Declaring `let mut composer = ` would make the circuit pass. + PrivateKernelCircuitPublicInputsComposer::new_from_previous_kernel( + self.previous_kernel.public_inputs, + ) + .pop_top_call_request() + .finish() + } + + pub fn execute(self) -> PrivateKernelCircuitPublicInputs { + // XXX: Running both this and the bottom assertion would make the circuit pass. + // assert(!is_empty(self.previous_kernel.public_inputs.end.private_call_stack[0]), "not empty before"); + + // Safety: This is where the program treated the input as mutable. + let output = unsafe { self.generate_output() }; + + assert( + !is_empty(self.previous_kernel.public_inputs.end.private_call_stack[0]), + "not empty after", + ); + + output + } +} + +pub trait Empty { + fn empty() -> Self; +} + +pub fn is_empty(item: T) -> bool +where + T: Empty + Eq, +{ + item.eq(T::empty()) +} + +impl Eq for PrivateCallRequest { + fn eq(self, other: PrivateCallRequest) -> bool { + (self.args_hash == other.args_hash) + & (self.returns_hash == other.returns_hash) + & (self.start_side_effect_counter == other.start_side_effect_counter) + & (self.end_side_effect_counter == other.end_side_effect_counter) + } +} + +impl Empty for PrivateCallRequest { + fn empty() -> Self { + PrivateCallRequest { + args_hash: 0, + returns_hash: 0, + start_side_effect_counter: 0, + end_side_effect_counter: 0, + } + } +} + +// Copy of https://github.com/AztecProtocol/aztec-packages/blob/f1fd2d104d01a4582d8a48a6ab003d8791010967/noir-projects/noir-protocol-circuits/crates/types/src/utils/arrays.nr#L110 +pub fn array_length(array: [T; N]) -> u32 +where + T: Empty + Eq, +{ + // We get the length by checking the index of the first empty element. + + // Safety: This is safe because we have validated the array (see function doc above) and the emptiness + // of the element and non-emptiness of the previous element is checked below. + let length = unsafe { find_index_hint(array, |elem: T| is_empty(elem)) }; + // if length != 0 { + // assert(!is_empty(array[length - 1])); + // } + // if length != N { + // assert(is_empty(array[length])); + // } + length +} + +// Helper function to find the index of the first element in an array that satisfies a given predicate. If the element +// is not found, the function returns N as the index. +pub unconstrained fn find_index_hint( + array: [T; N], + find: fn[Env](T) -> bool, +) -> u32 { + let mut index = N; + for i in 0..N { + // We check `index == N` to ensure that we only update the index if we haven't found a match yet. + if (index == N) & find(array[i]) { + index = i; + } + } + index +} + +pub unconstrained fn array_to_bounded_vec(array: [T; N]) -> BoundedVec +where + T: Empty + Eq, +{ + let len = array_length(array); + BoundedVec::from_parts_unchecked(array, len) +} diff --git a/noir/noir-repo/test_programs/execution_success/regression_7062/Nargo.toml b/noir/noir-repo/test_programs/execution_success/regression_7062/Nargo.toml new file mode 100644 index 00000000000..0e11219ad98 --- /dev/null +++ b/noir/noir-repo/test_programs/execution_success/regression_7062/Nargo.toml @@ -0,0 +1,7 @@ +[package] +name = "regression_7062" +type = "bin" +authors = [""] +compiler_version = ">=0.31.0" + +[dependencies] diff --git a/noir/noir-repo/test_programs/execution_success/regression_7062/Prover.toml b/noir/noir-repo/test_programs/execution_success/regression_7062/Prover.toml new file mode 100644 index 00000000000..08608e6b3ba --- /dev/null +++ b/noir/noir-repo/test_programs/execution_success/regression_7062/Prover.toml @@ -0,0 +1,2 @@ +index = 1 +value = 1 diff --git a/noir/noir-repo/test_programs/execution_success/regression_7062/src/main.nr b/noir/noir-repo/test_programs/execution_success/regression_7062/src/main.nr new file mode 100644 index 00000000000..47e7593c0e6 --- /dev/null +++ b/noir/noir-repo/test_programs/execution_success/regression_7062/src/main.nr @@ -0,0 +1,10 @@ +fn main(value: Field, index: u32) { + let mut args = &[0, 1]; + args[index] = value; + /// Safety: n/a + unsafe { store(args) }; + // Dummy test to remove the 'underconstraint bug' + assert(args[0] + args[1] != 0); +} + +pub unconstrained fn store(_: [Field]) {} diff --git a/noir/noir-repo/test_programs/execution_success/regression_7128/Nargo.toml b/noir/noir-repo/test_programs/execution_success/regression_7128/Nargo.toml new file mode 100644 index 00000000000..4d7b621526a --- /dev/null +++ b/noir/noir-repo/test_programs/execution_success/regression_7128/Nargo.toml @@ -0,0 +1,6 @@ +[package] +name = "regression_7128" +type = "bin" +authors = [""] + +[dependencies] \ No newline at end of file diff --git a/noir/noir-repo/test_programs/execution_success/regression_7128/Prover.toml b/noir/noir-repo/test_programs/execution_success/regression_7128/Prover.toml new file mode 100644 index 00000000000..dd9b68d125e --- /dev/null +++ b/noir/noir-repo/test_programs/execution_success/regression_7128/Prover.toml @@ -0,0 +1 @@ +in0 = "1" diff --git a/noir/noir-repo/test_programs/execution_success/regression_7128/src/main.nr b/noir/noir-repo/test_programs/execution_success/regression_7128/src/main.nr new file mode 100644 index 00000000000..454c2220b88 --- /dev/null +++ b/noir/noir-repo/test_programs/execution_success/regression_7128/src/main.nr @@ -0,0 +1,26 @@ +fn main(in0: Field) -> pub Field { + let mut out0: Field = 0; + let tmp1: Field = in0; + + if (out0 == out0) // <== changing out0 to in0 or removing + { + // the comparison changes the result + let in0_as_bytes: [u8; 32] = in0.to_be_bytes(); + let mut result: [u8; 32] = [0; 32]; + for i in 0..32 { + result[i] = in0_as_bytes[i]; + } + } + + let mut tmp2: Field = 0; // <== moving this to the top of main, + if (0.lt(in0)) // changes the result + { + tmp2 = 1; + } + + out0 = (tmp2 - tmp1); + + assert(out0 == 0, "completeness violation"); + + out0 +} diff --git a/noir/noir-repo/test_programs/execution_success/regression_7143/Nargo.toml b/noir/noir-repo/test_programs/execution_success/regression_7143/Nargo.toml new file mode 100644 index 00000000000..1f581c8b24d --- /dev/null +++ b/noir/noir-repo/test_programs/execution_success/regression_7143/Nargo.toml @@ -0,0 +1,6 @@ +[package] +name = "regression_7143" +type = "bin" +authors = [""] + +[dependencies] diff --git a/noir/noir-repo/test_programs/execution_success/regression_7143/Prover.toml b/noir/noir-repo/test_programs/execution_success/regression_7143/Prover.toml new file mode 100644 index 00000000000..f2f801df886 --- /dev/null +++ b/noir/noir-repo/test_programs/execution_success/regression_7143/Prover.toml @@ -0,0 +1,3 @@ +array = [0] +x = 0 +return = 1 diff --git a/noir/noir-repo/test_programs/execution_success/regression_7143/src/main.nr b/noir/noir-repo/test_programs/execution_success/regression_7143/src/main.nr new file mode 100644 index 00000000000..396ddf1a633 --- /dev/null +++ b/noir/noir-repo/test_programs/execution_success/regression_7143/src/main.nr @@ -0,0 +1,3 @@ +fn main(x: u32, array: call_data(0) [bool; 1]) -> pub bool { + !array[x] +} diff --git a/noir/noir-repo/test_programs/noir_test_success/comptime_expr/src/main.nr b/noir/noir-repo/test_programs/noir_test_success/comptime_expr/src/main.nr index 25910685e87..6efbc212cbe 100644 --- a/noir/noir-repo/test_programs/noir_test_success/comptime_expr/src/main.nr +++ b/noir/noir-repo/test_programs/noir_test_success/comptime_expr/src/main.nr @@ -761,8 +761,7 @@ mod tests { } comptime fn times_two(expr: Expr) -> Option { - expr.as_integer().and_then(|integer: (Field, bool)| { - let (value, _) = integer; + expr.as_integer().and_then(|(value, _)| { let value = value * 2; quote { $value }.as_expr() }) diff --git a/noir/noir-repo/tooling/inspector/Cargo.toml b/noir/noir-repo/tooling/inspector/Cargo.toml new file mode 100644 index 00000000000..2124f7e9a28 --- /dev/null +++ b/noir/noir-repo/tooling/inspector/Cargo.toml @@ -0,0 +1,28 @@ +[package] +name = "noir_inspector" +description = "Inspector for noir build artifacts" +version.workspace = true +authors.workspace = true +edition.workspace = true +license.workspace = true +rust-version.workspace = true +repository.workspace = true + +[lints] +workspace = true + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[[bin]] +name = "noir-inspector" +path = "src/main.rs" + +[dependencies] +clap.workspace = true +serde.workspace = true +serde_json.workspace = true +color-eyre.workspace = true +const_format.workspace = true +acir.workspace = true +noirc_artifacts.workspace = true +noirc_artifacts_info.workspace = true diff --git a/noir/noir-repo/tooling/inspector/src/cli/info_cmd.rs b/noir/noir-repo/tooling/inspector/src/cli/info_cmd.rs new file mode 100644 index 00000000000..6a9db2676f2 --- /dev/null +++ b/noir/noir-repo/tooling/inspector/src/cli/info_cmd.rs @@ -0,0 +1,35 @@ +use std::path::PathBuf; + +use clap::Args; +use color_eyre::eyre; +use noirc_artifacts::program::ProgramArtifact; +use noirc_artifacts_info::{count_opcodes_and_gates_in_program, show_info_report, InfoReport}; + +#[derive(Debug, Clone, Args)] +pub(crate) struct InfoCommand { + /// The artifact to inspect + artifact: PathBuf, + + /// Output a JSON formatted report. Changes to this format are not currently considered breaking. + #[clap(long, hide = true)] + json: bool, +} + +pub(crate) fn run(args: InfoCommand) -> eyre::Result<()> { + let file = std::fs::File::open(args.artifact.clone())?; + let artifact: ProgramArtifact = serde_json::from_reader(file)?; + + let package_name = args + .artifact + .with_extension("") + .file_name() + .map(|s| s.to_string_lossy().to_string()) + .unwrap_or_else(|| "artifact".to_string()); + + let program_info = count_opcodes_and_gates_in_program(artifact, package_name.to_string(), None); + + let info_report = InfoReport { programs: vec![program_info] }; + show_info_report(info_report, args.json); + + Ok(()) +} diff --git a/noir/noir-repo/tooling/inspector/src/cli/mod.rs b/noir/noir-repo/tooling/inspector/src/cli/mod.rs new file mode 100644 index 00000000000..8cce6ec3a6f --- /dev/null +++ b/noir/noir-repo/tooling/inspector/src/cli/mod.rs @@ -0,0 +1,33 @@ +use clap::{command, Parser, Subcommand}; +use color_eyre::eyre; +use const_format::formatcp; + +mod info_cmd; +mod print_acir_cmd; + +const INSPECTOR_VERSION: &str = env!("CARGO_PKG_VERSION"); + +static VERSION_STRING: &str = formatcp!("version = {}\n", INSPECTOR_VERSION,); + +#[derive(Parser, Debug)] +#[command(name="Noir inspector", author, version=VERSION_STRING, about, long_about = None)] +struct InspectorCli { + #[command(subcommand)] + command: InspectorCommand, +} + +#[non_exhaustive] +#[derive(Subcommand, Clone, Debug)] +enum InspectorCommand { + Info(info_cmd::InfoCommand), + PrintAcir(print_acir_cmd::PrintAcirCommand), +} + +pub(crate) fn start_cli() -> eyre::Result<()> { + let InspectorCli { command } = InspectorCli::parse(); + + match command { + InspectorCommand::Info(args) => info_cmd::run(args), + InspectorCommand::PrintAcir(args) => print_acir_cmd::run(args), + } +} diff --git a/noir/noir-repo/tooling/inspector/src/cli/print_acir_cmd.rs b/noir/noir-repo/tooling/inspector/src/cli/print_acir_cmd.rs new file mode 100644 index 00000000000..f3dfe528973 --- /dev/null +++ b/noir/noir-repo/tooling/inspector/src/cli/print_acir_cmd.rs @@ -0,0 +1,21 @@ +use std::path::PathBuf; + +use clap::Args; +use color_eyre::eyre; +use noirc_artifacts::program::ProgramArtifact; + +#[derive(Debug, Clone, Args)] +pub(crate) struct PrintAcirCommand { + /// The artifact to print + artifact: PathBuf, +} + +pub(crate) fn run(args: PrintAcirCommand) -> eyre::Result<()> { + let file = std::fs::File::open(args.artifact.clone())?; + let artifact: ProgramArtifact = serde_json::from_reader(file)?; + + println!("Compiled ACIR for main:"); + println!("{}", artifact.bytecode); + + Ok(()) +} diff --git a/noir/noir-repo/tooling/inspector/src/main.rs b/noir/noir-repo/tooling/inspector/src/main.rs new file mode 100644 index 00000000000..8270fedbf2c --- /dev/null +++ b/noir/noir-repo/tooling/inspector/src/main.rs @@ -0,0 +1,8 @@ +mod cli; + +fn main() { + if let Err(report) = cli::start_cli() { + eprintln!("{report:?}"); + std::process::exit(1); + } +} diff --git a/noir/noir-repo/tooling/lsp/src/modules.rs b/noir/noir-repo/tooling/lsp/src/modules.rs index b023f3886c3..758322fa4bc 100644 --- a/noir/noir-repo/tooling/lsp/src/modules.rs +++ b/noir/noir-repo/tooling/lsp/src/modules.rs @@ -16,7 +16,7 @@ pub(crate) fn module_def_id_to_reference_id(module_def_id: ModuleDefId) -> Refer match module_def_id { ModuleDefId::ModuleId(id) => ReferenceId::Module(id), ModuleDefId::FunctionId(id) => ReferenceId::Function(id), - ModuleDefId::TypeId(id) => ReferenceId::Struct(id), + ModuleDefId::TypeId(id) => ReferenceId::Type(id), ModuleDefId::TypeAliasId(id) => ReferenceId::Alias(id), ModuleDefId::TraitId(id) => ReferenceId::Trait(id), ModuleDefId::GlobalId(id) => ReferenceId::Global(id), diff --git a/noir/noir-repo/tooling/lsp/src/requests/code_action/fill_struct_fields.rs b/noir/noir-repo/tooling/lsp/src/requests/code_action/fill_struct_fields.rs index 739f0bf4a21..fc8be7c5163 100644 --- a/noir/noir-repo/tooling/lsp/src/requests/code_action/fill_struct_fields.rs +++ b/noir/noir-repo/tooling/lsp/src/requests/code_action/fill_struct_fields.rs @@ -20,25 +20,23 @@ impl<'a> CodeActionFinder<'a> { }; let location = Location::new(path.span, self.file); - let Some(ReferenceId::Struct(struct_id)) = self.interner.find_referenced(location) else { + let Some(ReferenceId::Type(type_id)) = self.interner.find_referenced(location) else { return; }; - let struct_type = self.interner.get_struct(struct_id); - let struct_type = struct_type.borrow(); + let typ = self.interner.get_type(type_id); + let typ = typ.borrow(); // First get all of the struct's fields - let mut fields = struct_type.get_fields_as_written(); + let Some(mut fields) = typ.get_fields_as_written() else { + return; + }; // Remove the ones that already exists in the constructor for (constructor_field, _) in &constructor.fields { fields.retain(|field| field.name.0.contents != constructor_field.0.contents); } - if fields.is_empty() { - return; - } - // Some fields are missing. Let's suggest a quick fix that adds them. let bytes = self.source.as_bytes(); let right_brace_index = span.end() as usize - 1; diff --git a/noir/noir-repo/tooling/lsp/src/requests/completion.rs b/noir/noir-repo/tooling/lsp/src/requests/completion.rs index 0d737e29ff7..0c51772935a 100644 --- a/noir/noir-repo/tooling/lsp/src/requests/completion.rs +++ b/noir/noir-repo/tooling/lsp/src/requests/completion.rs @@ -18,9 +18,9 @@ use noirc_frontend::{ AsTraitPath, AttributeTarget, BlockExpression, CallExpression, ConstructorExpression, Expression, ExpressionKind, ForLoopStatement, GenericTypeArgs, Ident, IfExpression, IntegerBitSize, ItemVisibility, LValue, Lambda, LetStatement, MemberAccessExpression, - MethodCallExpression, NoirFunction, NoirStruct, NoirTraitImpl, Path, PathKind, Pattern, - Signedness, Statement, TraitBound, TraitImplItemKind, TypeImpl, TypePath, - UnresolvedGeneric, UnresolvedGenerics, UnresolvedType, UnresolvedTypeData, + MethodCallExpression, ModuleDeclaration, NoirFunction, NoirStruct, NoirTraitImpl, Path, + PathKind, Pattern, Signedness, Statement, TraitBound, TraitImplItemKind, TypeImpl, + TypePath, UnresolvedGeneric, UnresolvedGenerics, UnresolvedType, UnresolvedTypeData, UnresolvedTypeExpression, UseTree, UseTreeKind, Visitor, }, graph::{CrateId, Dependency}, @@ -31,10 +31,10 @@ use noirc_frontend::{ }, }, hir_def::traits::Trait, - node_interner::{FuncId, NodeInterner, ReferenceId, StructId}, + node_interner::{FuncId, NodeInterner, ReferenceId, TypeId}, parser::{Item, ItemKind, ParsedSubModule}, token::{MetaAttribute, Token, Tokens}, - Kind, ParsedModule, StructType, Type, TypeBinding, + DataType, Kind, ParsedModule, Type, TypeBinding, }; use sort_text::underscore_sort_text; @@ -199,16 +199,19 @@ impl<'a> NodeFinder<'a> { }; let location = Location::new(span, self.file); - let Some(ReferenceId::Struct(struct_id)) = self.interner.find_referenced(location) else { + let Some(ReferenceId::Type(struct_id)) = self.interner.find_referenced(location) else { return; }; - let struct_type = self.interner.get_struct(struct_id); + let struct_type = self.interner.get_type(struct_id); let struct_type = struct_type.borrow(); // First get all of the struct's fields - let mut fields: Vec<_> = - struct_type.get_fields_as_written().into_iter().enumerate().collect(); + let Some(fields) = struct_type.get_fields_as_written() else { + return; + }; + + let mut fields = fields.into_iter().enumerate().collect::>(); // Remove the ones that already exists in the constructor for (used_name, _) in &constructor_expression.fields { @@ -318,9 +321,9 @@ impl<'a> NodeFinder<'a> { match module_def_id { ModuleDefId::ModuleId(id) => module_id = id, ModuleDefId::TypeId(struct_id) => { - let struct_type = self.interner.get_struct(struct_id); + let struct_type = self.interner.get_type(struct_id); self.complete_type_methods( - &Type::Struct(struct_type, vec![]), + &Type::DataType(struct_type, vec![]), &prefix, FunctionKind::Any, function_completion_kind, @@ -568,7 +571,7 @@ impl<'a> NodeFinder<'a> { ) { let typ = &typ; match typ { - Type::Struct(struct_type, generics) => { + Type::DataType(struct_type, generics) => { self.complete_struct_fields(&struct_type.borrow(), generics, prefix, self_prefix); } Type::MutableReference(typ) => { @@ -622,7 +625,7 @@ impl<'a> NodeFinder<'a> { | Type::Forall(_, _) | Type::Constant(..) | Type::Quoted(_) - | Type::InfixExpr(_, _, _) + | Type::InfixExpr(..) | Type::Error => (), } @@ -800,14 +803,16 @@ impl<'a> NodeFinder<'a> { fn complete_struct_fields( &mut self, - struct_type: &StructType, + struct_type: &DataType, generics: &[Type], prefix: &str, self_prefix: bool, ) { - for (field_index, (name, visibility, typ)) in - struct_type.get_fields_with_visibility(generics).iter().enumerate() - { + let Some(fields) = struct_type.get_fields_with_visibility(generics) else { + return; + }; + + for (field_index, (name, visibility, typ)) in fields.iter().enumerate() { if !struct_member_is_visible(struct_type.id, *visibility, self.module_id, self.def_maps) { continue; @@ -1111,7 +1116,55 @@ impl<'a> NodeFinder<'a> { } } - /// Determine where each segment in a `use` statement is located. + /// Try to suggest the name of a module to declare based on which + /// files exist in the filesystem, excluding modules that are already declared. + fn complete_module_delcaration(&mut self, module: &ModuleDeclaration) -> Option<()> { + let filename = self.files.get_absolute_name(self.file).ok()?.into_path_buf(); + + let is_main_lib_or_mod = filename.ends_with("main.nr") + || filename.ends_with("lib.nr") + || filename.ends_with("mod.nr"); + + let paths = if is_main_lib_or_mod { + // For a "main" file we list sibling files + std::fs::read_dir(filename.parent()?) + } else { + // For a non-main files we list directory children + std::fs::read_dir(filename.with_extension("")) + }; + let paths = paths.ok()?; + + // See which modules are already defined via `mod ...;` + let module_data = + &self.def_maps[&self.module_id.krate].modules()[self.module_id.local_id.0]; + let existing_children: HashSet = + module_data.children.keys().map(|ident| ident.to_string()).collect(); + + for path in paths { + let Ok(path) = path else { + continue; + }; + let file_name = path.file_name().to_string_lossy().to_string(); + let Some(name) = file_name.strip_suffix(".nr") else { + continue; + }; + if name == "main" || name == "mod" || name == "lib" { + continue; + } + if existing_children.contains(name) { + continue; + } + + let label = if module.has_semicolon { name.to_string() } else { format!("{};", name) }; + self.completion_items.push(simple_completion_item( + label, + CompletionItemKind::MODULE, + None, + )); + } + + Some(()) + } fn includes_span(&self, span: Span) -> bool { span.start() as usize <= self.byte_index && self.byte_index <= span.end() as usize @@ -1795,11 +1848,19 @@ impl<'a> Visitor for NodeFinder<'a> { trait_bound.trait_generics.accept(self); false } + + fn visit_module_declaration(&mut self, module: &ModuleDeclaration, _: Span) { + if !self.includes_span(module.ident.span()) { + return; + } + + self.complete_module_delcaration(module); + } } fn get_field_type(typ: &Type, name: &str) -> Option { match typ { - Type::Struct(struct_type, generics) => { + Type::DataType(struct_type, generics) => { Some(struct_type.borrow().get_field(name, generics)?.0) } Type::Tuple(types) => { @@ -1839,9 +1900,9 @@ fn get_array_element_type(typ: Type) -> Option { } } -fn get_type_struct_id(typ: &Type) -> Option { +fn get_type_struct_id(typ: &Type) -> Option { match typ { - Type::Struct(struct_type, _) => Some(struct_type.borrow().id), + Type::DataType(struct_type, _) => Some(struct_type.borrow().id), Type::Alias(type_alias, generics) => { let type_alias = type_alias.borrow(); let typ = type_alias.get_type(generics); @@ -1897,11 +1958,12 @@ fn name_matches(name: &str, prefix: &str) -> bool { fn module_def_id_from_reference_id(reference_id: ReferenceId) -> Option { match reference_id { ReferenceId::Module(module_id) => Some(ModuleDefId::ModuleId(module_id)), - ReferenceId::Struct(struct_id) => Some(ModuleDefId::TypeId(struct_id)), + ReferenceId::Type(struct_id) => Some(ModuleDefId::TypeId(struct_id)), ReferenceId::Trait(trait_id) => Some(ModuleDefId::TraitId(trait_id)), ReferenceId::Function(func_id) => Some(ModuleDefId::FunctionId(func_id)), ReferenceId::Alias(type_alias_id) => Some(ModuleDefId::TypeAliasId(type_alias_id)), ReferenceId::StructMember(_, _) + | ReferenceId::EnumVariant(_, _) | ReferenceId::Global(_) | ReferenceId::Local(_) | ReferenceId::Reference(_, _) => None, diff --git a/noir/noir-repo/tooling/lsp/src/requests/completion/builtins.rs b/noir/noir-repo/tooling/lsp/src/requests/completion/builtins.rs index 90b8c6301b7..10267d4719b 100644 --- a/noir/noir-repo/tooling/lsp/src/requests/completion/builtins.rs +++ b/noir/noir-repo/tooling/lsp/src/requests/completion/builtins.rs @@ -91,6 +91,9 @@ impl<'a> NodeFinder<'a> { AttributeTarget::Struct => { self.suggest_one_argument_attributes(prefix, &["abi"]); } + AttributeTarget::Enum => { + self.suggest_one_argument_attributes(prefix, &["abi"]); + } AttributeTarget::Function => { let no_arguments_attributes = &[ "contract_library_method", @@ -156,6 +159,7 @@ pub(super) fn keyword_builtin_type(keyword: &Keyword) -> Option<&'static str> { match keyword { Keyword::Bool => Some("bool"), Keyword::CtString => Some("CtString"), + Keyword::EnumDefinition => Some("EnumDefinition"), Keyword::Expr => Some("Expr"), Keyword::Field => Some("Field"), Keyword::FunctionDefinition => Some("FunctionDefinition"), @@ -247,6 +251,7 @@ pub(super) fn keyword_builtin_function(keyword: &Keyword) -> Option NodeFinder<'a> { match target { AttributeTarget::Module => Some(Type::Quoted(QuotedType::Module)), AttributeTarget::Struct => Some(Type::Quoted(QuotedType::StructDefinition)), + AttributeTarget::Enum => Some(Type::Quoted(QuotedType::EnumDefinition)), AttributeTarget::Trait => Some(Type::Quoted(QuotedType::TraitDefinition)), AttributeTarget::Function => Some(Type::Quoted(QuotedType::FunctionDefinition)), AttributeTarget::Let => { @@ -109,17 +110,17 @@ impl<'a> NodeFinder<'a> { self.completion_item_with_doc_comments(ReferenceId::Module(id), completion_item) } - fn struct_completion_item(&self, name: String, struct_id: StructId) -> CompletionItem { + fn struct_completion_item(&self, name: String, struct_id: TypeId) -> CompletionItem { let completion_item = simple_completion_item(name.clone(), CompletionItemKind::STRUCT, Some(name)); - self.completion_item_with_doc_comments(ReferenceId::Struct(struct_id), completion_item) + self.completion_item_with_doc_comments(ReferenceId::Type(struct_id), completion_item) } pub(super) fn struct_field_completion_item( &self, field: &str, typ: &Type, - struct_id: StructId, + struct_id: TypeId, field_index: usize, self_type: bool, ) -> CompletionItem { @@ -287,10 +288,10 @@ impl<'a> NodeFinder<'a> { } else { false }; + let description = func_meta_type_to_string(func_meta, name, func_self_type.is_some()); let name = if self_prefix { format!("self.{}", name) } else { name.clone() }; let name = if is_macro_call { format!("{}!", name) } else { name }; let name = &name; - let description = func_meta_type_to_string(func_meta, func_self_type.is_some()); let mut has_arguments = false; let completion_item = match function_completion_kind { @@ -350,7 +351,16 @@ impl<'a> NodeFinder<'a> { self.auto_import_trait_if_trait_method(func_id, trait_info, &mut completion_item); - self.completion_item_with_doc_comments(ReferenceId::Function(func_id), completion_item) + if let (Some(type_id), Some(variant_index)) = + (func_meta.type_id, func_meta.enum_variant_index) + { + self.completion_item_with_doc_comments( + ReferenceId::EnumVariant(type_id, variant_index), + completion_item, + ) + } else { + self.completion_item_with_doc_comments(ReferenceId::Function(func_id), completion_item) + } } fn auto_import_trait_if_trait_method( @@ -418,6 +428,8 @@ impl<'a> NodeFinder<'a> { function_kind: FunctionKind, skip_first_argument: bool, ) -> String { + let is_enum_variant = func_meta.enum_variant_index.is_some(); + let mut text = String::new(); text.push_str(name); text.push('('); @@ -447,7 +459,11 @@ impl<'a> NodeFinder<'a> { text.push_str("${"); text.push_str(&index.to_string()); text.push(':'); - self.hir_pattern_to_argument(pattern, &mut text); + if is_enum_variant { + text.push_str("()"); + } else { + self.hir_pattern_to_argument(pattern, &mut text); + } text.push('}'); index += 1; @@ -511,18 +527,25 @@ pub(super) fn trait_impl_method_completion_item( snippet_completion_item(label, CompletionItemKind::METHOD, insert_text, None) } -fn func_meta_type_to_string(func_meta: &FuncMeta, has_self_type: bool) -> String { +fn func_meta_type_to_string(func_meta: &FuncMeta, name: &str, has_self_type: bool) -> String { let mut typ = &func_meta.typ; if let Type::Forall(_, typ_) = typ { typ = typ_; } + let is_enum_variant = func_meta.enum_variant_index.is_some(); + if let Type::Function(args, ret, _env, unconstrained) = typ { let mut string = String::new(); - if *unconstrained { - string.push_str("unconstrained "); + if is_enum_variant { + string.push_str(name); + string.push('('); + } else { + if *unconstrained { + string.push_str("unconstrained "); + } + string.push_str("fn("); } - string.push_str("fn("); for (index, arg) in args.iter().enumerate() { if index > 0 { string.push_str(", "); @@ -535,13 +558,16 @@ fn func_meta_type_to_string(func_meta: &FuncMeta, has_self_type: bool) -> String } string.push(')'); - let ret: &Type = ret; - if let Type::Unit = ret { - // Nothing - } else { - string.push_str(" -> "); - string.push_str(&ret.to_string()); + if !is_enum_variant { + let ret: &Type = ret; + if let Type::Unit = ret { + // Nothing + } else { + string.push_str(" -> "); + string.push_str(&ret.to_string()); + } } + string } else { typ.to_string() diff --git a/noir/noir-repo/tooling/lsp/src/requests/completion/tests.rs b/noir/noir-repo/tooling/lsp/src/requests/completion/tests.rs index 8ff568e3c26..a3cd6b0d024 100644 --- a/noir/noir-repo/tooling/lsp/src/requests/completion/tests.rs +++ b/noir/noir-repo/tooling/lsp/src/requests/completion/tests.rs @@ -20,8 +20,8 @@ mod completion_tests { use lsp_types::{ CompletionItem, CompletionItemKind, CompletionItemLabelDetails, CompletionParams, - CompletionResponse, DidOpenTextDocumentParams, PartialResultParams, Position, - TextDocumentIdentifier, TextDocumentItem, TextDocumentPositionParams, + CompletionResponse, DidOpenTextDocumentParams, Documentation, PartialResultParams, + Position, TextDocumentIdentifier, TextDocumentItem, TextDocumentPositionParams, WorkDoneProgressParams, }; use tokio::test; @@ -3077,4 +3077,35 @@ fn main() { "#; assert_eq!(new_code, expected); } + + #[test] + async fn test_suggests_enum_variant_differently_than_a_function_call() { + let src = r#" + enum Enum { + /// Some docs + Variant(Field, i32) + } + + fn foo() { + Enum::Var>|< + } + "#; + let items = get_completions(src).await; + assert_eq!(items.len(), 1); + + let item = &items[0]; + assert_eq!(item.label, "Variant(…)".to_string()); + + let details = item.label_details.as_ref().unwrap(); + assert_eq!(details.description, Some("Variant(Field, i32)".to_string())); + + assert_eq!(item.detail, Some("Variant(Field, i32)".to_string())); + + assert_eq!(item.insert_text, Some("Variant(${1:()}, ${2:()})".to_string())); + + let Documentation::MarkupContent(markdown) = item.documentation.as_ref().unwrap() else { + panic!("Expected markdown docs"); + }; + assert!(markdown.value.contains("Some docs")); + } } diff --git a/noir/noir-repo/tooling/lsp/src/requests/hover.rs b/noir/noir-repo/tooling/lsp/src/requests/hover.rs index ef1246d752d..60c2a686a62 100644 --- a/noir/noir-repo/tooling/lsp/src/requests/hover.rs +++ b/noir/noir-repo/tooling/lsp/src/requests/hover.rs @@ -13,10 +13,11 @@ use noirc_frontend::{ traits::Trait, }, node_interner::{ - DefinitionId, DefinitionKind, ExprId, FuncId, GlobalId, NodeInterner, ReferenceId, - StructId, TraitId, TraitImplKind, TypeAliasId, + DefinitionId, DefinitionKind, ExprId, FuncId, GlobalId, NodeInterner, ReferenceId, TraitId, + TraitImplKind, TypeAliasId, TypeId, }, - Generics, Shared, StructType, Type, TypeAlias, TypeBinding, TypeVariable, + DataType, EnumVariant, Generics, Shared, StructField, Type, TypeAlias, TypeBinding, + TypeVariable, }; use crate::{ @@ -73,10 +74,13 @@ pub(crate) fn on_hover_request( fn format_reference(reference: ReferenceId, args: &ProcessRequestCallbackArgs) -> Option { match reference { ReferenceId::Module(id) => format_module(id, args), - ReferenceId::Struct(id) => Some(format_struct(id, args)), + ReferenceId::Type(id) => Some(format_type(id, args)), ReferenceId::StructMember(id, field_index) => { Some(format_struct_member(id, field_index, args)) } + ReferenceId::EnumVariant(id, variant_index) => { + Some(format_enum_variant(id, variant_index, args)) + } ReferenceId::Trait(id) => Some(format_trait(id, args)), ReferenceId::Global(id) => Some(format_global(id, args)), ReferenceId::Function(id) => Some(format_function(id, args)), @@ -122,20 +126,33 @@ fn format_module(id: ModuleId, args: &ProcessRequestCallbackArgs) -> Option String { - let struct_type = args.interner.get_struct(id); - let struct_type = struct_type.borrow(); +fn format_type(id: TypeId, args: &ProcessRequestCallbackArgs) -> String { + let typ = args.interner.get_type(id); + let typ = typ.borrow(); + if let Some(fields) = typ.get_fields_as_written() { + format_struct(&typ, fields, args) + } else if let Some(variants) = typ.get_variants_as_written() { + format_enum(&typ, variants, args) + } else { + unreachable!("Type should either be a struct or an enum") + } +} +fn format_struct( + typ: &DataType, + fields: Vec, + args: &ProcessRequestCallbackArgs, +) -> String { let mut string = String::new(); - if format_parent_module(ReferenceId::Struct(id), args, &mut string) { + if format_parent_module(ReferenceId::Type(typ.id), args, &mut string) { string.push('\n'); } string.push_str(" "); string.push_str("struct "); - string.push_str(&struct_type.name.0.contents); - format_generics(&struct_type.generics, &mut string); + string.push_str(&typ.name.0.contents); + format_generics(&typ.generics, &mut string); string.push_str(" {\n"); - for field in struct_type.get_fields_as_written() { + for field in fields { string.push_str(" "); string.push_str(&field.name.0.contents); string.push_str(": "); @@ -144,22 +161,56 @@ fn format_struct(id: StructId, args: &ProcessRequestCallbackArgs) -> String { } string.push_str(" }"); - append_doc_comments(args.interner, ReferenceId::Struct(id), &mut string); + append_doc_comments(args.interner, ReferenceId::Type(typ.id), &mut string); + + string +} + +fn format_enum( + typ: &DataType, + variants: Vec, + args: &ProcessRequestCallbackArgs, +) -> String { + let mut string = String::new(); + if format_parent_module(ReferenceId::Type(typ.id), args, &mut string) { + string.push('\n'); + } + string.push_str(" "); + string.push_str("enum "); + string.push_str(&typ.name.0.contents); + format_generics(&typ.generics, &mut string); + string.push_str(" {\n"); + for field in variants { + string.push_str(" "); + string.push_str(&field.name.0.contents); + + if !field.params.is_empty() { + let types = field.params.iter().map(ToString::to_string).collect::>(); + string.push('('); + string.push_str(&types.join(", ")); + string.push(')'); + } + + string.push_str(",\n"); + } + string.push_str(" }"); + + append_doc_comments(args.interner, ReferenceId::Type(typ.id), &mut string); string } fn format_struct_member( - id: StructId, + id: TypeId, field_index: usize, args: &ProcessRequestCallbackArgs, ) -> String { - let struct_type = args.interner.get_struct(id); + let struct_type = args.interner.get_type(id); let struct_type = struct_type.borrow(); let field = struct_type.field_at(field_index); let mut string = String::new(); - if format_parent_module(ReferenceId::Struct(id), args, &mut string) { + if format_parent_module(ReferenceId::Type(id), args, &mut string) { string.push_str("::"); } string.push_str(&struct_type.name.0.contents); @@ -175,6 +226,39 @@ fn format_struct_member( string } +fn format_enum_variant( + id: TypeId, + field_index: usize, + args: &ProcessRequestCallbackArgs, +) -> String { + let enum_type = args.interner.get_type(id); + let enum_type = enum_type.borrow(); + let variant = enum_type.variant_at(field_index); + + let mut string = String::new(); + if format_parent_module(ReferenceId::Type(id), args, &mut string) { + string.push_str("::"); + } + string.push_str(&enum_type.name.0.contents); + string.push('\n'); + string.push_str(" "); + string.push_str(&variant.name.0.contents); + if !variant.params.is_empty() { + let types = variant.params.iter().map(ToString::to_string).collect::>(); + string.push('('); + string.push_str(&types.join(", ")); + string.push(')'); + } + + for typ in variant.params.iter() { + string.push_str(&go_to_type_links(typ, args.interner, args.files)); + } + + append_doc_comments(args.interner, ReferenceId::EnumVariant(id, field_index), &mut string); + + string +} + fn format_trait(id: TraitId, args: &ProcessRequestCallbackArgs) -> String { let a_trait = args.interner.get_trait(id); @@ -307,9 +391,19 @@ fn format_function(id: FuncId, args: &ProcessRequestCallbackArgs) -> String { let func_name_definition_id = args.interner.definition(func_meta.name.id); + let enum_variant = match (func_meta.type_id, func_meta.enum_variant_index) { + (Some(type_id), Some(index)) => Some((type_id, index)), + _ => None, + }; + + let reference_id = if let Some((type_id, variant_index)) = enum_variant { + ReferenceId::EnumVariant(type_id, variant_index) + } else { + ReferenceId::Function(id) + }; + let mut string = String::new(); - let formatted_parent_module = - format_parent_module(ReferenceId::Function(id), args, &mut string); + let formatted_parent_module = format_parent_module(reference_id, args, &mut string); let formatted_parent_type = if let Some(trait_impl_id) = func_meta.trait_impl { let trait_impl = args.interner.get_trait_implementation(trait_impl_id); @@ -367,28 +461,30 @@ fn format_function(id: FuncId, args: &ProcessRequestCallbackArgs) -> String { format_generics(&trait_.generics, &mut string); true - } else if let Some(struct_id) = func_meta.struct_id { - let struct_type = args.interner.get_struct(struct_id); - let struct_type = struct_type.borrow(); + } else if let Some(type_id) = func_meta.type_id { + let data_type = args.interner.get_type(type_id); + let data_type = data_type.borrow(); if formatted_parent_module { string.push_str("::"); } - string.push_str(&struct_type.name.0.contents); - string.push('\n'); - string.push_str(" "); - string.push_str("impl"); + string.push_str(&data_type.name.0.contents); + if enum_variant.is_none() { + string.push('\n'); + string.push_str(" "); + string.push_str("impl"); - let impl_generics: Vec<_> = func_meta - .all_generics - .iter() - .take(func_meta.all_generics.len() - func_meta.direct_generics.len()) - .cloned() - .collect(); - format_generics(&impl_generics, &mut string); + let impl_generics: Vec<_> = func_meta + .all_generics + .iter() + .take(func_meta.all_generics.len() - func_meta.direct_generics.len()) + .cloned() + .collect(); + format_generics(&impl_generics, &mut string); - string.push(' '); - string.push_str(&struct_type.name.0.contents); - format_generic_names(&impl_generics, &mut string); + string.push(' '); + string.push_str(&data_type.name.0.contents); + format_generic_names(&impl_generics, &mut string); + } true } else { @@ -415,20 +511,36 @@ fn format_function(id: FuncId, args: &ProcessRequestCallbackArgs) -> String { let func_name = &func_name_definition_id.name; - string.push_str("fn "); + if enum_variant.is_none() { + string.push_str("fn "); + } string.push_str(func_name); format_generics(&func_meta.direct_generics, &mut string); string.push('('); let parameters = &func_meta.parameters; for (index, (pattern, typ, visibility)) in parameters.iter().enumerate() { - format_pattern(pattern, args.interner, &mut string); - if !pattern_is_self(pattern, args.interner) { - string.push_str(": "); - if matches!(visibility, Visibility::Public) { - string.push_str("pub "); - } + let is_self = pattern_is_self(pattern, args.interner); + + // `&mut self` is represented as a mutable reference type, not as a mutable pattern + if is_self && matches!(typ, Type::MutableReference(..)) { + string.push_str("&mut "); + } + + if enum_variant.is_some() { string.push_str(&format!("{}", typ)); + } else { + format_pattern(pattern, args.interner, &mut string); + + // Don't add type for `self` param + if !is_self { + string.push_str(": "); + if matches!(visibility, Visibility::Public) { + string.push_str("pub "); + } + string.push_str(&format!("{}", typ)); + } } + if index != parameters.len() - 1 { string.push_str(", "); } @@ -436,28 +548,34 @@ fn format_function(id: FuncId, args: &ProcessRequestCallbackArgs) -> String { string.push(')'); - let return_type = func_meta.return_type(); - match return_type { - Type::Unit => (), - _ => { - string.push_str(" -> "); - string.push_str(&format!("{}", return_type)); + if enum_variant.is_none() { + let return_type = func_meta.return_type(); + match return_type { + Type::Unit => (), + _ => { + string.push_str(" -> "); + string.push_str(&format!("{}", return_type)); + } } - } - string.push_str(&go_to_type_links(return_type, args.interner, args.files)); + string.push_str(&go_to_type_links(return_type, args.interner, args.files)); + } - let had_doc_comments = - append_doc_comments(args.interner, ReferenceId::Function(id), &mut string); - if !had_doc_comments { - // If this function doesn't have doc comments, but it's a trait impl method, - // use the trait method doc comments. - if let Some(trait_impl_id) = func_meta.trait_impl { - let trait_impl = args.interner.get_trait_implementation(trait_impl_id); - let trait_impl = trait_impl.borrow(); - let trait_ = args.interner.get_trait(trait_impl.trait_id); - if let Some(func_id) = trait_.method_ids.get(func_name) { - append_doc_comments(args.interner, ReferenceId::Function(*func_id), &mut string); + if enum_variant.is_some() { + append_doc_comments(args.interner, reference_id, &mut string); + } else { + let had_doc_comments = append_doc_comments(args.interner, reference_id, &mut string); + if !had_doc_comments { + // If this function doesn't have doc comments, but it's a trait impl method, + // use the trait method doc comments. + if let Some(trait_impl_id) = func_meta.trait_impl { + let trait_impl = args.interner.get_trait_implementation(trait_impl_id); + let trait_impl = trait_impl.borrow(); + let trait_ = args.interner.get_trait(trait_impl.trait_id); + if let Some(func_id) = trait_.method_ids.get(func_name) { + let reference_id = ReferenceId::Function(*func_id); + append_doc_comments(args.interner, reference_id, &mut string); + } } } } @@ -685,8 +803,8 @@ impl<'a> TypeLinksGatherer<'a> { self.gather_type_links(typ); } } - Type::Struct(struct_type, generics) => { - self.gather_struct_type_links(struct_type); + Type::DataType(data_type, generics) => { + self.gather_struct_type_links(data_type); for generic in generics { self.gather_type_links(generic); } @@ -721,7 +839,7 @@ impl<'a> TypeLinksGatherer<'a> { self.gather_type_links(env); } Type::MutableReference(typ) => self.gather_type_links(typ), - Type::InfixExpr(lhs, _, rhs) => { + Type::InfixExpr(lhs, _, rhs, _) => { self.gather_type_links(lhs); self.gather_type_links(rhs); } @@ -739,7 +857,7 @@ impl<'a> TypeLinksGatherer<'a> { } } - fn gather_struct_type_links(&mut self, struct_type: &Shared) { + fn gather_struct_type_links(&mut self, struct_type: &Shared) { let struct_type = struct_type.borrow(); if let Some(lsp_location) = to_lsp_location(self.files, struct_type.location.file, struct_type.name.span()) @@ -1168,4 +1286,75 @@ mod hover_tests { .await; assert!(hover_text.contains("Some docs")); } + + #[test] + async fn hover_on_function_with_mut_self() { + let hover_text = + get_hover_text("workspace", "two/src/lib.nr", Position { line: 96, character: 10 }) + .await; + assert!(hover_text.contains("fn mut_self(&mut self)")); + } + + #[test] + async fn hover_on_empty_enum_type() { + let hover_text = + get_hover_text("workspace", "two/src/lib.nr", Position { line: 100, character: 8 }) + .await; + assert!(hover_text.contains( + " two + enum EmptyColor { + } + +--- + + Red, blue, etc." + )); + } + + #[test] + async fn hover_on_non_empty_enum_type() { + let hover_text = + get_hover_text("workspace", "two/src/lib.nr", Position { line: 103, character: 8 }) + .await; + assert!(hover_text.contains( + " two + enum Color { + Red(Field), + } + +--- + + Red, blue, etc." + )); + } + + #[test] + async fn hover_on_enum_variant() { + let hover_text = + get_hover_text("workspace", "two/src/lib.nr", Position { line: 105, character: 6 }) + .await; + assert!(hover_text.contains( + " two::Color + Red(Field) + +--- + + Like a tomato" + )); + } + + #[test] + async fn hover_on_enum_variant_in_call() { + let hover_text = + get_hover_text("workspace", "two/src/lib.nr", Position { line: 109, character: 12 }) + .await; + assert!(hover_text.contains( + " two::Color + Red(Field) + +--- + + Like a tomato" + )); + } } diff --git a/noir/noir-repo/tooling/lsp/src/requests/inlay_hint.rs b/noir/noir-repo/tooling/lsp/src/requests/inlay_hint.rs index c6415acb545..cbf4ed26ef9 100644 --- a/noir/noir-repo/tooling/lsp/src/requests/inlay_hint.rs +++ b/noir/noir-repo/tooling/lsp/src/requests/inlay_hint.rs @@ -83,25 +83,33 @@ impl<'a> InlayHintCollector<'a> { let location = Location::new(ident.span(), self.file_id); if let Some(lsp_location) = to_lsp_location(self.files, self.file_id, span) { if let Some(referenced) = self.interner.find_referenced(location) { + let include_colon = true; match referenced { ReferenceId::Global(global_id) => { let global_info = self.interner.get_global(global_id); let definition_id = global_info.definition_id; let typ = self.interner.definition_type(definition_id); - self.push_type_hint(lsp_location, &typ, editable); + self.push_type_hint(lsp_location, &typ, editable, include_colon); } ReferenceId::Local(definition_id) => { let typ = self.interner.definition_type(definition_id); - self.push_type_hint(lsp_location, &typ, editable); + self.push_type_hint(lsp_location, &typ, editable, include_colon); } ReferenceId::StructMember(struct_id, field_index) => { - let struct_type = self.interner.get_struct(struct_id); + let struct_type = self.interner.get_type(struct_id); let struct_type = struct_type.borrow(); let field = struct_type.field_at(field_index); - self.push_type_hint(lsp_location, &field.typ, false); + self.push_type_hint(lsp_location, &field.typ, false, include_colon); + } + ReferenceId::EnumVariant(type_id, variant_index) => { + let typ = self.interner.get_type(type_id); + let shared_type = typ.clone(); + let typ = typ.borrow(); + let variant_type = typ.variant_function_type(variant_index, shared_type); + self.push_type_hint(lsp_location, &variant_type, false, include_colon); } ReferenceId::Module(_) - | ReferenceId::Struct(_) + | ReferenceId::Type(_) | ReferenceId::Trait(_) | ReferenceId::Function(_) | ReferenceId::Alias(_) @@ -111,11 +119,21 @@ impl<'a> InlayHintCollector<'a> { } } - fn push_type_hint(&mut self, location: lsp_types::Location, typ: &Type, editable: bool) { + fn push_type_hint( + &mut self, + location: lsp_types::Location, + typ: &Type, + editable: bool, + include_colon: bool, + ) { let position = location.range.end; let mut parts = Vec::new(); - parts.push(string_part(": ")); + if include_colon { + parts.push(string_part(": ")); + } else { + parts.push(string_part(" ")); + } push_type_parts(typ, &mut parts, self.files); self.inlay_hints.push(InlayHint { @@ -155,6 +173,11 @@ impl<'a> InlayHintCollector<'a> { if let Some(ReferenceId::Function(func_id)) = referenced { let func_meta = self.interner.function_meta(&func_id); + // No hints for enum variants + if func_meta.enum_variant_index.is_some() { + return; + } + let mut parameters = func_meta.parameters.iter().peekable(); let mut parameters_count = func_meta.parameters.len(); @@ -209,6 +232,36 @@ impl<'a> InlayHintCollector<'a> { } } + fn collect_method_call_chain_hints(&mut self, method: &MethodCallExpression) { + let Some(object_lsp_location) = + to_lsp_location(self.files, self.file_id, method.object.span) + else { + return; + }; + + let Some(name_lsp_location) = + to_lsp_location(self.files, self.file_id, method.method_name.span()) + else { + return; + }; + + if object_lsp_location.range.end.line >= name_lsp_location.range.start.line { + return; + } + + let object_location = Location::new(method.object.span, self.file_id); + let Some(typ) = self.interner.type_at_location(object_location) else { + return; + }; + + self.push_type_hint( + object_lsp_location, + &typ, + false, // not editable + false, // don't include colon + ); + } + fn get_pattern_name(&self, pattern: &HirPattern) -> Option { match pattern { HirPattern::Identifier(ident) => { @@ -349,6 +402,10 @@ impl<'a> Visitor for InlayHintCollector<'a> { &method_call_expression.arguments, ); + if self.options.chaining_hints.enabled { + self.collect_method_call_chain_hints(method_call_expression); + } + true } @@ -410,7 +467,7 @@ fn push_type_parts(typ: &Type, parts: &mut Vec, files: &File } parts.push(string_part(")")); } - Type::Struct(struct_type, generics) => { + Type::DataType(struct_type, generics) => { let struct_type = struct_type.borrow(); let location = Location::new(struct_type.name.span(), struct_type.location.file); parts.push(text_part_with_location(struct_type.name.to_string(), location, files)); @@ -540,7 +597,9 @@ fn get_expression_name(expression: &Expression) -> Option { #[cfg(test)] mod inlay_hints_tests { use crate::{ - requests::{ClosingBraceHintsOptions, ParameterHintsOptions, TypeHintsOptions}, + requests::{ + ChainingHintsOptions, ClosingBraceHintsOptions, ParameterHintsOptions, TypeHintsOptions, + }, test_utils, }; @@ -577,6 +636,7 @@ mod inlay_hints_tests { type_hints: TypeHintsOptions { enabled: false }, parameter_hints: ParameterHintsOptions { enabled: false }, closing_brace_hints: ClosingBraceHintsOptions { enabled: false, min_lines: 25 }, + chaining_hints: ChainingHintsOptions { enabled: false }, } } @@ -585,6 +645,7 @@ mod inlay_hints_tests { type_hints: TypeHintsOptions { enabled: true }, parameter_hints: ParameterHintsOptions { enabled: false }, closing_brace_hints: ClosingBraceHintsOptions { enabled: false, min_lines: 25 }, + chaining_hints: ChainingHintsOptions { enabled: false }, } } @@ -593,6 +654,7 @@ mod inlay_hints_tests { type_hints: TypeHintsOptions { enabled: false }, parameter_hints: ParameterHintsOptions { enabled: true }, closing_brace_hints: ClosingBraceHintsOptions { enabled: false, min_lines: 25 }, + chaining_hints: ChainingHintsOptions { enabled: false }, } } @@ -601,6 +663,16 @@ mod inlay_hints_tests { type_hints: TypeHintsOptions { enabled: false }, parameter_hints: ParameterHintsOptions { enabled: false }, closing_brace_hints: ClosingBraceHintsOptions { enabled: true, min_lines }, + chaining_hints: ChainingHintsOptions { enabled: false }, + } + } + + fn chaining_hints() -> InlayHintsOptions { + InlayHintsOptions { + type_hints: TypeHintsOptions { enabled: false }, + parameter_hints: ParameterHintsOptions { enabled: false }, + closing_brace_hints: ClosingBraceHintsOptions { enabled: false, min_lines: 0 }, + chaining_hints: ChainingHintsOptions { enabled: true }, } } @@ -955,4 +1027,39 @@ mod inlay_hints_tests { panic!("Expected InlayHintLabel::String, got {:?}", inlay_hint.label); } } + + #[test] + async fn test_shows_receiver_type_in_multiline_method_call() { + let mut inlay_hints = get_inlay_hints(125, 130, chaining_hints()).await; + assert_eq!(inlay_hints.len(), 3); + + inlay_hints.sort_by_key(|hint| hint.position.line); + + let inlay_hint = &inlay_hints[0]; + assert_eq!(inlay_hint.position.line, 125); + assert_eq!(inlay_hint.position.character, 59); + let InlayHintLabel::LabelParts(parts) = &inlay_hint.label else { + panic!("Expected label parts"); + }; + let label = parts.iter().map(|part| part.value.clone()).collect::>().join(""); + assert_eq!(label, " [u32; 14]"); + + let inlay_hint = &inlay_hints[1]; + assert_eq!(inlay_hint.position.line, 126); + assert_eq!(inlay_hint.position.character, 37); + let InlayHintLabel::LabelParts(parts) = &inlay_hint.label else { + panic!("Expected label parts"); + }; + let label = parts.iter().map(|part| part.value.clone()).collect::>().join(""); + assert_eq!(label, " [u32; 14]"); + + let inlay_hint = &inlay_hints[2]; + assert_eq!(inlay_hint.position.line, 127); + assert_eq!(inlay_hint.position.character, 23); + let InlayHintLabel::LabelParts(parts) = &inlay_hint.label else { + panic!("Expected label parts"); + }; + let label = parts.iter().map(|part| part.value.clone()).collect::>().join(""); + assert_eq!(label, " bool"); + } } diff --git a/noir/noir-repo/tooling/lsp/src/requests/mod.rs b/noir/noir-repo/tooling/lsp/src/requests/mod.rs index 80f4a167a04..334599e8f3d 100644 --- a/noir/noir-repo/tooling/lsp/src/requests/mod.rs +++ b/noir/noir-repo/tooling/lsp/src/requests/mod.rs @@ -90,6 +90,9 @@ pub(crate) struct InlayHintsOptions { #[serde(rename = "closingBraceHints", default = "default_closing_brace_hints")] pub(crate) closing_brace_hints: ClosingBraceHintsOptions, + + #[serde(rename = "ChainingHints", default = "default_chaining_hints")] + pub(crate) chaining_hints: ChainingHintsOptions, } #[derive(Debug, Deserialize, Serialize, Copy, Clone)] @@ -113,6 +116,12 @@ pub(crate) struct ClosingBraceHintsOptions { pub(crate) min_lines: u32, } +#[derive(Debug, Deserialize, Serialize, Copy, Clone)] +pub(crate) struct ChainingHintsOptions { + #[serde(rename = "enabled", default = "default_chaining_hints_enabled")] + pub(crate) enabled: bool, +} + fn default_enable_code_lens() -> bool { true } @@ -126,6 +135,7 @@ fn default_inlay_hints() -> InlayHintsOptions { type_hints: default_type_hints(), parameter_hints: default_parameter_hints(), closing_brace_hints: default_closing_brace_hints(), + chaining_hints: default_chaining_hints(), } } @@ -160,6 +170,14 @@ fn default_closing_brace_min_lines() -> u32 { 25 } +fn default_chaining_hints() -> ChainingHintsOptions { + ChainingHintsOptions { enabled: default_chaining_hints_enabled() } +} + +fn default_chaining_hints_enabled() -> bool { + true +} + impl Default for LspInitializationOptions { fn default() -> Self { Self { diff --git a/noir/noir-repo/tooling/lsp/src/requests/signature_help.rs b/noir/noir-repo/tooling/lsp/src/requests/signature_help.rs index c0d40656c19..99bd463f44a 100644 --- a/noir/noir-repo/tooling/lsp/src/requests/signature_help.rs +++ b/noir/noir-repo/tooling/lsp/src/requests/signature_help.rs @@ -122,10 +122,22 @@ impl<'a> SignatureFinder<'a> { active_parameter: Option, has_self: bool, ) -> SignatureInformation { + let enum_type_id = match (func_meta.type_id, func_meta.enum_variant_index) { + (Some(type_id), Some(_)) => Some(type_id), + _ => None, + }; + let mut label = String::new(); let mut parameters = Vec::new(); - label.push_str("fn "); + if let Some(enum_type_id) = enum_type_id { + label.push_str("enum "); + label.push_str(&self.interner.get_type(enum_type_id).borrow().name.0.contents); + label.push_str("::"); + } else { + label.push_str("fn "); + } + label.push_str(name); label.push('('); for (index, (pattern, typ, _)) in func_meta.parameters.0.iter().enumerate() { @@ -142,8 +154,10 @@ impl<'a> SignatureFinder<'a> { } else { let parameter_start = label.chars().count(); - self.hir_pattern_to_argument(pattern, &mut label); - label.push_str(": "); + if enum_type_id.is_none() { + self.hir_pattern_to_argument(pattern, &mut label); + label.push_str(": "); + } label.push_str(&typ.to_string()); let parameter_end = label.chars().count(); @@ -159,11 +173,13 @@ impl<'a> SignatureFinder<'a> { } label.push(')'); - match &func_meta.return_type { - FunctionReturnType::Default(_) => (), - FunctionReturnType::Ty(typ) => { - label.push_str(" -> "); - label.push_str(&typ.to_string()); + if enum_type_id.is_none() { + match &func_meta.return_type { + FunctionReturnType::Default(_) => (), + FunctionReturnType::Ty(typ) => { + label.push_str(" -> "); + label.push_str(&typ.to_string()); + } } } @@ -224,7 +240,7 @@ impl<'a> SignatureFinder<'a> { self.hardcoded_signature_information( active_parameter, "assert", - &["predicate: bool", "[failure_message: str]"], + &["predicate: bool", "[failure_message: T]"], ) } @@ -235,7 +251,7 @@ impl<'a> SignatureFinder<'a> { self.hardcoded_signature_information( active_parameter, "assert_eq", - &["lhs: T", "rhs: T", "[failure_message: str]"], + &["lhs: T", "rhs: T", "[failure_message: U]"], ) } diff --git a/noir/noir-repo/tooling/lsp/src/requests/signature_help/tests.rs b/noir/noir-repo/tooling/lsp/src/requests/signature_help/tests.rs index 4b3f3c38156..a5cf7c32e1e 100644 --- a/noir/noir-repo/tooling/lsp/src/requests/signature_help/tests.rs +++ b/noir/noir-repo/tooling/lsp/src/requests/signature_help/tests.rs @@ -206,13 +206,13 @@ mod signature_help_tests { assert_eq!(signature_help.signatures.len(), 1); let signature = &signature_help.signatures[0]; - assert_eq!(signature.label, "assert(predicate: bool, [failure_message: str])"); + assert_eq!(signature.label, "assert(predicate: bool, [failure_message: T])"); let params = signature.parameters.as_ref().unwrap(); assert_eq!(params.len(), 2); check_label(&signature.label, ¶ms[0].label, "predicate: bool"); - check_label(&signature.label, ¶ms[1].label, "[failure_message: str]"); + check_label(&signature.label, ¶ms[1].label, "[failure_message: T]"); assert_eq!(signature.active_parameter, Some(0)); } @@ -229,14 +229,41 @@ mod signature_help_tests { assert_eq!(signature_help.signatures.len(), 1); let signature = &signature_help.signatures[0]; - assert_eq!(signature.label, "assert_eq(lhs: T, rhs: T, [failure_message: str])"); + assert_eq!(signature.label, "assert_eq(lhs: T, rhs: T, [failure_message: U])"); let params = signature.parameters.as_ref().unwrap(); assert_eq!(params.len(), 3); check_label(&signature.label, ¶ms[0].label, "lhs: T"); check_label(&signature.label, ¶ms[1].label, "rhs: T"); - check_label(&signature.label, ¶ms[2].label, "[failure_message: str]"); + check_label(&signature.label, ¶ms[2].label, "[failure_message: U]"); + + assert_eq!(signature.active_parameter, Some(0)); + } + + #[test] + async fn test_signature_help_for_enum_variant() { + let src = r#" + enum Enum { + Variant(Field, i32) + } + + fn bar() { + Enum::Variant(>|<(), ()); + } + "#; + + let signature_help = get_signature_help(src).await; + assert_eq!(signature_help.signatures.len(), 1); + + let signature = &signature_help.signatures[0]; + assert_eq!(signature.label, "enum Enum::Variant(Field, i32)"); + + let params = signature.parameters.as_ref().unwrap(); + assert_eq!(params.len(), 2); + + check_label(&signature.label, ¶ms[0].label, "Field"); + check_label(&signature.label, ¶ms[1].label, "i32"); assert_eq!(signature.active_parameter, Some(0)); } diff --git a/noir/noir-repo/tooling/lsp/src/trait_impl_method_stub_generator.rs b/noir/noir-repo/tooling/lsp/src/trait_impl_method_stub_generator.rs index 2ae0d526f3e..4e505eb5e12 100644 --- a/noir/noir-repo/tooling/lsp/src/trait_impl_method_stub_generator.rs +++ b/noir/noir-repo/tooling/lsp/src/trait_impl_method_stub_generator.rs @@ -181,7 +181,7 @@ impl<'a> TraitImplMethodStubGenerator<'a> { } self.string.push(')'); } - Type::Struct(struct_type, generics) => { + Type::DataType(struct_type, generics) => { let struct_type = struct_type.borrow(); let current_module_data = @@ -361,7 +361,7 @@ impl<'a> TraitImplMethodStubGenerator<'a> { Type::Forall(_, _) => { panic!("Shouldn't get a Type::Forall"); } - Type::InfixExpr(left, op, right) => { + Type::InfixExpr(left, op, right, _) => { self.append_type(left); self.string.push(' '); self.string.push_str(&op.to_string()); diff --git a/noir/noir-repo/tooling/lsp/test_programs/inlay_hints/src/main.nr b/noir/noir-repo/tooling/lsp/test_programs/inlay_hints/src/main.nr index 46a6d3bc558..64eca72a667 100644 --- a/noir/noir-repo/tooling/lsp/test_programs/inlay_hints/src/main.nr +++ b/noir/noir-repo/tooling/lsp/test_programs/inlay_hints/src/main.nr @@ -119,4 +119,12 @@ mod some_module { contract some_contract { +}} + +use std::ops::Not; +pub fn chain() { + let _ = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14] + .map(|x| x + 123456789012345) + .any(|x| x > 5) + .not(); } diff --git a/noir/noir-repo/tooling/lsp/test_programs/workspace/two/src/lib.nr b/noir/noir-repo/tooling/lsp/test_programs/workspace/two/src/lib.nr index d18a663b276..0baeb83d5c1 100644 --- a/noir/noir-repo/tooling/lsp/test_programs/workspace/two/src/lib.nr +++ b/noir/noir-repo/tooling/lsp/test_programs/workspace/two/src/lib.nr @@ -93,3 +93,19 @@ impl TraitWithDocs for Field { fn foo() {} } +impl Foo { + fn mut_self(&mut self) {} +} + +/// Red, blue, etc. +enum EmptyColor {} + +/// Red, blue, etc. +enum Color { + /// Like a tomato + Red(Field), +} + +fn test_enum() -> Color { + Color::Red(1) +} diff --git a/noir/noir-repo/tooling/nargo_cli/build.rs b/noir/noir-repo/tooling/nargo_cli/build.rs index 3b1aff88755..21399662449 100644 --- a/noir/noir-repo/tooling/nargo_cli/build.rs +++ b/noir/noir-repo/tooling/nargo_cli/build.rs @@ -66,11 +66,13 @@ const INLINER_MIN_OVERRIDES: [(&str, i64); 1] = [ /// Some tests are expected to have warnings /// These should be fixed and removed from this list. -const TESTS_WITH_EXPECTED_WARNINGS: [&str; 2] = [ +const TESTS_WITH_EXPECTED_WARNINGS: [&str; 3] = [ // TODO(https://github.com/noir-lang/noir/issues/6238): remove from list once issue is closed "brillig_cast", // TODO(https://github.com/noir-lang/noir/issues/6238): remove from list once issue is closed "macros_in_comptime", + // We issue a "experimental feature" warning for all enums until they're stabilized + "enums", ]; fn read_test_cases( diff --git a/noir/noir-repo/tooling/nargo_fmt/src/formatter.rs b/noir/noir-repo/tooling/nargo_fmt/src/formatter.rs index 4184ff288d7..2a8adb3fb28 100644 --- a/noir/noir-repo/tooling/nargo_fmt/src/formatter.rs +++ b/noir/noir-repo/tooling/nargo_fmt/src/formatter.rs @@ -14,6 +14,7 @@ mod attribute; mod buffer; mod comments_and_whitespace; mod doc_comments; +mod enums; mod expression; mod function; mod generics; diff --git a/noir/noir-repo/tooling/nargo_fmt/src/formatter/enums.rs b/noir/noir-repo/tooling/nargo_fmt/src/formatter/enums.rs new file mode 100644 index 00000000000..b596ec95c94 --- /dev/null +++ b/noir/noir-repo/tooling/nargo_fmt/src/formatter/enums.rs @@ -0,0 +1,202 @@ +use noirc_frontend::{ + ast::NoirEnumeration, + token::{Keyword, Token}, +}; + +use super::Formatter; +use crate::chunks::ChunkGroup; + +impl<'a> Formatter<'a> { + pub(super) fn format_enum(&mut self, noir_enum: NoirEnumeration) { + self.format_secondary_attributes(noir_enum.attributes); + self.write_indentation(); + self.format_item_visibility(noir_enum.visibility); + self.write_keyword(Keyword::Enum); + self.write_space(); + self.write_identifier(noir_enum.name); + self.format_generics(noir_enum.generics); + self.skip_comments_and_whitespace(); + + // A case like `enum Foo;` + if self.is_at(Token::Semicolon) { + self.write_semicolon(); + return; + } + + // A case like `enum Foo { ... }` + self.write_space(); + self.write_left_brace(); + + if noir_enum.variants.is_empty() { + self.format_empty_block_contents(); + } else { + self.increase_indentation(); + self.write_line(); + + for (index, documented_variant) in noir_enum.variants.into_iter().enumerate() { + if index > 0 { + self.write_comma(); + self.write_line(); + } + + let doc_comments = documented_variant.doc_comments; + if !doc_comments.is_empty() { + self.format_outer_doc_comments(); + } + + let variant = documented_variant.item; + self.write_indentation(); + self.write_identifier(variant.name); + + if !variant.parameters.is_empty() { + self.write_token(Token::LeftParen); + for (i, parameter) in variant.parameters.into_iter().enumerate() { + if i != 0 { + self.write_comma(); + self.write_space(); + } + self.format_type(parameter); + } + self.write_token(Token::RightParen); + } else { + // Remove `()` from an empty `Variant()` + self.skip_comments_and_whitespace(); + if self.is_at(Token::LeftParen) { + self.bump(); + } + self.skip_comments_and_whitespace(); + if self.is_at(Token::RightParen) { + self.bump(); + } + } + } + + // Take the comment chunk so we can put it after a trailing comma we add, in case there's no comma + let mut group = ChunkGroup::new(); + let mut comments_and_whitespace_chunk = + self.chunk_formatter().skip_comments_and_whitespace_chunk(); + comments_and_whitespace_chunk.string = + comments_and_whitespace_chunk.string.trim_end().to_string(); + group.text(comments_and_whitespace_chunk); + + if self.is_at(Token::Comma) { + self.bump(); + } + self.write(","); + + self.format_chunk_group(group); + self.skip_comments_and_whitespace(); + + self.decrease_indentation(); + self.write_line(); + self.write_indentation(); + } + + self.write_right_brace(); + } +} + +#[cfg(test)] +mod tests { + use crate::assert_format; + + #[test] + fn format_empty_enum_with_generics() { + let src = " mod moo { enum Foo < A, B, let N : u32 > {} }"; + let expected = "mod moo { + enum Foo {} +} +"; + assert_format(src, expected); + } + + #[test] + fn format_enum_with_variants() { + let src = " mod moo { enum Foo { +// hello +/// comment + Variant ( Field , i32 ) , + // comment + Another ( ), + } }"; + let expected = "mod moo { + enum Foo { + // hello + /// comment + Variant(Field, i32), + // comment + Another, + } +} +"; + assert_format(src, expected); + } + + #[test] + fn format_enum_with_multiple_newlines() { + let src = " mod moo { + + + enum Foo { + + +X( Field) , + + +Y ( Field ) + + +} + + +}"; + let expected = "mod moo { + + enum Foo { + + X(Field), + + Y(Field), + } + +} +"; + assert_format(src, expected); + } + + #[test] + fn format_two_enums() { + let src = " enum Foo { } enum Bar {} + "; + let expected = "enum Foo {} +enum Bar {} +"; + assert_format(src, expected); + } + + #[test] + fn format_enum_field_without_trailing_comma_but_comment() { + let src = "enum Foo { + field(Field) // comment + }"; + let expected = "enum Foo { + field(Field), // comment +} +"; + assert_format(src, expected); + } + + #[test] + fn format_comment_after_last_enum_field() { + let src = "enum Foo { + field(Field) + /* comment */ + }"; + let expected = "enum Foo { + field(Field), + /* comment */ +} +"; + assert_format(src, expected); + } +} diff --git a/noir/noir-repo/tooling/nargo_fmt/src/formatter/item.rs b/noir/noir-repo/tooling/nargo_fmt/src/formatter/item.rs index 3365e52ec29..499acb8415c 100644 --- a/noir/noir-repo/tooling/nargo_fmt/src/formatter/item.rs +++ b/noir/noir-repo/tooling/nargo_fmt/src/formatter/item.rs @@ -63,6 +63,7 @@ impl<'a> Formatter<'a> { false, // skip visibility ), ItemKind::Struct(noir_struct) => self.format_struct(noir_struct), + ItemKind::Enum(noir_enum) => self.format_enum(noir_enum), ItemKind::Trait(noir_trait) => self.format_trait(noir_trait), ItemKind::TraitImpl(noir_trait_impl) => self.format_trait_impl(noir_trait_impl), ItemKind::Impl(type_impl) => self.format_impl(type_impl), diff --git a/noir/noir-repo/tooling/nargo_fmt/src/formatter/statement.rs b/noir/noir-repo/tooling/nargo_fmt/src/formatter/statement.rs index 27d558ec92b..751bc419d4a 100644 --- a/noir/noir-repo/tooling/nargo_fmt/src/formatter/statement.rs +++ b/noir/noir-repo/tooling/nargo_fmt/src/formatter/statement.rs @@ -75,7 +75,7 @@ impl<'a, 'b> ChunkFormatter<'a, 'b> { StatementKind::For(for_loop_statement) => { group.group(self.format_for_loop(for_loop_statement)); } - StatementKind::Loop(block) => { + StatementKind::Loop(block, _) => { group.group(self.format_loop(block)); } StatementKind::Break => { diff --git a/noir/noir-repo/tooling/noirc_abi/src/printable_type.rs b/noir/noir-repo/tooling/noirc_abi/src/printable_type.rs index a81eb0ce8f6..e13cab06e9f 100644 --- a/noir/noir-repo/tooling/noirc_abi/src/printable_type.rs +++ b/noir/noir-repo/tooling/noirc_abi/src/printable_type.rs @@ -74,5 +74,15 @@ pub fn decode_value( decode_value(field_iterator, typ) } PrintableType::Unit => PrintableValue::Field(F::zero()), + PrintableType::Enum { name: _, variants } => { + let tag = field_iterator.next().unwrap(); + let tag_value = tag.to_u128() as usize; + + let (_name, variant_types) = &variants[tag_value]; + PrintableValue::Vec { + array_elements: vecmap(variant_types, |typ| decode_value(field_iterator, typ)), + is_slice: false, + } + } } } diff --git a/noir/noir-repo/yarn.lock b/noir/noir-repo/yarn.lock index e7529ce6f90..811dd9328af 100644 --- a/noir/noir-repo/yarn.lock +++ b/noir/noir-repo/yarn.lock @@ -906,13 +906,13 @@ __metadata: linkType: hard "@babel/parser@npm:^7.26.5": - version: 7.26.7 - resolution: "@babel/parser@npm:7.26.7" + version: 7.26.5 + resolution: "@babel/parser@npm:7.26.5" dependencies: - "@babel/types": ^7.26.7 + "@babel/types": ^7.26.5 bin: parser: ./bin/babel-parser.js - checksum: 22aafd7a6fb9ae577cf192141e5b879477fc087872b8953df0eed60ab7e2397d01aa8d92690eb7ba406f408035dd27c86fbf9967c9a37abd9bf4b1d7cf46a823 + checksum: 663aebf27c1dc04813e6c1d6e8e8fcb2954163ec297a95bdb3f1d0c2a0f04b504bddc09588fe4b176b43fad28c8a4b2914838a1edffdd426537a42f3ac644f1e languageName: node linkType: hard @@ -2863,11 +2863,11 @@ __metadata: linkType: hard "@babel/runtime@npm:^7.13.10, @babel/runtime@npm:^7.20.13, @babel/runtime@npm:^7.5.5, @babel/runtime@npm:^7.8.7": - version: 7.26.7 - resolution: "@babel/runtime@npm:7.26.7" + version: 7.26.0 + resolution: "@babel/runtime@npm:7.26.0" dependencies: regenerator-runtime: ^0.14.0 - checksum: a1664a08f3f4854b895b540cca2f5f5c6c1993b5fb788c9615d70fc201e16bb254df8e0550c83eaf2749a14d87775e11a7c9ded6161203e9da7a4a323d546925 + checksum: c8e2c0504ab271b3467a261a8f119bf2603eb857a0d71e37791f4e3fae00f681365073cc79f141ddaa90c6077c60ba56448004ad5429d07ac73532be9f7cf28a languageName: node linkType: hard @@ -2937,13 +2937,13 @@ __metadata: languageName: node linkType: hard -"@babel/types@npm:^7.14.5, @babel/types@npm:^7.20.7, @babel/types@npm:^7.26.5, @babel/types@npm:^7.26.7": - version: 7.26.7 - resolution: "@babel/types@npm:7.26.7" +"@babel/types@npm:^7.14.5, @babel/types@npm:^7.20.7, @babel/types@npm:^7.26.5": + version: 7.26.5 + resolution: "@babel/types@npm:7.26.5" dependencies: "@babel/helper-string-parser": ^7.25.9 "@babel/helper-validator-identifier": ^7.25.9 - checksum: cfb12e8794ebda6c95c92f3b90f14a9ec87ab532a247d887233068f72f8c287c0fa2e8d3d6ed5a4e512729844f7f73a613cb87d86077ae60a63a2e870e697307 + checksum: 65dc14aa32ace22655c5edadeb99df80776c09cd93c105feaf49cc0583f3116aff0581b7eab630888c39ba61151f251c1399ec982b93585b0d1d1bf4a45b54f9 languageName: node linkType: hard @@ -2966,8 +2966,8 @@ __metadata: linkType: hard "@cookbookdev/docsbot@npm:^4.24.9": - version: 4.24.10 - resolution: "@cookbookdev/docsbot@npm:4.24.10" + version: 4.24.9 + resolution: "@cookbookdev/docsbot@npm:4.24.9" dependencies: "@cookbookdev/sonner": 1.5.1 "@headlessui/react": ^1.7.18 @@ -3032,7 +3032,7 @@ __metadata: peerDependencies: react: ^16.8 || ^17.0 || ^18.0 react-dom: ^16.8 || ^17.0 || ^18.0 - checksum: 1a8ff0148acd29601a723bb0b7019adda1daab07374558c40d1ff9835d387227a40469ecd63b393baaed863074f040797fabd445864ecce72d03c75f81a07875 + checksum: d1264012469a410d7232c3fa82f2855a2d603806142746ea16c05c21fda54632475214cbe1a2c390ee001d37823277c9d4e2d821d6a826e3cd5acef5b1da1b66 languageName: node linkType: hard @@ -6487,13 +6487,13 @@ __metadata: linkType: hard "@radix-ui/react-dialog@npm:^1.0.5": - version: 1.1.5 - resolution: "@radix-ui/react-dialog@npm:1.1.5" + version: 1.1.4 + resolution: "@radix-ui/react-dialog@npm:1.1.4" dependencies: "@radix-ui/primitive": 1.1.1 "@radix-ui/react-compose-refs": 1.1.1 "@radix-ui/react-context": 1.1.1 - "@radix-ui/react-dismissable-layer": 1.1.4 + "@radix-ui/react-dismissable-layer": 1.1.3 "@radix-ui/react-focus-guards": 1.1.1 "@radix-ui/react-focus-scope": 1.1.1 "@radix-ui/react-id": 1.1.0 @@ -6502,8 +6502,8 @@ __metadata: "@radix-ui/react-primitive": 2.0.1 "@radix-ui/react-slot": 1.1.1 "@radix-ui/react-use-controllable-state": 1.1.0 - aria-hidden: ^1.2.4 - react-remove-scroll: ^2.6.2 + aria-hidden: ^1.1.1 + react-remove-scroll: ^2.6.1 peerDependencies: "@types/react": "*" "@types/react-dom": "*" @@ -6514,7 +6514,7 @@ __metadata: optional: true "@types/react-dom": optional: true - checksum: 0897fd319e9566fac87141ae74b91dee17202c5ef68850ce0f15702bfb7bd45dbacec0221b3e0a1ec25e43f79727ff88948098eb83e072c655215129dac72bc8 + checksum: 695b35c7283adfe2be4ad88d30f0ad08be099a55dfd54e49ede61074846255b426eb57734509b5fc6f349439a051b48e44e58b542c2605383aa721a1b2bd7861 languageName: node linkType: hard @@ -6555,9 +6555,9 @@ __metadata: languageName: node linkType: hard -"@radix-ui/react-dismissable-layer@npm:1.1.4": - version: 1.1.4 - resolution: "@radix-ui/react-dismissable-layer@npm:1.1.4" +"@radix-ui/react-dismissable-layer@npm:1.1.3": + version: 1.1.3 + resolution: "@radix-ui/react-dismissable-layer@npm:1.1.3" dependencies: "@radix-ui/primitive": 1.1.1 "@radix-ui/react-compose-refs": 1.1.1 @@ -6574,19 +6574,19 @@ __metadata: optional: true "@types/react-dom": optional: true - checksum: 387060b412c8db474d2e0395c5ad8eb93544a81268be2a22ff1c9c8127f9477471bd8a6d026cfc9a541512223edbf3107f6603d574b46b554431819c3c68ff2e + checksum: 26d15726bdb274aeb8d801fd163051c270707fb19e9bac4e0e90b368e79063a5347a0b15dc3aadc0bbafa157674e9e796d785d720bd5132c059ac5294ac73a81 languageName: node linkType: hard "@radix-ui/react-dropdown-menu@npm:^2.0.6": - version: 2.1.5 - resolution: "@radix-ui/react-dropdown-menu@npm:2.1.5" + version: 2.1.4 + resolution: "@radix-ui/react-dropdown-menu@npm:2.1.4" dependencies: "@radix-ui/primitive": 1.1.1 "@radix-ui/react-compose-refs": 1.1.1 "@radix-ui/react-context": 1.1.1 "@radix-ui/react-id": 1.1.0 - "@radix-ui/react-menu": 2.1.5 + "@radix-ui/react-menu": 2.1.4 "@radix-ui/react-primitive": 2.0.1 "@radix-ui/react-use-controllable-state": 1.1.0 peerDependencies: @@ -6599,7 +6599,7 @@ __metadata: optional: true "@types/react-dom": optional: true - checksum: fc4d164dd25596987341608b357f8a0ed2a13c1b64836a057cbbf0de23d1180af56761c8d328724fe2f99ab5fb47e087414a19504abfbada2d7963332c328e20 + checksum: 51014b38daab21d32164813d228b182e5dbf90c77115e5c32b8bbd37696a303485fe30fa2e0a097e1ca3f474cad0dd15efcecc3f567b5edf0a79b5092fe3b4e0 languageName: node linkType: hard @@ -6675,13 +6675,13 @@ __metadata: linkType: hard "@radix-ui/react-hover-card@npm:^1.0.7": - version: 1.1.5 - resolution: "@radix-ui/react-hover-card@npm:1.1.5" + version: 1.1.4 + resolution: "@radix-ui/react-hover-card@npm:1.1.4" dependencies: "@radix-ui/primitive": 1.1.1 "@radix-ui/react-compose-refs": 1.1.1 "@radix-ui/react-context": 1.1.1 - "@radix-ui/react-dismissable-layer": 1.1.4 + "@radix-ui/react-dismissable-layer": 1.1.3 "@radix-ui/react-popper": 1.2.1 "@radix-ui/react-portal": 1.1.3 "@radix-ui/react-presence": 1.1.2 @@ -6697,7 +6697,7 @@ __metadata: optional: true "@types/react-dom": optional: true - checksum: a014491d8dd3b9284ecbe2b8b1672973d9bae97c0f0c2ee9a9ff6a15de1138a86c62b4e14ff81ebdb1dab2c7bae8f1b58bfa63d5c0b5ab9f415b3aa560950bb5 + checksum: 1c3e0d8edc01f714c1ca389fb28f4ed843726e0cb56b21dc7d79e9680f2c19138ce494eb3c333ef7ad524a94209494321b4ca70a44f56fdb01c6a90d30c02058 languageName: node linkType: hard @@ -6760,16 +6760,16 @@ __metadata: languageName: node linkType: hard -"@radix-ui/react-menu@npm:2.1.5": - version: 2.1.5 - resolution: "@radix-ui/react-menu@npm:2.1.5" +"@radix-ui/react-menu@npm:2.1.4": + version: 2.1.4 + resolution: "@radix-ui/react-menu@npm:2.1.4" dependencies: "@radix-ui/primitive": 1.1.1 "@radix-ui/react-collection": 1.1.1 "@radix-ui/react-compose-refs": 1.1.1 "@radix-ui/react-context": 1.1.1 "@radix-ui/react-direction": 1.1.0 - "@radix-ui/react-dismissable-layer": 1.1.4 + "@radix-ui/react-dismissable-layer": 1.1.3 "@radix-ui/react-focus-guards": 1.1.1 "@radix-ui/react-focus-scope": 1.1.1 "@radix-ui/react-id": 1.1.0 @@ -6780,8 +6780,8 @@ __metadata: "@radix-ui/react-roving-focus": 1.1.1 "@radix-ui/react-slot": 1.1.1 "@radix-ui/react-use-callback-ref": 1.1.0 - aria-hidden: ^1.2.4 - react-remove-scroll: ^2.6.2 + aria-hidden: ^1.1.1 + react-remove-scroll: ^2.6.1 peerDependencies: "@types/react": "*" "@types/react-dom": "*" @@ -6792,18 +6792,18 @@ __metadata: optional: true "@types/react-dom": optional: true - checksum: 5546120ce96f707f6ddb3d571280e1143c59155757a03c3feea8c0aaf918c0ecb70b576c0315fcec13be141a491351ccb6426144b009ed3625e2b45936e203f4 + checksum: 2a20db7c017075d3ceb07b076dfdbdc3c4d26825642e2f52d4bda9947078725db129c2b41a93a6f1393526eca3bf586f54dc47a9a66e1eef621c94a55a216aa4 languageName: node linkType: hard "@radix-ui/react-popover@npm:^1.0.7": - version: 1.1.5 - resolution: "@radix-ui/react-popover@npm:1.1.5" + version: 1.1.4 + resolution: "@radix-ui/react-popover@npm:1.1.4" dependencies: "@radix-ui/primitive": 1.1.1 "@radix-ui/react-compose-refs": 1.1.1 "@radix-ui/react-context": 1.1.1 - "@radix-ui/react-dismissable-layer": 1.1.4 + "@radix-ui/react-dismissable-layer": 1.1.3 "@radix-ui/react-focus-guards": 1.1.1 "@radix-ui/react-focus-scope": 1.1.1 "@radix-ui/react-id": 1.1.0 @@ -6813,8 +6813,8 @@ __metadata: "@radix-ui/react-primitive": 2.0.1 "@radix-ui/react-slot": 1.1.1 "@radix-ui/react-use-controllable-state": 1.1.0 - aria-hidden: ^1.2.4 - react-remove-scroll: ^2.6.2 + aria-hidden: ^1.1.1 + react-remove-scroll: ^2.6.1 peerDependencies: "@types/react": "*" "@types/react-dom": "*" @@ -6825,7 +6825,7 @@ __metadata: optional: true "@types/react-dom": optional: true - checksum: d98dd028ce455a509faa7a09193f8e08803ea9a3e797316e9852cc027ddf7463ce2250bb4391a165424a0b5d19aaf72587d114c8d602fe0f1560793f3391cc81 + checksum: f4525ac9a2f5957ad709749daddb78e58d8b1471dfd8683ca713d1ade9aac202b30c7179b798471e90ab13a01f01a70a3bc4002a872c279ee383bd3ad8ad49e6 languageName: node linkType: hard @@ -7032,8 +7032,8 @@ __metadata: linkType: hard "@radix-ui/react-select@npm:^2.0.0": - version: 2.1.5 - resolution: "@radix-ui/react-select@npm:2.1.5" + version: 2.1.4 + resolution: "@radix-ui/react-select@npm:2.1.4" dependencies: "@radix-ui/number": 1.1.0 "@radix-ui/primitive": 1.1.1 @@ -7041,7 +7041,7 @@ __metadata: "@radix-ui/react-compose-refs": 1.1.1 "@radix-ui/react-context": 1.1.1 "@radix-ui/react-direction": 1.1.0 - "@radix-ui/react-dismissable-layer": 1.1.4 + "@radix-ui/react-dismissable-layer": 1.1.3 "@radix-ui/react-focus-guards": 1.1.1 "@radix-ui/react-focus-scope": 1.1.1 "@radix-ui/react-id": 1.1.0 @@ -7054,8 +7054,8 @@ __metadata: "@radix-ui/react-use-layout-effect": 1.1.0 "@radix-ui/react-use-previous": 1.1.0 "@radix-ui/react-visually-hidden": 1.1.1 - aria-hidden: ^1.2.4 - react-remove-scroll: ^2.6.2 + aria-hidden: ^1.1.1 + react-remove-scroll: ^2.6.1 peerDependencies: "@types/react": "*" "@types/react-dom": "*" @@ -7066,7 +7066,7 @@ __metadata: optional: true "@types/react-dom": optional: true - checksum: 585be67916a0ae84a8ddb35ea64f969406bdb661a8239a2188a69a5ffe3af1dd3c33f372178e2227b24ef20c82f83062dc82e65ec61429c5c236f8764d4dbf2d + checksum: 68571da6bebc83f0b685b2b49d804227e787adc72e759523d2d5ce9ee63b51d61de2a9ea92455f33d386eced626d4b2bf73639f4c08d2dbb595f5ec3af47533c languageName: node linkType: hard @@ -7176,14 +7176,14 @@ __metadata: linkType: hard "@radix-ui/react-toast@npm:^1.1.5": - version: 1.2.5 - resolution: "@radix-ui/react-toast@npm:1.2.5" + version: 1.2.4 + resolution: "@radix-ui/react-toast@npm:1.2.4" dependencies: "@radix-ui/primitive": 1.1.1 "@radix-ui/react-collection": 1.1.1 "@radix-ui/react-compose-refs": 1.1.1 "@radix-ui/react-context": 1.1.1 - "@radix-ui/react-dismissable-layer": 1.1.4 + "@radix-ui/react-dismissable-layer": 1.1.3 "@radix-ui/react-portal": 1.1.3 "@radix-ui/react-presence": 1.1.2 "@radix-ui/react-primitive": 2.0.1 @@ -7201,18 +7201,18 @@ __metadata: optional: true "@types/react-dom": optional: true - checksum: 4910807f136b98d6c152e73c6dd4b4e023860a0fd115fd1b21ca58109bfd3854a4651b712d130c527e380dc48ac2a5d524e190d61e7ef3223cb3a32ac186085b + checksum: 6479af12ec9ae4f3cdbb8ca66d2926c8323e60fdca57fc3de6af1f41f64bf4cb9e0ea70bdac4e9aac568f1abc2222ebfb67e048d77f36ec10bb4971ae4870b73 languageName: node linkType: hard "@radix-ui/react-tooltip@npm:^1.0.7": - version: 1.1.7 - resolution: "@radix-ui/react-tooltip@npm:1.1.7" + version: 1.1.6 + resolution: "@radix-ui/react-tooltip@npm:1.1.6" dependencies: "@radix-ui/primitive": 1.1.1 "@radix-ui/react-compose-refs": 1.1.1 "@radix-ui/react-context": 1.1.1 - "@radix-ui/react-dismissable-layer": 1.1.4 + "@radix-ui/react-dismissable-layer": 1.1.3 "@radix-ui/react-id": 1.1.0 "@radix-ui/react-popper": 1.2.1 "@radix-ui/react-portal": 1.1.3 @@ -7231,7 +7231,7 @@ __metadata: optional: true "@types/react-dom": optional: true - checksum: 05d1167f1a65ae211b78f69be12ba35b05e568f00c5862fa8c092f726622642a2393ae4fa82a0b3c4e036b1caa3ba7fef8ec316738a5c6d3131be8a8f9fdf7db + checksum: aabbb2c3a7592419fcf41d306582c57307e9518ee29d80e0a8f811bb29ade72144ee35a4f4a120e5143dee46813017fe4e087f351503929553a94e3af986305f languageName: node linkType: hard @@ -7686,80 +7686,80 @@ __metadata: languageName: node linkType: hard -"@shikijs/core@npm:1.29.1": - version: 1.29.1 - resolution: "@shikijs/core@npm:1.29.1" +"@shikijs/core@npm:1.26.2": + version: 1.26.2 + resolution: "@shikijs/core@npm:1.26.2" dependencies: - "@shikijs/engine-javascript": 1.29.1 - "@shikijs/engine-oniguruma": 1.29.1 - "@shikijs/types": 1.29.1 + "@shikijs/engine-javascript": 1.26.2 + "@shikijs/engine-oniguruma": 1.26.2 + "@shikijs/types": 1.26.2 "@shikijs/vscode-textmate": ^10.0.1 "@types/hast": ^3.0.4 hast-util-to-html: ^9.0.4 - checksum: e26de2c33f6f00984718aa77d59d726659e4f1bc8d7d07d8e8870e3b29c6d4f70261d628f8a703c4770e78441fd94edac1532165a4597b00d79f1c98dd561e39 + checksum: b7bad4c281102bdd74f0974aa780efca06117208419c205005f172b247221d42685608d96dba97bd215eb6af99463d914517bad77fdc507e3254527f08f95975 languageName: node linkType: hard -"@shikijs/engine-javascript@npm:1.29.1": - version: 1.29.1 - resolution: "@shikijs/engine-javascript@npm:1.29.1" +"@shikijs/engine-javascript@npm:1.26.2": + version: 1.26.2 + resolution: "@shikijs/engine-javascript@npm:1.26.2" dependencies: - "@shikijs/types": 1.29.1 + "@shikijs/types": 1.26.2 "@shikijs/vscode-textmate": ^10.0.1 - oniguruma-to-es: ^2.2.0 - checksum: fe268745cb0078efdb17011d66687bf3e9cf6c3f379df036073bbd4d0b8ccddb3d6469e3f6be5e70fa7b629ff1f72c767a76286a9fd498e726d55c7510427fad + oniguruma-to-es: ^1.0.0 + checksum: 8df3d284033b17f50625d07e0cdcd26d24ad8e821bb8e58cc91aebffd28befe7cf80169aadbc5bcc038433d68a2e18148b5098ef4e243f18c9cb634ba20ea034 languageName: node linkType: hard -"@shikijs/engine-oniguruma@npm:1.29.1": - version: 1.29.1 - resolution: "@shikijs/engine-oniguruma@npm:1.29.1" +"@shikijs/engine-oniguruma@npm:1.26.2": + version: 1.26.2 + resolution: "@shikijs/engine-oniguruma@npm:1.26.2" dependencies: - "@shikijs/types": 1.29.1 + "@shikijs/types": 1.26.2 "@shikijs/vscode-textmate": ^10.0.1 - checksum: 099d38a9b14b8a0252be8b0cad6e0ed9b197bee58e159e3009c8d188a469bf79f591b7e162518ee50a3aa1d4f7579e43bd269fe127cc9a2c6096384f832144e1 + checksum: d2d4978be0b4e8b3b26fd01ea0480568ac2264135704442e54ee14fb3b05f564492f95cae521378d6114641f249e892a76dc01fb0e4ae64522e56b3b353ce08d languageName: node linkType: hard -"@shikijs/langs@npm:1.29.1": - version: 1.29.1 - resolution: "@shikijs/langs@npm:1.29.1" +"@shikijs/langs@npm:1.26.2": + version: 1.26.2 + resolution: "@shikijs/langs@npm:1.26.2" dependencies: - "@shikijs/types": 1.29.1 - checksum: 4cc7c27724d6aaa43e79606299accffa8f134b8cabc640387803060de59ecb2e1428edb533acd85e44f242245c9ecddc9097256198fc3bc177372063dfff5d6c + "@shikijs/types": 1.26.2 + checksum: c3f1882401f19bc50cbf5fb3d62b287c2fc07bdbc98b281d1d63b51dac8602d00b8446918aae73d8e76e455b205542cc9a73526b64fe3c1b020952af8ebe26c5 languageName: node linkType: hard "@shikijs/rehype@npm:^1.12.1": - version: 1.29.1 - resolution: "@shikijs/rehype@npm:1.29.1" + version: 1.26.2 + resolution: "@shikijs/rehype@npm:1.26.2" dependencies: - "@shikijs/types": 1.29.1 + "@shikijs/types": 1.26.2 "@types/hast": ^3.0.4 hast-util-to-string: ^3.0.1 - shiki: 1.29.1 + shiki: 1.26.2 unified: ^11.0.5 unist-util-visit: ^5.0.0 - checksum: 7083ea03ec03c2b40ee1234bc01382a09aa6cbe6c78798baa0c3b1061781cb3f8f25a278fbad6cd36cadc4bf051984647eb3a183824320fc0bc05984cd48a1b3 + checksum: b1ed82152275ebf8deadd97308b1ead12b91ce451409fce70307b6d6953c9a9c60d18dd1c3d6317039b30df14f33e9c01c42c7f0e149662f94221b110b3a569c languageName: node linkType: hard -"@shikijs/themes@npm:1.29.1": - version: 1.29.1 - resolution: "@shikijs/themes@npm:1.29.1" +"@shikijs/themes@npm:1.26.2": + version: 1.26.2 + resolution: "@shikijs/themes@npm:1.26.2" dependencies: - "@shikijs/types": 1.29.1 - checksum: 48274e3edef02eed1b5ec781e15a9625503e26b88ce2d84244fe47b7b7740e7d4ea930e381a690d05610186a66df70f215c7b8e5723579a337a5c3d50c8aff85 + "@shikijs/types": 1.26.2 + checksum: 1e11093ae6e4fbf3573ca626dea19e78f07e1eb33a720ef56503e830743e7789e2203203738162f9b6f1bb273fe23b51b545ea3acdd5818b39fb7052475993fb languageName: node linkType: hard -"@shikijs/types@npm:1.29.1": - version: 1.29.1 - resolution: "@shikijs/types@npm:1.29.1" +"@shikijs/types@npm:1.26.2": + version: 1.26.2 + resolution: "@shikijs/types@npm:1.26.2" dependencies: "@shikijs/vscode-textmate": ^10.0.1 "@types/hast": ^3.0.4 - checksum: 22b8e893033728c9aa5139ebab6c73e188c111eba52828d55d9cb3b0d96f6ecacc1670ddf6f8d2e9d25126e1998711a3b35b17012b9fce5cc9fec128a49e50a4 + checksum: 77e0823a60dce4f37b85b2648ae75f00750fc897b6efaf7f8d765b3e57849456c5bfbe6ad9f1faa269c2ccddf7e4d7ab6dc5b43ab69c70ee09b4bde37f360cc4 languageName: node linkType: hard @@ -7904,31 +7904,31 @@ __metadata: languageName: node linkType: hard -"@statsig/client-core@npm:3.11.0": - version: 3.11.0 - resolution: "@statsig/client-core@npm:3.11.0" - checksum: 2083915132b3a2d94b90a948800d3fad47da1e11a0454cd83a22ac9a3946ed2b045071ef182541562e6a54fcd9339dc87f399a8dfc5aa6e7a0f8b3b37142e76a +"@statsig/client-core@npm:3.9.0": + version: 3.9.0 + resolution: "@statsig/client-core@npm:3.9.0" + checksum: 2c8b8a99ac4ce1d24fe4e9263f55af37d4861455450dc348787bfa3cf2cf523a12e75a3aaa4fb676a853493174481f7e29d99196bcfd1ce30ffcb48137ff6453 languageName: node linkType: hard -"@statsig/js-client@npm:3.11.0, @statsig/js-client@npm:^3.1.0": - version: 3.11.0 - resolution: "@statsig/js-client@npm:3.11.0" +"@statsig/js-client@npm:3.9.0, @statsig/js-client@npm:^3.1.0": + version: 3.9.0 + resolution: "@statsig/js-client@npm:3.9.0" dependencies: - "@statsig/client-core": 3.11.0 - checksum: b7ea9889a19ed317d34b29b49e2a81389a386530501e124cb784a7881547177363ef991da36f937e10d233b949dec25ea6cba44804cd5904364d8678d651303e + "@statsig/client-core": 3.9.0 + checksum: 0f715c5043fb529baeadd256f8487f63aa05502cbe51d1f2053bcedf306638dc831eec05b0ad31522c05d28edabd854f9da4a4844e48b9fc27109e089679fb47 languageName: node linkType: hard "@statsig/react-bindings@npm:^3.1.0": - version: 3.11.0 - resolution: "@statsig/react-bindings@npm:3.11.0" + version: 3.9.0 + resolution: "@statsig/react-bindings@npm:3.9.0" dependencies: - "@statsig/client-core": 3.11.0 - "@statsig/js-client": 3.11.0 + "@statsig/client-core": 3.9.0 + "@statsig/js-client": 3.9.0 peerDependencies: react: ^16.6.3 || ^17.0.0 || ^18.0.0 || ^19.0.0 - checksum: 7d03220ff37cbb2c247473d70f8ee8f819246ed07a7b841c17e32ca3c4cc4ce72e7bfbfec98c1c5f080732636c2c774d96f96a046716dc7d26976136f76e2560 + checksum: e6b2cb68e3ca720c1714c5b8bc1f1b67fbb2fff7eab07b7395b4c32e64957877233158e29f38129a404ba3ede670490db665b22b8d4c1ce3fda71a95866008e7 languageName: node linkType: hard @@ -8266,14 +8266,14 @@ __metadata: linkType: hard "@tanstack/react-virtual@npm:^3.0.0-beta.60": - version: 3.11.3 - resolution: "@tanstack/react-virtual@npm:3.11.3" + version: 3.11.2 + resolution: "@tanstack/react-virtual@npm:3.11.2" dependencies: - "@tanstack/virtual-core": 3.11.3 + "@tanstack/virtual-core": 3.11.2 peerDependencies: react: ^16.8.0 || ^17.0.0 || ^18.0.0 || ^19.0.0 react-dom: ^16.8.0 || ^17.0.0 || ^18.0.0 || ^19.0.0 - checksum: eb39f8a015f4dc98070f0c18bbb1f9c094b7182133554ef3ee31d2678cd3a66edd28ce854d533e830f88f1f0ad1d5b065de184438a08fe774a9acc1dc62da436 + checksum: a1136da0ec4c2ecbd4f996d8b84f228f0b8d851b15806e01049a160ad1d9b2eef0e0a491035fe017c6f84a0e125334f69ea23b32c180df23614ea4a8eeb7490c languageName: node linkType: hard @@ -8284,10 +8284,10 @@ __metadata: languageName: node linkType: hard -"@tanstack/virtual-core@npm:3.11.3": - version: 3.11.3 - resolution: "@tanstack/virtual-core@npm:3.11.3" - checksum: 43dacdd27802de45c0ccdb34a4cba5f47c7b5aab1cc920a9851e382074fc9099419cc4b014afab74791672e4837afe0ef6ecf47e344267b457858dd982ffac57 +"@tanstack/virtual-core@npm:3.11.2": + version: 3.11.2 + resolution: "@tanstack/virtual-core@npm:3.11.2" + checksum: b5c91662461e3edd1cba0efbaa89e1d061c8bb605bb78d1e87e2a687335c740a731c96a81798b05491df4882ff2fbd27b312f5e7440e4f9d553a81fb2283156a languageName: node linkType: hard @@ -10218,7 +10218,7 @@ __metadata: languageName: node linkType: hard -"aria-hidden@npm:^1.1.1, aria-hidden@npm:^1.2.4": +"aria-hidden@npm:^1.1.1": version: 1.2.4 resolution: "aria-hidden@npm:1.2.4" dependencies: @@ -16991,9 +16991,9 @@ __metadata: linkType: hard "iso-639-1@npm:^3.1.3": - version: 3.1.4 - resolution: "iso-639-1@npm:3.1.4" - checksum: 3471a685d01fc2fa46b72e0c0b0679710073b17f0a234127c646050ac391931156b28fdfd42e2ef871d70fa05661e5f100da377869753fe178c5260f2464c2f3 + version: 3.1.3 + resolution: "iso-639-1@npm:3.1.3" + checksum: 9a4cf417a91f638af247328e2b92ca135ec82eedbab139246cfd0a53d2dee052a8abc1639ca997e84fbebb0cd536cb7fb88910433c922122bcb7250a6a16d8e9 languageName: node linkType: hard @@ -19721,14 +19721,14 @@ __metadata: languageName: node linkType: hard -"oniguruma-to-es@npm:^2.2.0": - version: 2.3.0 - resolution: "oniguruma-to-es@npm:2.3.0" +"oniguruma-to-es@npm:^1.0.0": + version: 1.0.0 + resolution: "oniguruma-to-es@npm:1.0.0" dependencies: emoji-regex-xs: ^1.0.0 regex: ^5.1.1 regex-recursion: ^5.1.1 - checksum: b9af262ecad9d8b0817203efceed25f2675c6e4018b4778bbe3c4092506924d726f1e2f9116d7321c2bd08110d1ddef5bbbeab863d6ef2937ce554087adb6938 + checksum: 2d88b3f0c670b1b7c87bf5c4caefea12771748c5970f691f04284604f3dce745107f7558573395c9103bea56154062b421d0a2a7005ada93968a7071316f5d1e languageName: node linkType: hard @@ -21204,14 +21204,14 @@ __metadata: linkType: hard "posthog-js@npm:^1.136.8": - version: 1.212.1 - resolution: "posthog-js@npm:1.212.1" + version: 1.205.1 + resolution: "posthog-js@npm:1.205.1" dependencies: core-js: ^3.38.1 fflate: ^0.4.8 preact: ^10.19.3 web-vitals: ^4.2.0 - checksum: 18e1574364c5ff431b852f7c499615bc3b42684d7ab25d6be31eb94babfda05377bfe55217e3280a461dde2522a258a17549cf646e6688ad4a6538801d3dd50a + checksum: d91ce45d2d3d5784b6db8d2f4408619db6512bd3d5fae8022f6fc91864b3b2d2c14c0081b670aefa1a8b63d7a498b2a9e3e58e58583e419893b54276832f1635 languageName: node linkType: hard @@ -21803,22 +21803,22 @@ __metadata: languageName: node linkType: hard -"react-remove-scroll@npm:^2.6.2": - version: 2.6.3 - resolution: "react-remove-scroll@npm:2.6.3" +"react-remove-scroll@npm:^2.6.1": + version: 2.6.2 + resolution: "react-remove-scroll@npm:2.6.2" dependencies: react-remove-scroll-bar: ^2.3.7 - react-style-singleton: ^2.2.3 + react-style-singleton: ^2.2.1 tslib: ^2.1.0 use-callback-ref: ^1.3.3 - use-sidecar: ^1.1.3 + use-sidecar: ^1.1.2 peerDependencies: "@types/react": "*" react: ^16.8.0 || ^17.0.0 || ^18.0.0 || ^19.0.0 || ^19.0.0-rc peerDependenciesMeta: "@types/react": optional: true - checksum: a4afd320435cc25a6ee39d7cef2f605dca14cc7618e1cdab24ed0924fa71d8c3756626334dedc9a578945d7ba6f8f87d7b8b66b48034853dc4dbfbda0a1b228b + checksum: 310e6e6d2f28226a1751dc5084a2dce49167f0b69e3d78d6510f329f423ee313d4f6477f5e1adccb68baef40a7af75541e980a8c398cb82ea0d3573e514e8124 languageName: node linkType: hard @@ -21880,7 +21880,7 @@ __metadata: languageName: node linkType: hard -"react-smooth@npm:^4.0.4": +"react-smooth@npm:^4.0.0": version: 4.0.4 resolution: "react-smooth@npm:4.0.4" dependencies: @@ -21904,7 +21904,7 @@ __metadata: languageName: node linkType: hard -"react-style-singleton@npm:^2.2.1, react-style-singleton@npm:^2.2.2, react-style-singleton@npm:^2.2.3": +"react-style-singleton@npm:^2.2.1, react-style-singleton@npm:^2.2.2": version: 2.2.3 resolution: "react-style-singleton@npm:2.2.3" dependencies: @@ -22033,21 +22033,21 @@ __metadata: linkType: hard "recharts@npm:^2.12.4": - version: 2.15.1 - resolution: "recharts@npm:2.15.1" + version: 2.15.0 + resolution: "recharts@npm:2.15.0" dependencies: clsx: ^2.0.0 eventemitter3: ^4.0.1 lodash: ^4.17.21 react-is: ^18.3.1 - react-smooth: ^4.0.4 + react-smooth: ^4.0.0 recharts-scale: ^0.4.4 tiny-invariant: ^1.3.1 victory-vendor: ^36.6.8 peerDependencies: react: ^16.0.0 || ^17.0.0 || ^18.0.0 || ^19.0.0 react-dom: ^16.0.0 || ^17.0.0 || ^18.0.0 || ^19.0.0 - checksum: c288012087424e7067bf4a48a3d50a7c0ea231188ecaaf17685d93711ef1846aedd37509dfc25c046052f671e7e19dd2de5d8b9c1e497a64c39995c9c338b40a + checksum: b49337c1df995f60c302101e44e638660ca4ea5370467ffbc61c0ee829770ffae0337564c0d1dc8b46786c91d9201f6bcad9e0155173c4c6fca44a5f3a6decc3 languageName: node linkType: hard @@ -23226,19 +23226,19 @@ __metadata: languageName: node linkType: hard -"shiki@npm:1.29.1": - version: 1.29.1 - resolution: "shiki@npm:1.29.1" +"shiki@npm:1.26.2": + version: 1.26.2 + resolution: "shiki@npm:1.26.2" dependencies: - "@shikijs/core": 1.29.1 - "@shikijs/engine-javascript": 1.29.1 - "@shikijs/engine-oniguruma": 1.29.1 - "@shikijs/langs": 1.29.1 - "@shikijs/themes": 1.29.1 - "@shikijs/types": 1.29.1 + "@shikijs/core": 1.26.2 + "@shikijs/engine-javascript": 1.26.2 + "@shikijs/engine-oniguruma": 1.26.2 + "@shikijs/langs": 1.26.2 + "@shikijs/themes": 1.26.2 + "@shikijs/types": 1.26.2 "@shikijs/vscode-textmate": ^10.0.1 "@types/hast": ^3.0.4 - checksum: a4ae65ddd427e81593d11d086136018071fb9b8faae613d47c99c9de095a1a078f3e68d5aa4a017dfbe0290ca76dba4f52ea1e3a122b910041aac5ed6be2218c + checksum: 0cc7af769eb57de4bc1423ab60a62e8c68071914456b0bd94c6188051d770f65b8e76d94341f552274bb76221c734c0ccc802106fa2c969b701df5dffaf26a14 languageName: node linkType: hard @@ -25226,7 +25226,7 @@ __metadata: languageName: node linkType: hard -"use-sidecar@npm:^1.1.2, use-sidecar@npm:^1.1.3": +"use-sidecar@npm:^1.1.2": version: 1.1.3 resolution: "use-sidecar@npm:1.1.3" dependencies: diff --git a/yarn-project/bb-prover/src/avm_proving_tests/avm_proving.test.ts b/yarn-project/bb-prover/src/avm_proving_tests/avm_proving.test.ts index 898899e2195..7f0b6a3a723 100644 --- a/yarn-project/bb-prover/src/avm_proving_tests/avm_proving.test.ts +++ b/yarn-project/bb-prover/src/avm_proving_tests/avm_proving.test.ts @@ -386,7 +386,7 @@ describe('AVM WitGen & Circuit', () => { const sender = AztecAddress.fromNumber(42); const feePayer = sender; - const initialFeeJuiceBalance = new Fr(10000); + const initialFeeJuiceBalance = new Fr(20000); let feeJuice: ProtocolContract; let feeJuiceContractClassPublic: ContractClassPublic; diff --git a/yarn-project/bb-prover/src/prover/bb_private_kernel_prover.ts b/yarn-project/bb-prover/src/prover/bb_private_kernel_prover.ts index e8d530d7043..320c07c09d8 100644 --- a/yarn-project/bb-prover/src/prover/bb_private_kernel_prover.ts +++ b/yarn-project/bb-prover/src/prover/bb_private_kernel_prover.ts @@ -166,7 +166,15 @@ export abstract class BBPrivateKernelProver implements PrivateKernelProver { const witnessMap = convertInputs(inputs, compiledCircuit.abi); const timer = new Timer(); - const outputWitness = await this.simulationProvider.executeProtocolCircuit(witnessMap, compiledCircuit); + const outputWitness = await this.simulationProvider + .executeProtocolCircuit(witnessMap, compiledCircuit) + .catch((err: Error) => { + this.log.debug(`Failed to simulate ${circuitType}`, { + circuitName: mapProtocolArtifactNameToCircuitName(circuitType), + error: err, + }); + throw err; + }); const output = convertOutputs(outputWitness, compiledCircuit.abi); this.log.debug(`Simulated ${circuitType}`, { diff --git a/yarn-project/simulator/src/providers/acvm_wasm.ts b/yarn-project/simulator/src/providers/acvm_wasm.ts index 5656db2aa4d..163031132d9 100644 --- a/yarn-project/simulator/src/providers/acvm_wasm.ts +++ b/yarn-project/simulator/src/providers/acvm_wasm.ts @@ -1,3 +1,4 @@ +import { createLogger } from '@aztec/foundation/log'; import { foreignCallHandler } from '@aztec/noir-protocol-circuits-types/client'; import { type NoirCompiledCircuit } from '@aztec/types/noir'; @@ -10,6 +11,8 @@ import { type ACVMWitness } from '../acvm/acvm_types.js'; import { type SimulationProvider, parseErrorPayload } from '../common/simulation_provider.js'; export class WASMSimulator implements SimulationProvider { + constructor(protected log = createLogger('wasm-simulator')) {} + async init(): Promise { // If these are available, then we are in the // web environment. For the node environment, this @@ -21,6 +24,7 @@ export class WASMSimulator implements SimulationProvider { } async executeProtocolCircuit(input: WitnessMap, compiledCircuit: NoirCompiledCircuit): Promise { + this.log.debug('init', { hash: compiledCircuit.hash }); await this.init(); // Execute the circuit on those initial witness values // @@ -34,13 +38,20 @@ export class WASMSimulator implements SimulationProvider { input, foreignCallHandler, // handle calls to debug_log ); - + this.log.debug('execution successful', { hash: compiledCircuit.hash }); return _witnessMap; } catch (err) { // Typescript types catched errors as unknown or any, so we need to narrow its type to check if it has raw assertion payload. if (typeof err === 'object' && err !== null && 'rawAssertionPayload' in err) { - throw parseErrorPayload(compiledCircuit.abi, err as ExecutionError); + const parsed = parseErrorPayload(compiledCircuit.abi, err as ExecutionError); + this.log.debug('execution failed', { + hash: compiledCircuit.hash, + error: parsed, + message: parsed.message, + }); + throw parsed; } + this.log.debug('execution failed', { hash: compiledCircuit.hash, error: err }); throw new Error(`Circuit execution failed: ${err}`); } }