Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Preserve custom operators #2125

Merged
merged 5 commits into from
Jul 8, 2021
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
215 changes: 215 additions & 0 deletions src/linker/Linker.Steps/DiscoverCustomOperatorsHandler.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,215 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using System.Collections.Generic;
using System.Diagnostics;
using Mono.Cecil;

namespace Mono.Linker.Steps
{
public class DiscoverOperatorsHandler : IMarkHandler
{
LinkContext _context;
bool seenLinqExpressions;
readonly HashSet<TypeDefinition> _trackedTypesWithOperators;
Dictionary<TypeDefinition, HashSet<MethodDefinition>> _pendingOperatorsForType;

Dictionary<TypeDefinition, HashSet<MethodDefinition>> PendingOperatorsForType {
get {
if (_pendingOperatorsForType == null)
_pendingOperatorsForType = new Dictionary<TypeDefinition, HashSet<MethodDefinition>> ();
return _pendingOperatorsForType;
}
}

public DiscoverOperatorsHandler ()
{
_trackedTypesWithOperators = new HashSet<TypeDefinition> ();
}

public void Initialize (LinkContext context, MarkContext markContext)
{
_context = context;
markContext.RegisterMarkTypeAction (ProcessType);
}

void ProcessType (TypeDefinition type)
{
CheckForLinqExpressions (type);

// Check for custom operators and either:
// - mark them, if Linq.Expressions was already marked, or
// - track them to be marked in case Linq.Expressions is marked later
var hasOperators = ProcessCustomOperators (type, mark: seenLinqExpressions);
if (!seenLinqExpressions) {
if (hasOperators)
_trackedTypesWithOperators.Add (type);
return;
}

// Mark pending operators defined on other types that reference this type
// (these are only tracked if we have already seen Linq.Expressions)
if (PendingOperatorsForType.TryGetValue (type, out var pendingOperators)) {
foreach (var customOperator in pendingOperators)
MarkOperator (customOperator);
PendingOperatorsForType.Remove (type);
}
}

void CheckForLinqExpressions (TypeDefinition type)
{
if (seenLinqExpressions)
return;

if (type.Namespace != "System.Linq.Expressions" || type.Name != "Expression")
return;

seenLinqExpressions = true;

foreach (var markedType in _trackedTypesWithOperators)
ProcessCustomOperators (markedType, mark: true);

_trackedTypesWithOperators.Clear ();
}

void MarkOperator (MethodDefinition method)
{
_context.Annotations.Mark (method, new DependencyInfo (DependencyKind.PreservedOperator, method.DeclaringType));
}

bool ProcessCustomOperators (TypeDefinition type, bool mark)
{
if (!type.HasMethods)
return false;

bool hasCustomOperators = false;
foreach (var method in type.Methods) {
if (!IsOperator (method, out var otherType))
continue;

if (!mark)
return true;

Debug.Assert (seenLinqExpressions);
hasCustomOperators = true;

if (otherType == null || _context.Annotations.IsMarked (otherType)) {
MarkOperator (method);
continue;
}

// Wait until otherType gets marked to mark the operator.
if (!PendingOperatorsForType.TryGetValue (otherType, out var pendingOperators)) {
pendingOperators = new HashSet<MethodDefinition> ();
PendingOperatorsForType.Add (otherType, pendingOperators);
}
pendingOperators.Add (method);
}
return hasCustomOperators;
}

TypeDefinition _nullableOfT;
TypeDefinition NullableOfT {
get {
if (_nullableOfT == null)
_nullableOfT = BCL.FindPredefinedType ("System", "Nullable`1", _context);
return _nullableOfT;
}
}

TypeDefinition NonNullableType (TypeReference type)
{
var typeDef = _context.TryResolve (type);
if (typeDef == null)
return null;

if (!typeDef.IsValueType || typeDef != NullableOfT)
return typeDef;

// Unwrap Nullable<T>
Debug.Assert (typeDef.HasGenericParameters);
var nullableType = type as GenericInstanceType;
Debug.Assert (nullableType != null && nullableType.HasGenericArguments && nullableType.GenericArguments.Count == 1);
return _context.TryResolve (nullableType.GenericArguments[0]);
}

bool IsOperator (MethodDefinition method, out TypeDefinition otherType)
{
otherType = null;

if (!method.IsStatic || !method.IsPublic || !method.IsSpecialName || !method.Name.StartsWith ("op_"))
return false;

var operatorName = method.Name.Substring (3);
var self = method.DeclaringType;

switch (operatorName) {
// Unary operators
case "UnaryPlus":
case "UnaryNegation":
case "LogicalNot":
case "OnesComplement":
case "Increment":
case "Decrement":
case "True":
case "False":
// Parameter type of a unary operator must be the declaring type
if (method.Parameters.Count != 1 || NonNullableType (method.Parameters[0].ParameterType) != self)
return false;
// ++ and -- must return the declaring type
if (operatorName is "Increment" or "Decrement" && NonNullableType (method.ReturnType) != self)
return false;
return true;
// Binary operators
case "Addition":
case "Subtraction":
case "Multiply":
case "Division":
case "Modulus":
case "BitwiseAnd":
case "BitwiseOr":
case "ExclusiveOr":
case "LeftShift":
case "RightShift":
case "Equality":
case "Inequality":
case "LessThan":
case "GreaterThan":
case "LessThanOrEqual":
case "GreaterThanOrEqual":
if (method.Parameters.Count != 2)
return false;
var nnLeft = NonNullableType (method.Parameters[0].ParameterType);
var nnRight = NonNullableType (method.Parameters[1].ParameterType);
if (nnLeft == null || nnRight == null)
return false;
// << and >> must take the declaring type and int
if (operatorName is "LeftShift" or "RightShift" && (nnLeft != self || nnRight.MetadataType != MetadataType.Int32))
return false;
// At least one argument must be the declaring type
if (nnLeft != self && nnRight != self)
return false;
if (nnLeft != self)
otherType = nnLeft;
if (nnRight != self)
otherType = nnRight;
return true;
// Conversion operators
case "Implicit":
case "Explicit":
if (method.Parameters.Count != 1)
return false;
var nnSource = NonNullableType (method.Parameters[0].ParameterType);
var nnTarget = NonNullableType (method.ReturnType);
// Exactly one of source/target must be the declaring type
if (nnSource == self == (nnTarget == self))
return false;
otherType = nnSource == self ? nnTarget : nnSource;
return true;
default:
return false;
}
}
}
}
2 changes: 2 additions & 0 deletions src/linker/Linker/DependencyInfo.cs
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,8 @@ public enum DependencyKind
XmlSerialized = 84, // entry type or member for XML serialization
SerializedRecursiveType = 85, // recursive type kept due to serialization handling
SerializedMember = 86, // field or property kept on a type for serialization

