diff --git a/tools/slicec-cs/src/builders.rs b/tools/slicec-cs/src/builders.rs index a55c8d4f5..bfbdb570b 100644 --- a/tools/slicec-cs/src/builders.rs +++ b/tools/slicec-cs/src/builders.rs @@ -142,9 +142,7 @@ impl ContainerBuilder { pub fn add_fields(&mut self, fields: &[&Field]) -> &mut Self { for field in fields { - let type_string = field - .data_type() - .cs_type_string(&field.namespace(), TypeContext::Field, false); + let type_string = field.data_type().field_type_string(&field.namespace(), false); self.add_field( &field.field_name(), @@ -338,7 +336,7 @@ impl FunctionBuilder { }; for (index, parameter) in parameters.iter().enumerate() { - let parameter_type = parameter.cs_type_string(&operation.namespace(), context, false); + let parameter_type = parameter.cs_type_string(&operation.namespace(), context); let parameter_name = parameter.parameter_name(); let default_value = if context == TypeContext::OutgoingParam && (index >= trailing_optional_parameters_index) { diff --git a/tools/slicec-cs/src/decoding.rs b/tools/slicec-cs/src/decoding.rs index bba3d8877..6ee317dd4 100644 --- a/tools/slicec-cs/src/decoding.rs +++ b/tools/slicec-cs/src/decoding.rs @@ -8,7 +8,7 @@ use crate::slicec_ext::*; use convert_case::Case; use slicec::code_block::CodeBlock; use slicec::grammar::*; -use slicec::utils::code_gen_util::{get_bit_sequence_size, TypeContext}; +use slicec::utils::code_gen_util::get_bit_sequence_size; /// Compute how many bits are needed to decode the provided members, and if more than 0 bits are needed, /// this generates code that creates a new `BitSequenceReader` with the necessary capacity. @@ -26,9 +26,7 @@ pub fn decode_fields(fields: &[&Field], encoding: Encoding) -> CodeBlock { let mut code = CodeBlock::default(); initialize_bit_sequence_reader_for(fields, &mut code, encoding); - let action = |field_name, field_value| { - writeln!(code, "this.{field_name} = {field_value};") - }; + let action = |field_name, field_value| writeln!(code, "this.{field_name} = {field_value};"); decode_fields_core(fields, encoding, action); code @@ -75,7 +73,7 @@ pub fn default_activator(encoding: Encoding) -> &'static str { fn decode_member(member: &impl Member, namespace: &str, encoding: Encoding) -> CodeBlock { let mut code = CodeBlock::default(); let data_type = member.data_type(); - let type_string = data_type.cs_type_string(namespace, TypeContext::IncomingParam, true); + let type_string = data_type.field_type_string(namespace, true); if data_type.is_optional { match data_type.concrete_type() { @@ -173,14 +171,10 @@ fn decode_dictionary(dictionary_ref: &TypeRef, namespace: &str, enco // decode value let mut decode_value = decode_func(value_type, namespace, encoding); if matches!(value_type.concrete_type(), Types::Sequence(_) | Types::Dictionary(_)) { - write!( - decode_value, - " as {}", - value_type.cs_type_string(namespace, TypeContext::Field, true), - ); + write!(decode_value, " as {}", value_type.field_type_string(namespace, true)); } - let dictionary_type = dictionary_ref.cs_type_string(namespace, TypeContext::IncomingParam, true); + let dictionary_type = dictionary_ref.incoming_parameter_type_string(namespace, true); let decode_key = decode_key.indent(); let decode_value = decode_value.indent(); @@ -231,15 +225,11 @@ fn decode_sequence(sequence_ref: &TypeRef, namespace: &str, encoding: if !has_cs_type_attribute && matches!(element_type.concrete_type(), Types::Sequence(_)) { // For nested sequences we want to cast Foo[][] returned by DecodeSequence to IList[] // used in the request and response decode methods. - write!( - code, - "({}[])", - element_type.cs_type_string(namespace, TypeContext::Field, false), - ); + write!(code, "({}[])", element_type.field_type_string(namespace, false)); }; if has_cs_type_attribute { - let sequence_type = sequence_ref.cs_type_string(namespace, TypeContext::IncomingParam, true); + let sequence_type = sequence_ref.incoming_parameter_type_string(namespace, true); let arg: Option = match element_type.concrete_type() { Types::Primitive(primitive) if primitive.fixed_wire_size().is_some() && !element_type.is_optional => { @@ -247,7 +237,7 @@ fn decode_sequence(sequence_ref: &TypeRef, namespace: &str, encoding: // faster than decoding the collection elements one by one. Some(format!( "decoder.DecodeSequence<{}>({})", - element_type.cs_type_string(namespace, TypeContext::IncomingParam, true), + element_type.incoming_parameter_type_string(namespace, true), if matches!(primitive, Primitive::Bool) { "checkElement: SliceDecoder.CheckBoolValue" } else { @@ -265,14 +255,14 @@ fn decode_sequence(sequence_ref: &TypeRef, namespace: &str, encoding: if enum_def.is_unchecked { Some(format!( "decoder.DecodeSequence<{}>()", - element_type.cs_type_string(namespace, TypeContext::IncomingParam, true), + element_type.incoming_parameter_type_string(namespace, true), )) } else { Some(format!( "\ decoder.DecodeSequence( ({enum_type_name} e) => _ = {underlying_extensions_class}.As{name}(({underlying_type})e))", - enum_type_name = element_type.cs_type_string(namespace, TypeContext::IncomingParam, false), + enum_type_name = element_type.incoming_parameter_type_string(namespace, false), underlying_extensions_class = enum_def.escape_scoped_identifier_with_suffix( &format!( "{}Extensions", @@ -332,7 +322,7 @@ decoder.DecodeSequenceOfOptionals( write!( code, "decoder.DecodeSequence<{}>({})", - element_type.cs_type_string(namespace, TypeContext::IncomingParam, true), + element_type.incoming_parameter_type_string(namespace, true), if matches!(primitive, Primitive::Bool) { "checkElement: SliceDecoder.CheckBoolValue" } else { @@ -345,7 +335,7 @@ decoder.DecodeSequenceOfOptionals( write!( code, "decoder.DecodeSequence<{}>()", - element_type.cs_type_string(namespace, TypeContext::IncomingParam, true), + element_type.incoming_parameter_type_string(namespace, true), ) } else { write!( @@ -353,7 +343,7 @@ decoder.DecodeSequenceOfOptionals( "\ decoder.DecodeSequence( ({enum_type} e) => _ = {underlying_extensions_class}.As{name}(({underlying_type})e))", - enum_type = element_type.cs_type_string(namespace, TypeContext::IncomingParam, false), + enum_type = element_type.incoming_parameter_type_string(namespace, false), underlying_extensions_class = enum_def.escape_scoped_identifier_with_suffix( &format!( "{}Extensions", @@ -400,11 +390,7 @@ fn decode_result_field(type_ref: &TypeRef, namespace: &str, encoding: Encoding) // TODO: it's lame to have to do this here. We should provide a better API. if matches!(type_ref.concrete_type(), Types::Sequence(_) | Types::Dictionary(_)) { - write!( - decode_func, - " as {}", - type_ref.cs_type_string(namespace, TypeContext::Field, false), - ); + write!(decode_func, " as {}", type_ref.field_type_string(namespace, false)); } if type_ref.is_optional { @@ -427,7 +413,7 @@ pub fn decode_func(type_ref: &TypeRef, namespace: &str, encoding: Encoding) -> C fn decode_func_body(type_ref: &TypeRef, namespace: &str, encoding: Encoding) -> CodeBlock { let mut code = CodeBlock::default(); - let type_name = type_ref.cs_type_string(namespace, TypeContext::IncomingParam, true); + let type_name = type_ref.incoming_parameter_type_string(namespace, true); // When we decode the type, we decode it as a non-optional. // If the type is supposed to be optional, we cast it after decoding. @@ -503,7 +489,7 @@ pub fn decode_operation(operation: &Operation, dispatch: bool) -> CodeBlock { // For optional value types we have to use the full type as the compiler cannot // disambiguate between null and the actual value type. let param_type_string = match param_type.is_optional && param_type.is_value_type() { - true => param_type.cs_type_string(&namespace, TypeContext::IncomingParam, false), + true => param_type.incoming_parameter_type_string(&namespace, false), false => "var".to_owned(), }; @@ -530,7 +516,7 @@ pub fn decode_operation_stream( ) -> CodeBlock { let cs_encoding = encoding.to_cs_encoding(); let param_type = stream_member.data_type(); - let param_type_str = param_type.cs_type_string(namespace, TypeContext::IncomingParam, false); + let param_type_str = param_type.incoming_parameter_type_string(namespace, false); let fixed_wire_size = param_type.fixed_wire_size(); match param_type.concrete_type() { diff --git a/tools/slicec-cs/src/encoding.rs b/tools/slicec-cs/src/encoding.rs index a1fc463d9..e701cbdaf 100644 --- a/tools/slicec-cs/src/encoding.rs +++ b/tools/slicec-cs/src/encoding.rs @@ -229,7 +229,7 @@ fn encode_tagged_type( let null_check = if read_only_memory { format!("{param}.Span != null") } else { - let unwrapped_type = data_type.cs_type_string(namespace, type_context, true); + let unwrapped_type = get_type_string(data_type, namespace, type_context, true); format!("{param} is {unwrapped_type} {unwrapped_name}") }; @@ -332,7 +332,7 @@ fn encode_action( ) -> CodeBlock { CodeBlock::from(format!( "(ref SliceEncoder encoder, {value_type} value) => {encode_action_body}", - value_type = type_ref.cs_type_string(namespace, type_context, is_tagged), + value_type = get_type_string(type_ref, namespace, type_context, is_tagged), encode_action_body = encode_action_body(type_ref, type_context, namespace, encoding, is_tagged), )) } @@ -457,7 +457,7 @@ fn encode_type_with_bit_sequence_optimization( namespace: &str, encoding: Encoding, ) -> CodeBlock { - let value_type = type_ref.cs_type_string(namespace, type_context, false); + let value_type = get_type_string(type_ref, namespace, type_context, false); if type_ref.is_optional { CodeBlock::from(format!( "\ @@ -549,3 +549,12 @@ int startPos_ = encoder_.EncodedByteCount;", ) .into() } + +// TODO temporary bridging code while cleaning up the type_string functions. +fn get_type_string(type_ref: &TypeRef, namespace: &str, context: TypeContext, ignore_optional: bool) -> String { + match context { + TypeContext::OutgoingParam => type_ref.outgoing_parameter_type_string(namespace, ignore_optional), + TypeContext::Field => type_ref.field_type_string(namespace, ignore_optional), + TypeContext::IncomingParam => unreachable!(), + } +} diff --git a/tools/slicec-cs/src/generators/class_generator.rs b/tools/slicec-cs/src/generators/class_generator.rs index 076ccdde2..9f6733742 100644 --- a/tools/slicec-cs/src/generators/class_generator.rs +++ b/tools/slicec-cs/src/generators/class_generator.rs @@ -9,7 +9,6 @@ use crate::member_util::*; use crate::slicec_ext::*; use slicec::code_block::CodeBlock; use slicec::grammar::{Class, Encoding, Field}; -use slicec::utils::code_gen_util::TypeContext; pub fn generate_class(class_def: &Class) -> CodeBlock { let class_name = class_def.escape_identifier(); @@ -129,7 +128,7 @@ fn constructor( for field in base_fields.iter().chain(fields.iter()) { builder.add_parameter( - &field.data_type.cs_type_string(namespace, TypeContext::Field, false), + &field.data_type.field_type_string(namespace, false), &field.parameter_name(), None, field.formatted_doc_comment_summary(), diff --git a/tools/slicec-cs/src/generators/dispatch_generator.rs b/tools/slicec-cs/src/generators/dispatch_generator.rs index 4680ac282..efc947937 100644 --- a/tools/slicec-cs/src/generators/dispatch_generator.rs +++ b/tools/slicec-cs/src/generators/dispatch_generator.rs @@ -108,7 +108,7 @@ fn request_class(interface_def: &Interface) -> CodeBlock { }, &format!( "global::System.Threading.Tasks.ValueTask<{}>", - ¶meters.to_tuple_type(namespace, TypeContext::IncomingParam, false), + ¶meters.to_tuple_type(namespace, TypeContext::IncomingParam), ), &operation.escape_identifier_with_prefix_and_suffix("Decode", "Async"), function_type, @@ -192,7 +192,7 @@ fn response_class(interface_def: &Interface) -> CodeBlock { match non_streamed_returns.as_slice() { [param] => { builder.add_parameter( - ¶m.cs_type_string(namespace, TypeContext::OutgoingParam, false), + ¶m.cs_type_string(namespace, TypeContext::OutgoingParam), "returnValue", None, Some("The operation return value.".to_owned()), @@ -201,7 +201,7 @@ fn response_class(interface_def: &Interface) -> CodeBlock { _ => { for param in &non_streamed_returns { builder.add_parameter( - ¶m.cs_type_string(namespace, TypeContext::OutgoingParam, false), + ¶m.cs_type_string(namespace, TypeContext::OutgoingParam), ¶m.parameter_name(), None, param.formatted_param_doc_comment(), diff --git a/tools/slicec-cs/src/generators/exception_generator.rs b/tools/slicec-cs/src/generators/exception_generator.rs index cd6745427..7960ca7a9 100644 --- a/tools/slicec-cs/src/generators/exception_generator.rs +++ b/tools/slicec-cs/src/generators/exception_generator.rs @@ -7,7 +7,6 @@ use crate::member_util::*; use crate::slicec_ext::*; use slicec::code_block::CodeBlock; use slicec::grammar::{Encoding, Exception, Member}; -use slicec::utils::code_gen_util::TypeContext; pub fn generate_exception(exception_def: &Exception) -> CodeBlock { let exception_name = exception_def.escape_identifier(); @@ -144,7 +143,7 @@ fn one_shot_constructor(exception_def: &Exception) -> CodeBlock { for field in &all_fields { ctor_builder.add_parameter( - &field.data_type().cs_type_string(namespace, TypeContext::Field, false), + &field.data_type().field_type_string(namespace, false), field.parameter_name().as_str(), None, field.formatted_doc_comment_summary(), diff --git a/tools/slicec-cs/src/generators/proxy_generator.rs b/tools/slicec-cs/src/generators/proxy_generator.rs index c9b75273e..b797a58fc 100644 --- a/tools/slicec-cs/src/generators/proxy_generator.rs +++ b/tools/slicec-cs/src/generators/proxy_generator.rs @@ -347,7 +347,7 @@ if ({features_parameter}?.Get() is null) invocation_builder.add_argument( FunctionCallBuilder::new(format!( "{stream_parameter_name}.ToPipeReader<{}>", - stream_type.cs_type_string(namespace, TypeContext::OutgoingParam, false), + stream_type.outgoing_parameter_type_string(namespace, false), )) .use_semicolon(false) .add_argument(encode_stream_parameter(stream_type, namespace, operation.encoding).indent()) @@ -486,7 +486,7 @@ fn request_class(interface_def: &Interface) -> CodeBlock { for param in ¶ms { builder.add_parameter( - ¶m.cs_type_string(namespace, TypeContext::OutgoingParam, false), + ¶m.cs_type_string(namespace, TypeContext::OutgoingParam), ¶m.parameter_name(), None, param.formatted_param_doc_comment(), @@ -554,7 +554,7 @@ fn response_class(interface_def: &Interface) -> CodeBlock { } else { format!( "global::System.Threading.Tasks.ValueTask<{}>", - members.to_tuple_type(namespace, TypeContext::IncomingParam, false), + members.to_tuple_type(namespace, TypeContext::IncomingParam), ) }; diff --git a/tools/slicec-cs/src/generators/struct_generator.rs b/tools/slicec-cs/src/generators/struct_generator.rs index 765f99917..5378cab6e 100644 --- a/tools/slicec-cs/src/generators/struct_generator.rs +++ b/tools/slicec-cs/src/generators/struct_generator.rs @@ -10,7 +10,6 @@ use crate::member_util::*; use crate::slicec_ext::{CommentExt, EntityExt, MemberExt, TypeRefExt}; use slicec::code_block::CodeBlock; use slicec::grammar::*; -use slicec::utils::code_gen_util::*; pub fn generate_struct(struct_def: &Struct) -> CodeBlock { let escaped_identifier = struct_def.escape_identifier(); @@ -54,7 +53,7 @@ pub fn generate_struct(struct_def: &Struct) -> CodeBlock { for field in &fields { main_constructor.add_parameter( - &field.data_type().cs_type_string(&namespace, TypeContext::Field, false), + &field.data_type().field_type_string(&namespace, false), field.parameter_name().as_str(), None, field.formatted_doc_comment_summary(), diff --git a/tools/slicec-cs/src/member_util.rs b/tools/slicec-cs/src/member_util.rs index 72abaf8ce..7e91be6dc 100644 --- a/tools/slicec-cs/src/member_util.rs +++ b/tools/slicec-cs/src/member_util.rs @@ -4,7 +4,6 @@ use crate::comments::CommentTag; use crate::slicec_ext::*; use slicec::code_block::CodeBlock; use slicec::grammar::{Contained, Field, Member}; -use slicec::utils::code_gen_util::TypeContext; /// Takes a list of members and sorts them in the following order: [required members][tagged members] /// Required members are left in the provided order. Tagged members are sorted so tag values are in increasing order. @@ -23,9 +22,7 @@ pub fn escape_parameter_name(parameters: &[&impl Member], name: &str) -> String } pub fn field_declaration(field: &Field) -> String { - let type_string = field - .data_type() - .cs_type_string(&field.namespace(), TypeContext::Field, false); + let type_string = field.data_type().field_type_string(&field.namespace(), false); let mut prelude = CodeBlock::default(); if let Some(summary) = field.formatted_doc_comment_summary() { diff --git a/tools/slicec-cs/src/slicec_ext/member_ext.rs b/tools/slicec-cs/src/slicec_ext/member_ext.rs index 3f0e715fa..7a6204fbc 100644 --- a/tools/slicec-cs/src/slicec_ext/member_ext.rs +++ b/tools/slicec-cs/src/slicec_ext/member_ext.rs @@ -63,7 +63,7 @@ impl FieldExt for Field { } pub trait ParameterExt { - fn cs_type_string(&self, namespace: &str, context: TypeContext, ignore_optional: bool) -> String; + fn cs_type_string(&self, namespace: &str, context: TypeContext) -> String; /// Returns the message of the `@param` tag corresponding to this parameter from the operation it's part of. /// If the operation has no doc comment, or a matching `@param` tag, this returns `None`. @@ -71,8 +71,14 @@ pub trait ParameterExt { } impl ParameterExt for Parameter { - fn cs_type_string(&self, namespace: &str, context: TypeContext, ignore_optional: bool) -> String { - let type_str = self.data_type().cs_type_string(namespace, context, ignore_optional); + fn cs_type_string(&self, namespace: &str, context: TypeContext) -> String { + // TODO this can be further simplified. + let type_str = match context { + TypeContext::OutgoingParam => self.data_type().outgoing_parameter_type_string(namespace, false), + TypeContext::IncomingParam => self.data_type().incoming_parameter_type_string(namespace, false), + TypeContext::Field => unreachable!(), + }; + if self.is_streamed { if type_str == "byte" { "global::System.IO.Pipelines.PipeReader".to_owned() @@ -99,7 +105,7 @@ impl ParameterExt for Parameter { pub trait ParameterSliceExt { fn to_argument_tuple(&self, prefix: &str) -> String; - fn to_tuple_type(&self, namespace: &str, context: TypeContext, ignore_optional: bool) -> String; + fn to_tuple_type(&self, namespace: &str, context: TypeContext) -> String; } impl ParameterSliceExt for [&Parameter] { @@ -117,14 +123,14 @@ impl ParameterSliceExt for [&Parameter] { } } - fn to_tuple_type(&self, namespace: &str, context: TypeContext, ignore_optional: bool) -> String { + fn to_tuple_type(&self, namespace: &str, context: TypeContext) -> String { match self { [] => panic!("tuple type with no members"), - [member] => member.cs_type_string(namespace, context, ignore_optional), + [member] => member.cs_type_string(namespace, context), _ => format!( "({})", self.iter() - .map(|m| m.cs_type_string(namespace, context, ignore_optional) + " " + &m.field_name()) + .map(|m| m.cs_type_string(namespace, context) + " " + &m.field_name()) .collect::>() .join(", "), ), diff --git a/tools/slicec-cs/src/slicec_ext/operation_ext.rs b/tools/slicec-cs/src/slicec_ext/operation_ext.rs index 7be2df330..8d3a42003 100644 --- a/tools/slicec-cs/src/slicec_ext/operation_ext.rs +++ b/tools/slicec-cs/src/slicec_ext/operation_ext.rs @@ -58,7 +58,7 @@ fn operation_return_type(operation: &Operation, is_dispatch: bool, context: Type if let Some(stream_member) = operation.streamed_return_member() { format!( "(global::System.IO.Pipelines.PipeReader Payload, {} {})", - stream_member.cs_type_string(&namespace, context, false), + stream_member.cs_type_string(&namespace, context), stream_member.field_name(), ) } else { @@ -67,8 +67,8 @@ fn operation_return_type(operation: &Operation, is_dispatch: bool, context: Type } else { match operation.return_members().as_slice() { [] => "void".to_owned(), - [member] => member.cs_type_string(&namespace, context, false), - members => members.to_tuple_type(&namespace, context, false), + [member] => member.cs_type_string(&namespace, context), + members => members.to_tuple_type(&namespace, context), } } } diff --git a/tools/slicec-cs/src/slicec_ext/type_ref_ext.rs b/tools/slicec-cs/src/slicec_ext/type_ref_ext.rs index b7e47b70d..3a0e283e9 100644 --- a/tools/slicec-cs/src/slicec_ext/type_ref_ext.rs +++ b/tools/slicec-cs/src/slicec_ext/type_ref_ext.rs @@ -3,14 +3,14 @@ use super::{EntityExt, EnumExt, PrimitiveExt}; use crate::cs_attributes::CsType; use slicec::grammar::*; -use slicec::utils::code_gen_util::TypeContext; pub trait TypeRefExt { /// Is this type known to map to a C# value type? fn is_value_type(&self) -> bool; - /// The C# mapped type for this type reference. - fn cs_type_string(&self, namespace: &str, context: TypeContext, ignore_optional: bool) -> String; + fn field_type_string(&self, namespace: &str, ignore_optional: bool) -> String; + fn incoming_parameter_type_string(&self, namespace: &str, ignore_optional: bool) -> String; + fn outgoing_parameter_type_string(&self, namespace: &str, ignore_optional: bool) -> String; } impl TypeRefExt for TypeRef { @@ -23,100 +23,94 @@ impl TypeRefExt for TypeRef { } } - fn cs_type_string(&self, namespace: &str, context: TypeContext, mut ignore_optional: bool) -> String { - let type_str = match &self.concrete_typeref() { + fn field_type_string(&self, namespace: &str, ignore_optional: bool) -> String { + let type_string = match &self.concrete_typeref() { + TypeRefs::Primitive(primitive_ref) => primitive_ref.cs_type().to_owned(), TypeRefs::Struct(struct_ref) => struct_ref.escape_scoped_identifier(namespace), TypeRefs::Class(class_ref) => class_ref.escape_scoped_identifier(namespace), TypeRefs::Enum(enum_ref) => enum_ref.escape_scoped_identifier(namespace), - TypeRefs::ResultType(result_type_ref) => result_type_to_string(result_type_ref, namespace), + TypeRefs::ResultType(result_type_ref) => { + let success_type = result_type_ref.success_type.field_type_string(namespace, false); + let failure_type = result_type_ref.failure_type.field_type_string(namespace, false); + format!("Result<{success_type}, {failure_type}>") + } TypeRefs::CustomType(custom_type_ref) => { let attribute = custom_type_ref.definition().find_attribute::(); - attribute.unwrap().type_string.clone() + let attribute = attribute.expect("called 'type_string' on custom type with no 'cs::type' attribute!"); + attribute.type_string.clone() } TypeRefs::Sequence(sequence_ref) => { - // For readonly sequences of fixed size numeric elements the mapping is the - // same for optional an non optional types. - if context == TypeContext::OutgoingParam - && sequence_ref.has_fixed_size_primitive_elements() - && !self.has_attribute::() - { - ignore_optional = true; - } - sequence_type_to_string(sequence_ref, namespace, context) + let element_type = sequence_ref.element_type.field_type_string(namespace, false); + format!("global::System.Collections.Generic.IList<{element_type}>") + } + TypeRefs::Dictionary(dictionary_ref) => { + let key_type = dictionary_ref.key_type.field_type_string(namespace, false); + let value_type = dictionary_ref.value_type.field_type_string(namespace, false); + format!("global::System.Collections.Generic.IDictionary<{key_type}, {value_type}>") } - TypeRefs::Dictionary(dictionary_ref) => dictionary_type_to_string(dictionary_ref, namespace, context), - TypeRefs::Primitive(primitive_ref) => primitive_ref.cs_type().to_owned(), }; - if self.is_optional && !ignore_optional { - type_str + "?" - } else { - type_str - } + set_optional_modifier_for(type_string, self.is_optional && !ignore_optional) } -} -/// Helper method to convert a sequence type into a string -fn sequence_type_to_string(sequence_ref: &TypeRef, namespace: &str, context: TypeContext) -> String { - let element_type = sequence_ref - .element_type - .cs_type_string(namespace, TypeContext::Field, false); - - let cs_type_attribute = sequence_ref.find_attribute::(); - - match context { - TypeContext::Field => { - format!("global::System.Collections.Generic.IList<{element_type}>") - } - TypeContext::IncomingParam => match cs_type_attribute { - Some(arg) => arg.type_string.clone(), - None => format!("{element_type}[]"), - }, - TypeContext::OutgoingParam => { - // If the underlying type is of fixed size, we map to `ReadOnlyMemory` instead. - if sequence_ref.has_fixed_size_primitive_elements() && cs_type_attribute.is_none() { - format!("global::System.ReadOnlyMemory<{element_type}>") - } else { - format!("global::System.Collections.Generic.IEnumerable<{element_type}>") + fn incoming_parameter_type_string(&self, namespace: &str, ignore_optional: bool) -> String { + let type_string = match &self.concrete_typeref() { + TypeRefs::Sequence(sequence_ref) => { + match sequence_ref.find_attribute::() { + Some(argument) => argument.type_string.clone(), + None => { + let element_type = sequence_ref.element_type.field_type_string(namespace, false); + format!("{element_type}[]") + } + } } - } - } -} + TypeRefs::Dictionary(dictionary_ref) => { + match dictionary_ref.find_attribute::() { + Some(argument) => argument.type_string.clone(), + None => { + let key_type = dictionary_ref.key_type.field_type_string(namespace, false); + let value_type = dictionary_ref.value_type.field_type_string(namespace, false); + format!("global::System.Collections.Generic.Dictionary<{key_type}, {value_type}>") + } + } + } + _ => self.field_type_string(namespace, true), + }; -/// Helper method to convert a dictionary type into a string -fn dictionary_type_to_string(dictionary_ref: &TypeRef, namespace: &str, context: TypeContext) -> String { - let key_type = dictionary_ref - .key_type - .cs_type_string(namespace, TypeContext::Field, false); - let value_type = dictionary_ref - .value_type - .cs_type_string(namespace, TypeContext::Field, false); + set_optional_modifier_for(type_string, self.is_optional && !ignore_optional) + } - let cs_type_attribute = dictionary_ref.find_attribute::(); + fn outgoing_parameter_type_string(&self, namespace: &str, mut ignore_optional: bool) -> String { + let type_string = match &self.concrete_typeref() { + TypeRefs::Sequence(sequence_ref) => { + let element_type = sequence_ref.element_type.field_type_string(namespace, false); + let has_cs_type_attribute = self.has_attribute::(); + if sequence_ref.has_fixed_size_primitive_elements() && !has_cs_type_attribute { + // If the underlying type is of fixed size, we map to `ReadOnlyMemory` instead, + // and the mapping is the same for optional and non-optional types. + ignore_optional = true; + format!("global::System.ReadOnlyMemory<{element_type}>") + } else { + format!("global::System.Collections.Generic.IEnumerable<{element_type}>") + } + } + TypeRefs::Dictionary(dictionary_ref) => { + let key_type = dictionary_ref.key_type.field_type_string(namespace, false); + let value_type = dictionary_ref.value_type.field_type_string(namespace, false); + format!( + "global::System.Collections.Generic.IEnumerable>" + ) + } + _ => self.field_type_string(namespace, true), + }; - match context { - TypeContext::Field => { - format!("global::System.Collections.Generic.IDictionary<{key_type}, {value_type}>") - } - TypeContext::IncomingParam => match cs_type_attribute { - Some(arg) => arg.type_string.clone(), - None => format!("global::System.Collections.Generic.Dictionary<{key_type}, {value_type}>"), - }, - TypeContext::OutgoingParam => - format!( - "global::System.Collections.Generic.IEnumerable>" - ) + set_optional_modifier_for(type_string, self.is_optional && !ignore_optional) } } -/// Helper method to convert a result type into a string -fn result_type_to_string(result_type_ref: &TypeRef, namespace: &str) -> String { - let success_type = result_type_ref - .success_type - .cs_type_string(namespace, TypeContext::Field, false); - let failure_type = result_type_ref - .failure_type - .cs_type_string(namespace, TypeContext::Field, false); - - format!("Result<{success_type}, {failure_type}>") +fn set_optional_modifier_for(type_string: String, is_optional: bool) -> String { + match is_optional { + true => type_string + "?", + false => type_string, + } }