diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MemberInitExpressionToAggregationExpressionTranslator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MemberInitExpressionToAggregationExpressionTranslator.cs index 2e65fc34cd1..809775141f2 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MemberInitExpressionToAggregationExpressionTranslator.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MemberInitExpressionToAggregationExpressionTranslator.cs @@ -21,6 +21,7 @@ using MongoDB.Bson.Serialization; using MongoDB.Driver.Linq.Linq3Implementation.Ast; using MongoDB.Driver.Linq.Linq3Implementation.Ast.Expressions; +using MongoDB.Driver.Linq.Linq3Implementation.Misc; namespace MongoDB.Driver.Linq.Linq3Implementation.Translators.ExpressionToAggregationExpressionTranslators { @@ -57,7 +58,8 @@ public static AggregationExpression Translate( var constructorArgumentSerializer = constructorArgumentTranslation.Serializer ?? BsonSerializer.LookupSerializer(constructorArgumentType); var memberMap = EnsureMemberMap(expression, classMap, creatorMapParameter); EnsureDefaultValue(memberMap); - memberMap.SetSerializer(constructorArgumentSerializer); + var memberSerializer = CoerceSourceSerializerToMemberSerializer(memberMap, constructorArgumentSerializer); + memberMap.SetSerializer(memberSerializer); computedFields.Add(AstExpression.ComputedField(memberMap.ElementName, constructorArgumentTranslation.Ast)); } } @@ -69,7 +71,8 @@ public static AggregationExpression Translate( var memberMap = FindMemberMap(expression, classMap, member.Name); var valueExpression = memberAssignment.Expression; var valueTranslation = ExpressionToAggregationExpressionTranslator.Translate(context, valueExpression); - memberMap.SetSerializer(valueTranslation.Serializer); + var memberSerializer = CoerceSourceSerializerToMemberSerializer(memberMap, valueTranslation.Serializer); + memberMap.SetSerializer(memberSerializer); computedFields.Add(AstExpression.ComputedField(memberMap.ElementName, valueTranslation.Ast)); } @@ -107,6 +110,27 @@ private static BsonClassMap CreateClassMap(Type classType, ConstructorInfo const return classMap; } + private static IBsonSerializer CoerceSourceSerializerToMemberSerializer(BsonMemberMap memberMap, IBsonSerializer sourceSerializer) + { + var memberType = memberMap.MemberType; + var memberSerializer = memberMap.GetSerializer(); + var sourceType = sourceSerializer.ValueType; + + if (memberType != sourceType && + memberType.ImplementsIEnumerable(out var memberItemType) && + sourceType.ImplementsIEnumerable(out var sourceItemType) && + sourceItemType == memberItemType && + sourceSerializer is IBsonArraySerializer sourceArraySerializer && + sourceArraySerializer.TryGetItemSerializationInfo(out var sourceItemSerializationInfo) && + memberSerializer is IChildSerializerConfigurable memberChildSerializerConfigurable) + { + var sourceItemSerializer = sourceItemSerializationInfo.Serializer; + return memberChildSerializerConfigurable.WithChildSerializer(sourceItemSerializer); + } + + return sourceSerializer; + } + private static BsonMemberMap EnsureMemberMap(Expression expression, BsonClassMap classMap, MemberInfo creatorMapParameter) { var declaringClassMap = classMap; diff --git a/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp4731Tests.cs b/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp4731Tests.cs new file mode 100644 index 00000000000..a676d613b26 --- /dev/null +++ b/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp4731Tests.cs @@ -0,0 +1,119 @@ +/* Copyright 2010-present MongoDB Inc. +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +using System; +using System.Collections.Generic; +using System.Linq; +using FluentAssertions; +using MongoDB.Bson; +using MongoDB.Bson.Serialization.Attributes; +using MongoDB.Driver.Linq; +using MongoDB.TestHelpers.XunitExtensions; +using Xunit; + +namespace MongoDB.Driver.Tests.Linq.Linq3Implementation.Jira +{ + public class CSharp4731Tests : Linq3IntegrationTest + { + [Theory] + [ParameterAttributeData] + public void Select_setting_IList_from_List_should_work( + [Values(LinqProvider.V2, LinqProvider.V3)] LinqProvider linqProvider) + { + var collection = GetCollection(linqProvider); + + var queryable = collection + .AsQueryable() + .Select(x => new P { IList = x.List }) + .Where(x => x.IList.Contains(E.A)); + + if (linqProvider == LinqProvider.V2) + { + var exception = Record.Exception(() => Translate(collection, queryable)); + exception.Should().BeOfType(); + } + else + { + var stages = Translate(collection, queryable); + AssertStages( + stages, + "{ $project : { IList : '$List', _id : 0 } }", + "{ $match : { IList : 'A' } }"); + + var result = queryable.Single(); + result.IList.Should().Equal(E.A, E.B); + } + } + + [Theory] + [ParameterAttributeData] + public void Select_setting_IReadOnlyList_from_List_should_work( + [Values(LinqProvider.V2, LinqProvider.V3)] LinqProvider linqProvider) + { + var collection = GetCollection(linqProvider); + + var queryable = collection + .AsQueryable() + .Select(x => new Q { IReadOnlyList = x.List }) + .Where(x => x.IReadOnlyList.Contains(E.A)); + + if (linqProvider == LinqProvider.V2) + { + var exception = Record.Exception(() => Translate(collection, queryable)); + exception.Should().BeOfType(); + } + else + { + var stages = Translate(collection, queryable); + AssertStages( + stages, + "{ $project : { IReadOnlyList : '$List', _id : 0 } }", + "{ $match : { IReadOnlyList : 'A' } }"); + + var result = queryable.Single(); + result.IReadOnlyList.Should().Equal(E.A, E.B); + } + } + + private IMongoCollection GetCollection(LinqProvider linqProvider) + { + var collection = GetCollection("test", linqProvider); + CreateCollection( + collection, + new Test { Id = 1, List = new List { E.A, E.B } }, + new Test { Id = 2, List = new List { E.C, E.D } }); + return collection; + } + + private class Test + { + public int Id { get; set; } + [BsonRepresentation(BsonType.String)] + public List List { get; set; } + } + + private class P + { + public IList IList { get; set; } + } + + private class Q + { + public IReadOnlyList IReadOnlyList { get; set; } + } + + private enum E { A, B, C, D }; + } +}