Skip to content

Commit

Permalink
feat(experimental): Support struct constructors in match patterns (no…
Browse files Browse the repository at this point in the history
  • Loading branch information
jfecher authored Feb 24, 2025
1 parent e26e993 commit 6f79fd1
Show file tree
Hide file tree
Showing 3 changed files with 180 additions and 9 deletions.
87 changes: 79 additions & 8 deletions compiler/noirc_frontend/src/elaborator/enums.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
use std::collections::BTreeMap;

use fm::FileId;
use fxhash::FxHashMap as HashMap;
use iter_extended::{try_vecmap, vecmap};
use noirc_errors::Location;

use crate::{
ast::{
EnumVariant, Expression, ExpressionKind, FunctionKind, Ident, Literal, NoirEnumeration,
StatementKind, UnresolvedType, Visibility,
ConstructorExpression, EnumVariant, Expression, ExpressionKind, FunctionKind, Ident,
Literal, NoirEnumeration, StatementKind, UnresolvedType, Visibility,
},
elaborator::path_resolution::PathResolutionItem,
hir::{comptime::Value, resolution::errors::ResolverError, type_check::TypeCheckError},
Expand Down Expand Up @@ -356,11 +358,14 @@ impl Elaborator<'_> {
if let Some(existing) =
variables_defined.iter().find(|elem| *elem == &last_ident)
{
let error = ResolverError::VariableAlreadyDefinedInPattern {
existing: existing.clone(),
new_span: last_ident.span(),
};
self.push_err(error, self.file);
// Allow redefinition of `_` only, to ignore variables
if last_ident.0.contents != "_" {
let error = ResolverError::VariableAlreadyDefinedInPattern {
existing: existing.clone(),
new_span: last_ident.span(),
};
self.push_err(error, self.file);
}
} else {
variables_defined.push(last_ident.clone());
}
Expand All @@ -381,7 +386,9 @@ impl Elaborator<'_> {
expected_type,
variables_defined,
),
ExpressionKind::Constructor(_) => todo!("handle constructors"),
ExpressionKind::Constructor(constructor) => {
self.constructor_to_pattern(*constructor, variables_defined)
}
ExpressionKind::Tuple(fields) => {
let field_types = vecmap(0..fields.len(), |_| self.interner.next_type_variable());
let actual = Type::Tuple(field_types.clone());
Expand Down Expand Up @@ -434,6 +441,53 @@ impl Elaborator<'_> {
}
}

fn constructor_to_pattern(
&mut self,
constructor: ConstructorExpression,
variables_defined: &mut Vec<Ident>,
) -> Pattern {
let location = constructor.typ.location;
let typ = self.resolve_type(constructor.typ);

let Some((struct_name, mut expected_field_types)) =
self.struct_name_and_field_types(&typ, location)
else {
return Pattern::Error;
};

let mut fields = BTreeMap::default();
for (field_name, field) in constructor.fields {
let Some(field_index) =
expected_field_types.iter().position(|(name, _)| *name == field_name.0.contents)
else {
let error = if fields.contains_key(&field_name.0.contents) {
ResolverError::DuplicateField { field: field_name }
} else {
let struct_definition = struct_name.clone();
ResolverError::NoSuchField { field: field_name, struct_definition }
};
self.push_err(error, self.file);
continue;
};

let (field_name, expected_field_type) = expected_field_types.swap_remove(field_index);
let pattern =
self.expression_to_pattern(field, &expected_field_type, variables_defined);
fields.insert(field_name, pattern);
}

if !expected_field_types.is_empty() {
let struct_definition = struct_name;
let span = location.span;
let missing_fields = vecmap(expected_field_types, |(name, _)| name);
let error = ResolverError::MissingFields { span, missing_fields, struct_definition };
self.push_err(error, self.file);
}

let args = vecmap(fields, |(_name, field)| field);
Pattern::Constructor(Constructor::Variant(typ, 0), args)
}

