From f81b8070c09891f1914e063248a892a5370f2dfe Mon Sep 17 00:00:00 2001 From: Adam Fowler Date: Mon, 23 Oct 2023 11:17:39 +0100 Subject: [PATCH] Add async version of NIOThreadPool.runIfActive --- Sources/NIOPosix/NIOThreadPool.swift | 42 +++++++++++++++++++++ Tests/NIOPosixTests/NIOThreadPoolTest.swift | 16 ++++++++ 2 files changed, 58 insertions(+) diff --git a/Sources/NIOPosix/NIOThreadPool.swift b/Sources/NIOPosix/NIOThreadPool.swift index e6b631d1aa..dd41f02c28 100644 --- a/Sources/NIOPosix/NIOThreadPool.swift +++ b/Sources/NIOPosix/NIOThreadPool.swift @@ -292,6 +292,48 @@ extension NIOThreadPool { } } +@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) +extension NIOThreadPool { + #if swift(>=5.7) + /// Runs the submitted closure if the thread pool is still active, otherwise throw an error. + /// The closure will be run on the thread pool so can do blocking work. + /// + /// - parameters: + /// - body: The closure which performs some blocking work to be done on the thread pool. + /// - returns: result of the passed closure. + @preconcurrency + public func runIfActive(_ body: @escaping @Sendable () throws -> T) async throws -> T { + try await self._runIfActive(body) + } + #else + /// Runs the submitted closure if the thread pool is still active, otherwise throw an error. + /// The closure will be run on the thread pool so can do blocking work. + /// + /// - parameters: + /// - body: The closure which performs some blocking work to be done on the thread pool. + /// - returns: result of the passed closure. + public func runIfActive(_ body: @escaping () throws -> T) async throws -> T { + try await self._runIfActive(body) + } + #endif + + private func _runIfActive(_ body: @escaping () throws -> T) async throws -> T { + try await withCheckedThrowingContinuation { (cont: CheckedContinuation) in + self.submit { shouldRun in + guard case shouldRun = NIOThreadPool.WorkItemState.active else { + cont.resume(throwing: NIOThreadPoolError.ThreadPoolInactive()) + return + } + do { + try cont.resume(returning: body()) + } catch { + cont.resume(throwing: error) + } + } + } + } +} + extension NIOThreadPool { @preconcurrency public func shutdownGracefully(_ callback: @escaping @Sendable (Error?) -> Void) { diff --git a/Tests/NIOPosixTests/NIOThreadPoolTest.swift b/Tests/NIOPosixTests/NIOThreadPoolTest.swift index a36e4794ac..121747cd56 100644 --- a/Tests/NIOPosixTests/NIOThreadPoolTest.swift +++ b/Tests/NIOPosixTests/NIOThreadPoolTest.swift @@ -14,6 +14,7 @@ import XCTest @testable import NIOPosix +import Atomics import Dispatch import NIOConcurrencyHelpers import NIOEmbedded @@ -110,6 +111,21 @@ class NIOThreadPoolTest: XCTestCase { } } + func testAsyncThreadPool() async throws { + guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { throw XCTSkip() } + let numberOfThreads = 1 + let pool = NIOThreadPool(numberOfThreads: numberOfThreads) + pool.start() + do { + let hitCount = ManagedAtomic(false) + try await pool.runIfActive { + hitCount.store(true, ordering: .relaxed) + } + XCTAssertEqual(hitCount.load(ordering: .relaxed), true) + } catch {} + try await pool.shutdownGracefully() + } + func testAsyncShutdownWorks() async throws { guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { throw XCTSkip() } let threadPool = NIOThreadPool(numberOfThreads: 17)