Skip to content

Commit

Permalink
Merge pull request #73017 from CyrusNajmabadi/callbackStyle
Browse files Browse the repository at this point in the history
Switch to a callback style in the asset finding code
  • Loading branch information
CyrusNajmabadi authored Apr 14, 2024
2 parents bfb6efb + dda41de commit 91f2a8b
Show file tree
Hide file tree
Showing 4 changed files with 62 additions and 46 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ internal static async Task FindAsync<TState>(
AssetPath assetPath,
TextDocumentStates<TState> documentStates,
HashSet<Checksum> searchingChecksumsLeft,
Dictionary<Checksum, object> result,
Action<Checksum, object> onAssetFound,
CancellationToken cancellationToken) where TState : TextDocumentState
{
var hintDocument = assetPath.DocumentId;
Expand All @@ -68,7 +68,7 @@ internal static async Task FindAsync<TState>(
if (state != null)
{
Contract.ThrowIfFalse(state.TryGetStateChecksums(out var stateChecksums));
await stateChecksums.FindAsync(assetPath, state, searchingChecksumsLeft, result, cancellationToken).ConfigureAwait(false);
await stateChecksums.FindAsync(assetPath, state, searchingChecksumsLeft, onAssetFound, cancellationToken).ConfigureAwait(false);
}
}
else
Expand All @@ -81,7 +81,7 @@ internal static async Task FindAsync<TState>(

Contract.ThrowIfFalse(state.TryGetStateChecksums(out var stateChecksums));

await stateChecksums.FindAsync(assetPath, state, searchingChecksumsLeft, result, cancellationToken).ConfigureAwait(false);
await stateChecksums.FindAsync(assetPath, state, searchingChecksumsLeft, onAssetFound, cancellationToken).ConfigureAwait(false);
}
}
}
Expand All @@ -90,7 +90,7 @@ internal static void Find<T>(
IReadOnlyList<T> values,
ChecksumCollection checksums,
HashSet<Checksum> searchingChecksumsLeft,
Dictionary<Checksum, object> result,
Action<Checksum, object> onAssetFound,
CancellationToken cancellationToken) where T : class
{
Contract.ThrowIfFalse(values.Count == checksums.Children.Length);
Expand All @@ -103,7 +103,7 @@ internal static void Find<T>(

var checksum = checksums.Children[i];
if (searchingChecksumsLeft.Remove(checksum))
result[checksum] = values[i];
onAssetFound(checksum, values[i]);
}
}

