Skip to content

Commit

Permalink
Do not pass non-thread-safe logger to parallel computations in NodeCo…
Browse files Browse the repository at this point in the history
…de.Parallel (#16717)

* don't pass non-threadsafe logger to parallel computations

* no need to restore in forks

---------

Co-authored-by: Petr <psfinaki@users.noreply.github.com>
  • Loading branch information
majocha and psfinaki authored Feb 16, 2024
1 parent 11e6619 commit 0f520bc
Show file tree
Hide file tree
Showing 4 changed files with 63 additions and 16 deletions.
35 changes: 23 additions & 12 deletions src/Compiler/Facilities/BuildGraph.fs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ let wrapThreadStaticInfo computation =
DiagnosticsThreadStatics.BuildPhase <- phase
}

let unwrapNode (Node(computation)) = computation

type Async<'T> with

static member AwaitNodeCode(node: NodeCode<'T>) =
Expand Down Expand Up @@ -193,19 +195,28 @@ type NodeCode private () =
}

static member Parallel(computations: NodeCode<'T> seq) =
let diagnosticsLogger = DiagnosticsThreadStatics.DiagnosticsLogger
let phase = DiagnosticsThreadStatics.BuildPhase
node {
let concurrentLogging = new CaptureDiagnosticsConcurrently()
let phase = DiagnosticsThreadStatics.BuildPhase
// Why does it return just IDisposable?
use _ = concurrentLogging

computations
|> Seq.map (fun (Node x) ->
async {
DiagnosticsThreadStatics.DiagnosticsLogger <- diagnosticsLogger
DiagnosticsThreadStatics.BuildPhase <- phase
return! x
})
|> Async.Parallel
|> wrapThreadStaticInfo
|> Node
let injectLogger i computation =
let logger = concurrentLogging.GetLoggerForTask($"NodeCode.Parallel {i}")

async {
DiagnosticsThreadStatics.DiagnosticsLogger <- logger
DiagnosticsThreadStatics.BuildPhase <- phase
return! unwrapNode computation
}

return!
computations
|> Seq.mapi injectLogger
|> Async.Parallel
|> wrapThreadStaticInfo
|> Node
}

[<RequireQualifiedAccess>]
module GraphNode =
Expand Down
15 changes: 15 additions & 0 deletions src/Compiler/Facilities/DiagnosticsLogger.fs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ open System.Reflection
open System.Threading
open Internal.Utilities.Library
open Internal.Utilities.Library.Extras
open System.Collections.Concurrent

/// Represents the style being used to format errors
[<RequireQualifiedAccess>]
Expand Down Expand Up @@ -883,3 +884,17 @@ type StackGuard(maxDepth: int, name: string) =

static member GetDepthOption(name: string) =
GetEnvInteger ("FSHARP_" + name + "StackGuardDepth") StackGuard.DefaultDepth

type CaptureDiagnosticsConcurrently() =
let target = DiagnosticsThreadStatics.DiagnosticsLogger
let loggers = ResizeArray()

member _.GetLoggerForTask(name) : DiagnosticsLogger =
let logger = CapturingDiagnosticsLogger(name)
loggers.Add logger
logger

interface IDisposable with
member _.Dispose() =
for logger in loggers do
logger.CommitDelayedDiagnostics target
7 changes: 7 additions & 0 deletions src/Compiler/Facilities/DiagnosticsLogger.fsi
Original file line number Diff line number Diff line change
Expand Up @@ -463,3 +463,10 @@ type CompilationGlobalsScope =
member DiagnosticsLogger: DiagnosticsLogger

member BuildPhase: BuildPhase

type CaptureDiagnosticsConcurrently =
new: unit -> CaptureDiagnosticsConcurrently

member GetLoggerForTask: string -> DiagnosticsLogger

interface IDisposable
22 changes: 18 additions & 4 deletions tests/FSharp.Compiler.UnitTests/BuildGraphTests.fs
Original file line number Diff line number Diff line change
Expand Up @@ -235,22 +235,36 @@ module BuildGraphTests =
Assert.shouldBeTrue graphNode.HasValue
Assert.shouldBe (ValueSome 1) (graphNode.TryPeekValue())


type ExampleException(msg) = inherit System.Exception(msg)

[<Fact>]
let internal ``NodeCode preserves DiagnosticsThreadStatics`` () =
let random =
let rng = Random()
fun n -> rng.Next n

let job phase _ = node {
let job phase i = node {
do! random 10 |> Async.Sleep |> NodeCode.AwaitAsync
Assert.Equal(phase, DiagnosticsThreadStatics.BuildPhase)
DiagnosticsThreadStatics.DiagnosticsLogger.DebugDisplay()
|> Assert.shouldBe $"DiagnosticsLogger(NodeCode.Parallel {i})"

errorR (ExampleException $"job {i}")
}

let work (phase: BuildPhase) =
node {
use _ = new CompilationGlobalsScope(DiscardErrorsLogger, phase)
let! _ = Seq.init 8 (job phase) |> NodeCode.Parallel
let n = 8
let logger = CapturingDiagnosticsLogger("test NodeCode")
use _ = new CompilationGlobalsScope(logger, phase)
let! _ = Seq.init n (job phase) |> NodeCode.Parallel

let diags = logger.Diagnostics |> List.map fst

diags |> List.map _.Phase |> Set |> Assert.shouldBe (Set.singleton phase)
diags |> List.map _.Exception.Message
|> Assert.shouldBe (List.init n <| sprintf "job %d")

Assert.Equal(phase, DiagnosticsThreadStatics.BuildPhase)
}

Expand Down

0 comments on commit 0f520bc

Please sign in to comment.