fn expression_to_constructor(
&mut self,
name: Expression,
Expand Down Expand Up @@ -555,6 +609,23 @@ impl Elaborator<'_> {
Pattern::Constructor(constructor, args)
}

fn struct_name_and_field_types(
&mut self,
typ: &Type,
location: Location,
) -> Option<(Ident, Vec<(String, Type)>)> {
if let Type::DataType(typ, generics) = typ.follow_bindings_shallow().as_ref() {
if let Some(fields) = typ.borrow().get_fields(generics) {
return Some((typ.borrow().name.clone(), fields));
}
}

let error =
ResolverError::NonStructUsedInConstructor { typ: typ.to_string(), span: location.span };
self.push_err(error, location.file);
None
}

/// Compiles the rows of a match expression, outputting a decision tree for the match.
///
/// This is an adaptation of https://github.com/yorickpeterse/pattern-matching-in-rust/tree/main/jacobs2021
Expand Down
75 changes: 75 additions & 0 deletions compiler/noirc_frontend/src/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4494,3 +4494,78 @@ fn errors_on_repeated_match_variables_in_pattern() {
CompilationError::ResolverError(ResolverError::VariableAlreadyDefinedInPattern { .. })
));
}

#[test]
fn duplicate_field_in_match_struct_pattern() {
let src = r#"
fn main() {
let foo = Foo { x: 10, y: 20 };
match foo {
Foo { x: _, x: _, y: _ } => {}
}
}
struct Foo {
x: i32,
y: Field,
}
"#;

let errors = get_program_errors(src);
assert_eq!(errors.len(), 1);

assert!(matches!(
&errors[0].0,
CompilationError::ResolverError(ResolverError::DuplicateField { .. })
));
}

#[test]
fn missing_field_in_match_struct_pattern() {
let src = r#"
fn main() {
let foo = Foo { x: 10, y: 20 };
match foo {
Foo { x: _ } => {}
}
}
struct Foo {
x: i32,
y: Field,
}
"#;

let errors = get_program_errors(src);
assert_eq!(errors.len(), 1);

assert!(matches!(
&errors[0].0,
CompilationError::ResolverError(ResolverError::MissingFields { .. })
));
}

#[test]
fn no_such_field_in_match_struct_pattern() {
let src = r#"
fn main() {
let foo = Foo { x: 10, y: 20 };
match foo {
Foo { x: _, y: _, z: _ } => {}
}
}
struct Foo {
x: i32,
y: Field,
}
"#;

let errors = get_program_errors(src);
assert_eq!(errors.len(), 1);

assert!(matches!(
&errors[0].0,
CompilationError::ResolverError(ResolverError::NoSuchField { .. })
));
}
27 changes: 26 additions & 1 deletion test_programs/compile_success_empty/enums/src/main.nr
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ fn main() {
foo_tests();
option_tests();
abc_tests();
match_on_structs();
}

fn primitive_tests() {
Expand Down Expand Up @@ -103,7 +104,7 @@ fn abc_tests() {
// Mut is only to throw the optimizer off a bit so we can see
// the `eq`s that get generated before they're removed because each of these are constant
let mut tuple = (ABC::A, ABC::B);
match tuple {
let _ = match tuple {
(ABC::A, _) => 1,
(_, ABC::A) => 2,
(_, ABC::B) => 3,
Expand All @@ -114,3 +115,27 @@ fn abc_tests() {
_ => 0,
};
}

fn match_on_structs() {
let foo = MyStruct { x: 10, y: 20 };
match foo {
MyStruct { x, y } => {
assert_eq(x, 10);
assert_eq(y, 20);
},
}

match MyOption::Some(foo) {
MyOption::Some(MyStruct { x: x2, y: y2 }) => {
assert_eq(x2, 10);
assert_eq(y2, 20);
},
MyOption::None => fail(),
MyOption::Maybe => fail(),
}
}

struct MyStruct {
x: i32,
y: Field,
}

0 comments on commit 6f79fd1

Please sign in to comment.