Expand Down
65 changes: 37 additions & 28 deletions src/Workspaces/Core/Portable/Workspace/Solution/StateChecksums.cs
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ public async Task FindAsync(
ProjectCone? projectCone,
AssetPath assetPath,
HashSet<Checksum> searchingChecksumsLeft,
Dictionary<Checksum, object> result,
Action<Checksum, object> onAssetFound,
CancellationToken cancellationToken)
{
cancellationToken.ThrowIfCancellationRequested();
Expand All @@ -128,10 +128,10 @@ public async Task FindAsync(
if (assetPath.IncludeSolutionCompilationState)
{
if (assetPath.IncludeSolutionCompilationStateChecksums && searchingChecksumsLeft.Remove(this.Checksum))
result[this.Checksum] = this;
onAssetFound(this.Checksum, this);

if (assetPath.IncludeSolutionSourceGeneratorExecutionVersionMap && searchingChecksumsLeft.Remove(this.SourceGeneratorExecutionVersionMap))
result[this.SourceGeneratorExecutionVersionMap] = compilationState.SourceGeneratorExecutionVersionMap;
onAssetFound(this.SourceGeneratorExecutionVersionMap, compilationState.SourceGeneratorExecutionVersionMap);

if (compilationState.FrozenSourceGeneratedDocumentStates != null)
{
Expand All @@ -143,7 +143,7 @@ public async Task FindAsync(
{
await ChecksumCollection.FindAsync(
new AssetPath(AssetPathKind.DocumentText, assetPath.ProjectId, assetPath.DocumentId),
compilationState.FrozenSourceGeneratedDocumentStates, searchingChecksumsLeft, result, cancellationToken).ConfigureAwait(false);
compilationState.FrozenSourceGeneratedDocumentStates, searchingChecksumsLeft, onAssetFound, cancellationToken).ConfigureAwait(false);
}

// ... or one of the identities. In this case, we'll use the fact that there's a 1:1 correspondence between the
Expand All @@ -161,7 +161,7 @@ await ChecksumCollection.FindAsync(
if (searchingChecksumsLeft.Remove(identityChecksum))
{
Contract.ThrowIfFalse(compilationState.FrozenSourceGeneratedDocumentStates.TryGetState(documentId, out var state));
result[identityChecksum] = state.Identity;
onAssetFound(identityChecksum, state.Identity);
}
}
}
Expand All @@ -175,7 +175,7 @@ await ChecksumCollection.FindAsync(
{
var id = FrozenSourceGeneratedDocuments.Value.Ids[i];
Contract.ThrowIfFalse(compilationState.FrozenSourceGeneratedDocumentStates.TryGetState(id, out var state));
result[identityChecksum] = state.Identity;
onAssetFound(identityChecksum, state.Identity);
}
}
}
Expand All @@ -189,13 +189,13 @@ await ChecksumCollection.FindAsync(
// If we're not in a project cone, start the search at the top most state-checksum corresponding to the
// entire solution.
Contract.ThrowIfFalse(solutionState.TryGetStateChecksums(out var solutionChecksums));
await solutionChecksums.FindAsync(solutionState, projectCone, assetPath, searchingChecksumsLeft, result, cancellationToken).ConfigureAwait(false);
await solutionChecksums.FindAsync(solutionState, projectCone, assetPath, searchingChecksumsLeft, onAssetFound, cancellationToken).ConfigureAwait(false);
}
else
{
// Otherwise, grab the top-most state checksum for this cone and search within that.
Contract.ThrowIfFalse(solutionState.TryGetStateChecksums(projectCone.RootProjectId, out var solutionChecksums));
await solutionChecksums.FindAsync(solutionState, projectCone, assetPath, searchingChecksumsLeft, result, cancellationToken).ConfigureAwait(false);
await solutionChecksums.FindAsync(solutionState, projectCone, assetPath, searchingChecksumsLeft, onAssetFound, cancellationToken).ConfigureAwait(false);
}
}
}
Expand Down Expand Up @@ -269,7 +269,7 @@ public async Task FindAsync(
ProjectCone? projectCone,
AssetPath assetPath,
HashSet<Checksum> searchingChecksumsLeft,
Dictionary<Checksum, object> result,
Action<Checksum, object> onAssetFound,
CancellationToken cancellationToken)
{
cancellationToken.ThrowIfCancellationRequested();
Expand All @@ -279,13 +279,13 @@ public async Task FindAsync(
if (assetPath.IncludeSolutionState)
{
if (assetPath.IncludeSolutionStateChecksums && searchingChecksumsLeft.Remove(Checksum))
result[Checksum] = this;
onAssetFound(Checksum, this);

if (assetPath.IncludeSolutionAttributes && searchingChecksumsLeft.Remove(Attributes))
result[Attributes] = solution.SolutionAttributes;
onAssetFound(Attributes, solution.SolutionAttributes);

if (assetPath.IncludeSolutionAnalyzerReferences)
ChecksumCollection.Find(solution.AnalyzerReferences, AnalyzerReferences, searchingChecksumsLeft, result, cancellationToken);
ChecksumCollection.Find(solution.AnalyzerReferences, AnalyzerReferences, searchingChecksumsLeft, onAssetFound, cancellationToken);
}

if (searchingChecksumsLeft.Count == 0)
Expand All @@ -304,7 +304,7 @@ public async Task FindAsync(
if (projectState != null &&
projectState.TryGetStateChecksums(out var projectStateChecksums))
{
await projectStateChecksums.FindAsync(projectState, assetPath, searchingChecksumsLeft, result, cancellationToken).ConfigureAwait(false);
await projectStateChecksums.FindAsync(projectState, assetPath, searchingChecksumsLeft, onAssetFound, cancellationToken).ConfigureAwait(false);
}
}
else
Expand All @@ -327,7 +327,7 @@ public async Task FindAsync(
if (!projectState.TryGetStateChecksums(out var projectStateChecksums))
continue;

await projectStateChecksums.FindAsync(projectState, assetPath, searchingChecksumsLeft, result, cancellationToken).ConfigureAwait(false);
await projectStateChecksums.FindAsync(projectState, assetPath, searchingChecksumsLeft, onAssetFound, cancellationToken).ConfigureAwait(false);
}
}
}
Expand Down Expand Up @@ -435,7 +435,7 @@ public async Task FindAsync(
ProjectState state,
AssetPath assetPath,
HashSet<Checksum> searchingChecksumsLeft,
Dictionary<Checksum, object> result,
Action<Checksum, object> onAssetFound,
CancellationToken cancellationToken)
{
cancellationToken.ThrowIfCancellationRequested();
Expand All @@ -449,32 +449,38 @@ public async Task FindAsync(
if (assetPath.IncludeProjects)
{
if (assetPath.IncludeProjectStateChecksums && searchingChecksumsLeft.Remove(Checksum))
result[Checksum] = this;
onAssetFound(Checksum, this);

if (assetPath.IncludeProjectAttributes && searchingChecksumsLeft.Remove(Info))
result[Info] = state.ProjectInfo.Attributes;
onAssetFound(Info, state.ProjectInfo.Attributes);

if (assetPath.IncludeProjectCompilationOptions && searchingChecksumsLeft.Remove(CompilationOptions))
result[CompilationOptions] = state.CompilationOptions ?? throw new InvalidOperationException("We should not be trying to serialize a project with no compilation options; RemoteSupportedLanguages.IsSupported should have filtered it out.");
{
var compilationOptions = state.CompilationOptions ?? throw new InvalidOperationException("We should not be trying to serialize a project with no compilation options; RemoteSupportedLanguages.IsSupported should have filtered it out.");
onAssetFound(CompilationOptions, compilationOptions);
}

if (assetPath.IncludeProjectParseOptions && searchingChecksumsLeft.Remove(ParseOptions))
result[ParseOptions] = state.ParseOptions ?? throw new InvalidOperationException("We should not be trying to serialize a project with no parse options; RemoteSupportedLanguages.IsSupported should have filtered it out.");
{
var parseOptions = state.ParseOptions ?? throw new InvalidOperationException("We should not be trying to serialize a project with no parse options; RemoteSupportedLanguages.IsSupported should have filtered it out.");
onAssetFound(ParseOptions, parseOptions);
}

if (assetPath.IncludeProjectProjectReferences)
ChecksumCollection.Find(state.ProjectReferences, ProjectReferences, searchingChecksumsLeft, result, cancellationToken);
ChecksumCollection.Find(state.ProjectReferences, ProjectReferences, searchingChecksumsLeft, onAssetFound, cancellationToken);

if (assetPath.IncludeProjectMetadataReferences)
ChecksumCollection.Find(state.MetadataReferences, MetadataReferences, searchingChecksumsLeft, result, cancellationToken);
ChecksumCollection.Find(state.MetadataReferences, MetadataReferences, searchingChecksumsLeft, onAssetFound, cancellationToken);

if (assetPath.IncludeProjectAnalyzerReferences)
ChecksumCollection.Find(state.AnalyzerReferences, AnalyzerReferences, searchingChecksumsLeft, result, cancellationToken);
ChecksumCollection.Find(state.AnalyzerReferences, AnalyzerReferences, searchingChecksumsLeft, onAssetFound, cancellationToken);
}

if (assetPath.IncludeDocuments)
{
await ChecksumCollection.FindAsync(assetPath, state.DocumentStates, searchingChecksumsLeft, result, cancellationToken).ConfigureAwait(false);
await ChecksumCollection.FindAsync(assetPath, state.AdditionalDocumentStates, searchingChecksumsLeft, result, cancellationToken).ConfigureAwait(false);
await ChecksumCollection.FindAsync(assetPath, state.AnalyzerConfigDocumentStates, searchingChecksumsLeft, result, cancellationToken).ConfigureAwait(false);
await ChecksumCollection.FindAsync(assetPath, state.DocumentStates, searchingChecksumsLeft, onAssetFound, cancellationToken).ConfigureAwait(false);
await ChecksumCollection.FindAsync(assetPath, state.AdditionalDocumentStates, searchingChecksumsLeft, onAssetFound, cancellationToken).ConfigureAwait(false);
await ChecksumCollection.FindAsync(assetPath, state.AnalyzerConfigDocumentStates, searchingChecksumsLeft, onAssetFound, cancellationToken).ConfigureAwait(false);
}
}
}
Expand All @@ -500,18 +506,21 @@ public async Task FindAsync(
AssetPath assetPath,
TextDocumentState state,
HashSet<Checksum> searchingChecksumsLeft,
Dictionary<Checksum, object> result,
Action<Checksum, object> onAssetFound,
CancellationToken cancellationToken)
{
Debug.Assert(state.TryGetStateChecksums(out var stateChecksum) && this == stateChecksum);

cancellationToken.ThrowIfCancellationRequested();

if (assetPath.IncludeDocumentAttributes && searchingChecksumsLeft.Remove(Info))
result[Info] = state.Attributes;
onAssetFound(Info, state.Attributes);

if (assetPath.IncludeDocumentText && searchingChecksumsLeft.Remove(Text))
result[Text] = await SerializableSourceText.FromTextDocumentStateAsync(state, cancellationToken).ConfigureAwait(false);
{
var text = await SerializableSourceText.FromTextDocumentStateAsync(state, cancellationToken).ConfigureAwait(false);
onAssetFound(Text, text);
}
}
}