PreservedOperator = 87 // operator method preserved on a type
}

public readonly struct DependencyInfo : IEquatable<DependencyInfo>
Expand Down
9 changes: 9 additions & 0 deletions src/linker/Linker/Driver.cs
Original file line number Diff line number Diff line change
Expand Up @@ -355,6 +355,12 @@ protected int SetupContext (ILogger customLogger = null)

continue;

case "--disable-operator-discovery":
if (!GetBoolParam (token, l => context.DisableOperatorDiscovery = l))
return -1;

continue;

case "--ignore-descriptors":
if (!GetBoolParam (token, l => context.IgnoreDescriptors = l))
return -1;
Expand Down Expand Up @@ -732,6 +738,9 @@ protected int SetupContext (ILogger customLogger = null)
if (!context.DisableSerializationDiscovery)
p.MarkHandlers.Add (new DiscoverSerializationHandler ());

if (!context.DisableOperatorDiscovery)
p.MarkHandlers.Add (new DiscoverOperatorsHandler ());

foreach (string custom_step in custom_steps) {
if (!AddCustomStep (p, custom_step))
return -1;
Expand Down
2 changes: 2 additions & 0 deletions src/linker/Linker/LinkContext.cs
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,8 @@ public bool IgnoreUnresolved {

public bool DisableSerializationDiscovery { get; set; }

public bool DisableOperatorDiscovery { get; set; }

public bool IgnoreDescriptors { get; set; }

public bool IgnoreSubstitutions { get; set; }
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
using Mono.Linker.Tests.Cases.Expectations.Assertions;
using Mono.Linker.Tests.Cases.Expectations.Metadata;

namespace Mono.Linker.Tests.Cases.LinqExpressions
{
[SetupLinkerArgument ("--disable-operator-discovery")]
public class CanDisableOperatorDiscovery
{
public static void Main ()
{
var c = new CustomOperators ();
var expression = typeof (System.Linq.Expressions.Expression);
c = -c;
var t = typeof (TargetType);
}

[KeptMember (".ctor()")]
class CustomOperators
{
[Kept]
public static CustomOperators operator - (CustomOperators c) => null;

public static CustomOperators operator + (CustomOperators c) => null;
public static CustomOperators operator + (CustomOperators left, CustomOperators right) => null;
public static explicit operator TargetType (CustomOperators self) => null;
}

[Kept]
class TargetType { }
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
using Mono.Linker.Tests.Cases.Expectations.Assertions;
using Mono.Linker.Tests.Cases.Expectations.Metadata;

namespace Mono.Linker.Tests.Cases.LinqExpressions
{
public class CanPreserveCustomOperators
{
public static void Main ()
{
var t = typeof (CustomOperators);
var expression = typeof (System.Linq.Expressions.Expression);

var t3 = typeof (TargetTypeImplicit);
var t4 = typeof (SourceTypeImplicit);
var t5 = typeof (TargetTypeExplicit);
var t6 = typeof (SourceTypeExplicit);
}

class CustomOperators
{
// Unary operators
[Kept]
public static CustomOperators operator + (CustomOperators c) => null;
[Kept]
public static CustomOperators operator - (CustomOperators c) => null;
[Kept]
public static CustomOperators operator ! (CustomOperators c) => null;
[Kept]
public static CustomOperators operator ~ (CustomOperators c) => null;
[Kept]
public static CustomOperators operator ++ (CustomOperators c) => null;
[Kept]
public static CustomOperators operator -- (CustomOperators c) => null;
[Kept]
public static bool operator true (CustomOperators c) => true;
[Kept]
public static bool operator false (CustomOperators c) => true;

// Binary operators
[Kept]
public static CustomOperators operator + (CustomOperators left, CustomOperators right) => null;
[Kept]
public static CustomOperators operator - (CustomOperators left, CustomOperators right) => null;
[Kept]
public static CustomOperators operator * (CustomOperators left, CustomOperators right) => null;
[Kept]
public static CustomOperators operator / (CustomOperators left, CustomOperators right) => null;
[Kept]
public static CustomOperators operator % (CustomOperators left, CustomOperators right) => null;
[Kept]
public static CustomOperators operator & (CustomOperators left, CustomOperators right) => null;
[Kept]
public static CustomOperators operator | (CustomOperators left, CustomOperators right) => null;
[Kept]
public static CustomOperators operator ^ (CustomOperators left, CustomOperators right) => null;
[Kept]
public static CustomOperators operator << (CustomOperators value, int shift) => null;
[Kept]
public static CustomOperators operator >> (CustomOperators value, int shift) => null;
[Kept]
public static CustomOperators operator == (CustomOperators left, CustomOperators right) => null;
[Kept]
public static CustomOperators operator != (CustomOperators left, CustomOperators right) => null;
[Kept]
public static CustomOperators operator < (CustomOperators left, CustomOperators right) => null;
[Kept]
public static CustomOperators operator > (CustomOperators left, CustomOperators right) => null;
[Kept]
public static CustomOperators operator <= (CustomOperators left, CustomOperators right) => null;
[Kept]
public static CustomOperators operator >= (CustomOperators left, CustomOperators right) => null;

// conversion operators
[Kept]
public static implicit operator TargetTypeImplicit (CustomOperators self) => null;
[Kept]
public static implicit operator CustomOperators (SourceTypeImplicit other) => null;
[Kept]
public static explicit operator TargetTypeExplicit (CustomOperators self) => null;
[Kept]
public static explicit operator CustomOperators (SourceTypeExplicit other) => null;
}

[Kept]
class TargetTypeImplicit { }
[Kept]
class SourceTypeImplicit { }
[Kept]
class TargetTypeExplicit { }
[Kept]
class SourceTypeExplicit { }
}
}
Loading