Skip to content

Commit

Permalink
Merge pull request #20 from nsubstitute/GH-12-detect-reentrant-calls
Browse files Browse the repository at this point in the history
[Gh-12] detect reentrant calls
  • Loading branch information
tpodolak authored Jul 18, 2018
2 parents 1a911be + a92b428 commit a43e3db
Show file tree
Hide file tree
Showing 32 changed files with 8,197 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
using System.Collections.Generic;
using System.Collections.Immutable;
using System.Linq;
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using NSubstitute.Analyzers.Shared.DiagnosticAnalyzers;

namespace NSubstitute.Analyzers.CSharp.DiagnosticAnalyzers
{
internal class ReEntrantCallFinder : AbstractReEntrantCallFinder
{
protected override ImmutableList<ISymbol> GetReEntrantSymbols(SemanticModel semanticModel, SyntaxNode rootNode)
{
var visitor = new ReEntrantCallVisitor(this, semanticModel);
visitor.Visit(rootNode);
return visitor.InvocationSymbols;
}

private class ReEntrantCallVisitor : CSharpSyntaxWalker
{
private readonly ReEntrantCallFinder _reEntrantCallFinder;
private readonly SemanticModel _semanticModel;
private readonly HashSet<SyntaxNode> _visitedNodes = new HashSet<SyntaxNode>();
private readonly List<ISymbol> _invocationSymbols = new List<ISymbol>();

public ImmutableList<ISymbol> InvocationSymbols => _invocationSymbols.ToImmutableList();

public ReEntrantCallVisitor(ReEntrantCallFinder reEntrantCallFinder, SemanticModel semanticModel)
{
_reEntrantCallFinder = reEntrantCallFinder;
_semanticModel = semanticModel;
}

public override void VisitInvocationExpression(InvocationExpressionSyntax node)
{
var symbolInfo = _semanticModel.GetSymbolInfo(node);
if (_reEntrantCallFinder.IsReturnsLikeMethod(_semanticModel, symbolInfo.Symbol))
{
_invocationSymbols.Add(symbolInfo.Symbol);
}

base.VisitInvocationExpression(node);
}

public override void DefaultVisit(SyntaxNode node)
{
VisitRelatedSymbols(node);
base.DefaultVisit(node);
}

private void VisitRelatedSymbols(SyntaxNode syntaxNode)
{
if (_visitedNodes.Contains(syntaxNode) == false &&
(syntaxNode.IsKind(SyntaxKind.IdentifierName) ||
syntaxNode.IsKind(SyntaxKind.ElementAccessExpression) ||
syntaxNode.IsKind(SyntaxKind.SimpleMemberAccessExpression)))
{
_visitedNodes.Add(syntaxNode);
foreach (var relatedNode in _reEntrantCallFinder.GetRelatedNodes(_semanticModel, syntaxNode))
{
Visit(relatedNode);
}
}
}
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
using System.Collections.Generic;
using System.Collections.Immutable;
using System.Linq;
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using Microsoft.CodeAnalysis.Diagnostics;
using NSubstitute.Analyzers.Shared.DiagnosticAnalyzers;

namespace NSubstitute.Analyzers.CSharp.DiagnosticAnalyzers
{
[DiagnosticAnalyzer(LanguageNames.CSharp)]
internal class ReEntrantSetupAnalyzer : AbstractReEntrantSetupAnalyzer<SyntaxKind, InvocationExpressionSyntax>
{
public ReEntrantSetupAnalyzer()
: base(new DiagnosticDescriptorsProvider())
{
}

protected override AbstractReEntrantCallFinder GetReEntrantCallFinder()
{
return new ReEntrantCallFinder();
}

protected override SyntaxKind InvocationExpressionKind { get; } = SyntaxKind.InvocationExpression;

protected override IEnumerable<SyntaxNode> ExtractArguments(InvocationExpressionSyntax invocationExpressionSyntax)
{
return invocationExpressionSyntax.ArgumentList.Arguments.Select(arg => arg.Expression);
}
}
}
27 changes: 27 additions & 0 deletions src/NSubstitute.Analyzers.CSharp/Resources.Designer.cs

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

12 changes: 12 additions & 0 deletions src/NSubstitute.Analyzers.CSharp/Resources.resx
Original file line number Diff line number Diff line change
Expand Up @@ -273,4 +273,16 @@
<value>Non-virtual setup specification.</value>
<comment>The title of the diagnostic.</comment>
</data>
<data name="ReEntrantSubstituteCallDescription" xml:space="preserve">
<value>Re-entrant substitute call.</value>
<comment>An optional longer localizable description of the diagnostic.</comment>
</data>
<data name="ReEntrantSubstituteCallMessageFormat" xml:space="preserve">
<value>{0}() is set with a method that itself calls {1}. This can cause problems with NSubstitute. Consider replacing with a lambda: {0}(x => {2}).</value>
<comment>The format-able message the diagnostic displays.</comment>
</data>
<data name="ReEntrantSubstituteCallTitle" xml:space="preserve">
<value>Re-entrant substitute call.</value>
<comment>The title of the diagnostic.</comment>
</data>
</root>
Original file line number Diff line number Diff line change
Expand Up @@ -29,5 +29,7 @@ internal class AbstractDiagnosticDescriptorsProvider<T> : IDiagnosticDescriptors
public DiagnosticDescriptor NonVirtualReceivedSetupSpecification { get; } = DiagnosticDescriptors<T>.NonVirtualReceivedSetupSpecification;

public DiagnosticDescriptor NonVirtualWhenSetupSpecification { get; } = DiagnosticDescriptors<T>.NonVirtualWhenSetupSpecification;

public DiagnosticDescriptor ReEntrantSubstituteCall { get; } = DiagnosticDescriptors<T>.ReEntrantSubstituteCall;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
using System;
using System.Collections.Generic;
using System.Collections.Immutable;
using System.Linq;
using Microsoft.CodeAnalysis;

namespace NSubstitute.Analyzers.Shared.DiagnosticAnalyzers
{
internal abstract class AbstractReEntrantCallFinder
{
private static readonly ImmutableDictionary<string, string> MethodNames = new Dictionary<string, string>()
{
[MetadataNames.NSubstituteReturnsMethod] = MetadataNames.NSubstituteSubstituteExtensionsFullTypeName,
[MetadataNames.NSubstituteReturnsForAnyArgsMethod] = MetadataNames.NSubstituteSubstituteExtensionsFullTypeName,
[MetadataNames.NSubstituteDoMethod] = MetadataNames.NSubstituteWhenCalledType
}.ToImmutableDictionary();

public ImmutableList<ISymbol> GetReEntrantCalls(SemanticModel semanticModel, SyntaxNode rootNode)
{
var typeInfo = semanticModel.GetTypeInfo(rootNode);
if (IsCalledViaDelegate(semanticModel, typeInfo))
{
return ImmutableList<ISymbol>.Empty;
}

return GetReEntrantSymbols(semanticModel, rootNode);
}

protected abstract ImmutableList<ISymbol> GetReEntrantSymbols(SemanticModel semanticModel, SyntaxNode rootNode);

protected IEnumerable<SyntaxNode> GetRelatedNodes(SemanticModel semanticModel, SyntaxNode syntaxNode)
{
var symbol = semanticModel.GetSymbolInfo(syntaxNode);
if (symbol.Symbol != null && symbol.Symbol.Locations.Any())
{
foreach (var symbolLocation in symbol.Symbol.Locations.Where(location => location.SourceTree != null))
{
var root = symbolLocation.SourceTree.GetRoot();
var relatedNode = root.FindNode(symbolLocation.SourceSpan);
if (relatedNode != null)
{
yield return relatedNode;
}
}
}
}

protected bool IsReturnsLikeMethod(SemanticModel semanticModel, ISymbol symbol)
{
if (symbol == null || MethodNames.TryGetValue(symbol.Name, out var containingType) == false)
{
return false;
}

return symbol.ContainingAssembly?.Name.Equals(MetadataNames.NSubstituteAssemblyName, StringComparison.OrdinalIgnoreCase) == true &&
(symbol.ContainingType?.ToString().Equals(containingType, StringComparison.OrdinalIgnoreCase) == true ||
(symbol.ContainingType?.ConstructedFrom.Name)?.Equals(containingType, StringComparison.OrdinalIgnoreCase) == true);
}

private static bool IsCalledViaDelegate(SemanticModel semanticModel, TypeInfo typeInfo)
{
var typeSymbol = typeInfo.Type ?? typeInfo.ConvertedType;
var isCalledViaDelegate = typeSymbol != null &&
typeSymbol.TypeKind == TypeKind.Delegate &&
typeSymbol is INamedTypeSymbol namedTypeSymbol &&
namedTypeSymbol.ConstructedFrom.Equals(semanticModel.Compilation.GetTypeByMetadataName("System.Func`2")) &&
IsCallInfoParameter(namedTypeSymbol.TypeArguments.First());

return isCalledViaDelegate;
}

private static bool IsCallInfoParameter(ITypeSymbol symbol)
{
return symbol.ContainingAssembly?.Name.Equals(MetadataNames.NSubstituteAssemblyName, StringComparison.OrdinalIgnoreCase) == true &&
symbol.ToString().Equals(MetadataNames.NSubstituteCoreFullTypeName, StringComparison.OrdinalIgnoreCase) == true;
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
using System;
using System.Collections.Generic;
using System.Collections.Immutable;
using System.Linq;
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.Diagnostics;

namespace NSubstitute.Analyzers.Shared.DiagnosticAnalyzers
{
internal abstract class AbstractReEntrantSetupAnalyzer<TSyntaxKind, TInvocationExpressionSyntax> : AbstractDiagnosticAnalyzer
where TInvocationExpressionSyntax : SyntaxNode
where TSyntaxKind : struct
{
private static readonly ImmutableHashSet<string> MethodNames = ImmutableHashSet.Create(
MetadataNames.NSubstituteReturnsMethod,
MetadataNames.NSubstituteReturnsForAnyArgsMethod);

private AbstractReEntrantCallFinder ReEntrantCallFinder => _reEntrantCallFinderProxy.Value;

private readonly Lazy<AbstractReEntrantCallFinder> _reEntrantCallFinderProxy;

public override ImmutableArray<DiagnosticDescriptor> SupportedDiagnostics =>
ImmutableArray.Create(DiagnosticDescriptorsProvider.ReEntrantSubstituteCall);

protected AbstractReEntrantSetupAnalyzer(IDiagnosticDescriptorsProvider diagnosticDescriptorsProvider)
: base(diagnosticDescriptorsProvider)
{
_reEntrantCallFinderProxy = new Lazy<AbstractReEntrantCallFinder>(GetReEntrantCallFinder);
}

protected abstract AbstractReEntrantCallFinder GetReEntrantCallFinder();

protected abstract TSyntaxKind InvocationExpressionKind { get; }

public override void Initialize(AnalysisContext context)
{
context.RegisterSyntaxNodeAction(AnalyzeInvocation, InvocationExpressionKind);
}

protected abstract IEnumerable<SyntaxNode> ExtractArguments(TInvocationExpressionSyntax invocationExpressionSyntax);

private void AnalyzeInvocation(SyntaxNodeAnalysisContext syntaxNodeContext)
{
var invocationExpression = (TInvocationExpressionSyntax)syntaxNodeContext.Node;
var methodSymbolInfo = syntaxNodeContext.SemanticModel.GetSymbolInfo(invocationExpression);

if (methodSymbolInfo.Symbol?.Kind != SymbolKind.Method)
{
return;
}

var methodSymbol = (IMethodSymbol)methodSymbolInfo.Symbol;

if (IsReturnsLikeMethod(syntaxNodeContext, invocationExpression, methodSymbol.Name) == false)
{
return;
}

var allArguments = ExtractArguments(invocationExpression);
var argumentsForAnalysis = methodSymbol.MethodKind == MethodKind.ReducedExtension ? allArguments : allArguments.Skip(1);

foreach (var argument in argumentsForAnalysis)
{
var reentrantSymbol = ReEntrantCallFinder.GetReEntrantCalls(syntaxNodeContext.SemanticModel, argument).FirstOrDefault();
if (reentrantSymbol != null)
{
var diagnostic = Diagnostic.Create(
DiagnosticDescriptorsProvider.ReEntrantSubstituteCall,
argument.GetLocation(),
methodSymbol.Name,
reentrantSymbol.Name,
argument.ToString());

syntaxNodeContext.ReportDiagnostic(diagnostic);
}
}
}

private bool IsReturnsLikeMethod(SyntaxNodeAnalysisContext syntaxNodeContext, SyntaxNode syntax, string memberName)
{
if (MethodNames.Contains(memberName) == false)
{
return false;
}

var symbol = syntaxNodeContext.SemanticModel.GetSymbolInfo(syntax);

return symbol.Symbol?.ContainingAssembly?.Name.Equals(MetadataNames.NSubstituteAssemblyName, StringComparison.OrdinalIgnoreCase) == true &&
symbol.Symbol?.ContainingType?.ToString().Equals(MetadataNames.NSubstituteSubstituteExtensionsFullTypeName, StringComparison.OrdinalIgnoreCase) == true;
}
}
}
8 changes: 8 additions & 0 deletions src/NSubstitute.Analyzers.Shared/DiagnosticDescriptors.cs
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,14 @@ internal class DiagnosticDescriptors<T>
defaultSeverity: DiagnosticSeverity.Warning,
isEnabledByDefault: true);

public static DiagnosticDescriptor ReEntrantSubstituteCall { get; } =
CreateDiagnosticDescriptor(
name: nameof(ReEntrantSubstituteCall),
id: DiagnosticIdentifiers.ReEntrantSubstituteCall,
category: DiagnosticCategories.Usage,
defaultSeverity: DiagnosticSeverity.Warning,
isEnabledByDefault: true);

private static DiagnosticDescriptor CreateDiagnosticDescriptor(
string name, string id, string category, DiagnosticSeverity defaultSeverity, bool isEnabledByDefault)
{
Expand Down
1 change: 1 addition & 0 deletions src/NSubstitute.Analyzers.Shared/DiagnosticIdentifiers.cs
Original file line number Diff line number Diff line change
Expand Up @@ -14,5 +14,6 @@ internal class DiagnosticIdentifiers
public static readonly string SubstituteConstructorArgumentsForDelegate = "NS010";
public static readonly string NonVirtualReceivedSetupSpecification = "NS011";
public static readonly string NonVirtualWhenSetupSpecification = "NS012";
public static readonly string ReEntrantSubstituteCall = "NS013";
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -29,5 +29,7 @@ internal interface IDiagnosticDescriptorsProvider
DiagnosticDescriptor NonVirtualReceivedSetupSpecification { get; }

DiagnosticDescriptor NonVirtualWhenSetupSpecification { get; }

DiagnosticDescriptor ReEntrantSubstituteCall { get; }
}
}
3 changes: 3 additions & 0 deletions src/NSubstitute.Analyzers.Shared/MetadataNames.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,11 @@ internal class MetadataNames
{
public const string NSubstituteAssemblyName = "NSubstitute";
public const string NSubstituteSubstituteExtensionsFullTypeName = "NSubstitute.SubstituteExtensions";
public const string NSubstituteCoreFullTypeName = "NSubstitute.Core.CallInfo";
public const string NSubstituteSubstituteFullTypeName = "NSubstitute.Substitute";
public const string NSubstituteReturnsMethod = "Returns";
public const string NSubstituteReturnsForAnyArgsMethod = "ReturnsForAnyArgs";
public const string NSubstituteDoMethod = "Do";
public const string NSubstituteReceivedMethod = "Received";
public const string NSubstituteReceivedWithAnyArgsMethod = "ReceivedWithAnyArgs";
public const string NSubstituteDidNotReceiveMethod = "DidNotReceive";
Expand All @@ -17,5 +19,6 @@ internal class MetadataNames
public const string CastleDynamicProxyGenAssembly2Name = "DynamicProxyGenAssembly2";
public const string NSubstituteWhenMethod = "When";
public const string NSubstituteWhenForAnyArgsMethod = "WhenForAnyArgs";
public const string NSubstituteWhenCalledType = "WhenCalled";
}
}
Loading

0 comments on commit a43e3db

Please sign in to comment.