Expand Down
15 changes: 9 additions & 6 deletions src/Workspaces/Remote/Core/SolutionAssetStorage.Scope.cs
Original file line number Diff line number Diff line change
Expand Up @@ -58,14 +58,17 @@ public async Task AddAssetsAsync(
var numberOfChecksumsToSearch = checksumsToFind.Count;
Contract.ThrowIfTrue(checksumsToFind.Contains(Checksum.Null));

await FindAssetsAsync(assetPath, checksumsToFind, assetMap, cancellationToken).ConfigureAwait(false);
await FindAssetsAsync(
assetPath, checksumsToFind,
(checksum, asset) => assetMap[checksum] = asset,
cancellationToken).ConfigureAwait(false);

Contract.ThrowIfTrue(checksumsToFind.Count > 0);
Contract.ThrowIfTrue(assetMap.Count != numberOfChecksumsToSearch);
}

private async Task FindAssetsAsync(
AssetPath assetPath, HashSet<Checksum> remainingChecksumsToFind, Dictionary<Checksum, object> result, CancellationToken cancellationToken)
AssetPath assetPath, HashSet<Checksum> remainingChecksumsToFind, Action<Checksum, object> onAssetFound, CancellationToken cancellationToken)
{
var solutionState = this.CompilationState;

Expand All @@ -74,13 +77,13 @@ private async Task FindAssetsAsync(
// If we're not in a project cone, start the search at the top most state-checksum corresponding to the
// entire solution.
Contract.ThrowIfFalse(solutionState.TryGetStateChecksums(out var stateChecksums));
await stateChecksums.FindAsync(solutionState, this.ProjectCone, assetPath, remainingChecksumsToFind, result, cancellationToken).ConfigureAwait(false);
await stateChecksums.FindAsync(solutionState, this.ProjectCone, assetPath, remainingChecksumsToFind, onAssetFound, cancellationToken).ConfigureAwait(false);
}
else
{
// Otherwise, grab the top-most state checksum for this cone and search within that.
Contract.ThrowIfFalse(solutionState.TryGetStateChecksums(this.ProjectCone.RootProjectId, out var stateChecksums));
await stateChecksums.FindAsync(solutionState, this.ProjectCone, assetPath, remainingChecksumsToFind, result, cancellationToken).ConfigureAwait(false);
await stateChecksums.FindAsync(solutionState, this.ProjectCone, assetPath, remainingChecksumsToFind, onAssetFound, cancellationToken).ConfigureAwait(false);
}
}

Expand All @@ -97,10 +100,10 @@ public async ValueTask<object> GetAssetAsync(Checksum checksum, CancellationToke
{
Contract.ThrowIfTrue(checksum == Checksum.Null);

using var checksumPool = Creator.CreateChecksumSet(checksum);
using var _ = Creator.CreateResultMap(out var resultPool);

await scope.FindAssetsAsync(AssetPath.FullLookupForTesting, checksumPool.Object, resultPool, cancellationToken).ConfigureAwait(false);
var checksums = new ReadOnlyMemory<Checksum>([checksum]);
await scope.AddAssetsAsync(AssetPath.FullLookupForTesting, checksums, resultPool, cancellationToken).ConfigureAwait(false);
Contract.ThrowIfTrue(resultPool.Count != 1);

var (resultingChecksum, value) = resultPool.First();
Expand Down
Loading

0 comments on commit 91f2a8b

Please sign in to comment.