From 60f1f10cfcd14d181020e1e0813422ca7d5e2a7d Mon Sep 17 00:00:00 2001 From: Adam Fowler Date: Wed, 20 Mar 2024 08:36:57 +0000 Subject: [PATCH] WebSocket Router (#41) * WebSocket Router * reorder parameters * Changed comment in testNotWebSocket * Catch websocket upgrade fail and write methodNotAllowed * Cleaner websocket fail * Move maxFrameSize into configuration type * Added WebSocketUpgradeMiddleware --- Package.swift | 3 +- Snippets/WebsocketTest.swift | 30 ++- .../Client/WebSocketClient.swift | 39 +--- .../Client/WebSocketClientChannel.swift | 18 +- .../Client/WebSocketClientConfiguration.swift | 34 +++ .../NIOWebSocketServerUpgrade+ext.swift | 19 +- ...elHandler.swift => WebSocketChannel.swift} | 78 +++++-- .../Server/WebSocketHTTPChannelBuilder.swift | 19 +- .../Server/WebSocketRouter.swift | 210 ++++++++++++++++++ .../Server/WebSocketServerConfiguration.swift | 29 +++ .../WebSocketContext.swift | 6 +- .../WebSocketDataHandler.swift | 4 +- .../WebSocketTests.swift | 145 +++++++++++- 13 files changed, 522 insertions(+), 112 deletions(-) create mode 100644 Sources/HummingbirdWebSocket/Client/WebSocketClientConfiguration.swift rename Sources/HummingbirdWebSocket/Server/{WebSocketChannelHandler.swift => WebSocketChannel.swift} (70%) create mode 100644 Sources/HummingbirdWebSocket/Server/WebSocketRouter.swift create mode 100644 Sources/HummingbirdWebSocket/Server/WebSocketServerConfiguration.swift diff --git a/Package.swift b/Package.swift index 5b888ae..bda3e6a 100644 --- a/Package.swift +++ b/Package.swift @@ -16,7 +16,7 @@ let package = Package( .package(url: "https://github.com/apple/swift-atomics.git", from: "1.0.0"), .package(url: "https://github.com/apple/swift-http-types.git", from: "1.0.0"), .package(url: "https://github.com/apple/swift-nio.git", from: "2.62.0"), - .package(url: "https://github.com/apple/swift-nio-extras.git", from: "1.21.0"), + .package(url: "https://github.com/apple/swift-nio-extras.git", from: "1.22.0"), .package(url: "https://github.com/apple/swift-nio-ssl.git", from: "2.5.0"), .package(url: "https://github.com/swift-extras/swift-extras-base64.git", from: "0.5.0"), .package(url: "https://github.com/adam-fowler/compress-nio.git", from: "1.0.0"), @@ -41,6 +41,7 @@ let package = Package( // .byName(name: "HummingbirdWSCompression"), .product(name: "Atomics", package: "swift-atomics"), .product(name: "Hummingbird", package: "hummingbird"), + .product(name: "HummingbirdTesting", package: "hummingbird"), .product(name: "HummingbirdTLS", package: "hummingbird"), ]), ] diff --git a/Snippets/WebsocketTest.swift b/Snippets/WebsocketTest.swift index da2b947..fad3fce 100644 --- a/Snippets/WebsocketTest.swift +++ b/Snippets/WebsocketTest.swift @@ -1,28 +1,24 @@ +import HTTPTypes import Hummingbird import HummingbirdWebSocket -import NIOHTTP1 -let router = Router() +let router = Router(context: BasicWebSocketRequestContext.self) +router.middlewares.add(FileMiddleware("Snippets/public")) router.get { _, _ in "Hello" } -router.middlewares.add(FileMiddleware("Snippets/public")) -let app = Application( - router: router, - server: .webSocketUpgrade { _, head in - if head.uri == "/ws" { - return .upgrade(HTTPHeaders()) { inbound, outbound, _ in - for try await packet in inbound { - if case .text("disconnect") = packet { - break - } - try await outbound.write(.custom(packet.webSocketFrame)) - } - } - } else { - return .dontUpgrade +router.ws("/ws") { inbound, outbound, _ in + for try await packet in inbound { + if case .text("disconnect") = packet { + break } + try await outbound.write(.custom(packet.webSocketFrame)) } +} + +let app = Application( + router: router, + server: .webSocketUpgrade(webSocketRouter: router) ) try await app.runService() diff --git a/Sources/HummingbirdWebSocket/Client/WebSocketClient.swift b/Sources/HummingbirdWebSocket/Client/WebSocketClient.swift index 4e1f626..e5b3720 100644 --- a/Sources/HummingbirdWebSocket/Client/WebSocketClient.swift +++ b/Sources/HummingbirdWebSocket/Client/WebSocketClient.swift @@ -40,25 +40,6 @@ import ServiceLifecycle /// } /// ``` public struct WebSocketClient { - public struct Configuration: Sendable { - /// Max websocket frame size that can be sent/received - public var maxFrameSize: Int - /// Additional headers to be sent with the initial HTTP request - public var additionalHeaders: HTTPFields - - /// Initialize WebSocketClient configuration - /// - Paramters - /// - maxFrameSize: Max websocket frame size that can be sent/received - /// - additionalHeaders: Additional headers to be sent with the initial HTTP request - public init( - maxFrameSize: Int = (1 << 14), - additionalHeaders: HTTPFields = .init() - ) { - self.maxFrameSize = maxFrameSize - self.additionalHeaders = additionalHeaders - } - } - enum MultiPlatformTLSConfiguration: Sendable { case niossl(TLSConfiguration) #if canImport(Network) @@ -71,7 +52,7 @@ public struct WebSocketClient { /// WebSocket data handler let handler: WebSocketDataCallbackHandler /// configuration - let configuration: Configuration + let configuration: WebSocketClientConfiguration /// EventLoopGroup to use let eventLoopGroup: EventLoopGroup /// Logger @@ -90,7 +71,7 @@ public struct WebSocketClient { /// - logger: Logger public init( url: URI, - configuration: Configuration = .init(), + configuration: WebSocketClientConfiguration = .init(), tlsConfiguration: TLSConfiguration? = nil, eventLoopGroup: EventLoopGroup = MultiThreadedEventLoopGroup.singleton, logger: Logger, @@ -116,7 +97,7 @@ public struct WebSocketClient { /// - logger: Logger public init( url: URI, - configuration: Configuration = .init(), + configuration: WebSocketClientConfiguration = .init(), transportServicesTLSOptions: TSTLSOptions, eventLoopGroup: NIOTSEventLoopGroup = NIOTSEventLoopGroup.singleton, logger: Logger, @@ -143,7 +124,7 @@ public struct WebSocketClient { case .niossl(let tlsConfiguration): let client = try ClientConnection( TLSClientChannel( - WebSocketClientChannel(handler: handler, url: urlPath, maxFrameSize: self.configuration.maxFrameSize), + WebSocketClientChannel(handler: handler, url: urlPath, configuration: self.configuration), tlsConfiguration: tlsConfiguration ), address: .hostname(host, port: port), @@ -155,7 +136,7 @@ public struct WebSocketClient { #if canImport(Network) case .ts(let tlsOptions): let client = try ClientConnection( - WebSocketClientChannel(handler: handler, url: urlPath, maxFrameSize: self.configuration.maxFrameSize), + WebSocketClientChannel(handler: handler, url: urlPath, configuration: self.configuration), address: .hostname(host, port: port), transportServicesTLSOptions: tlsOptions, eventLoopGroup: self.eventLoopGroup, @@ -170,8 +151,7 @@ public struct WebSocketClient { WebSocketClientChannel( handler: handler, url: urlPath, - maxFrameSize: self.configuration.maxFrameSize, - additionalHeaders: self.configuration.additionalHeaders + configuration: self.configuration ), tlsConfiguration: TLSConfiguration.makeClientConfiguration() ), @@ -186,8 +166,7 @@ public struct WebSocketClient { WebSocketClientChannel( handler: handler, url: urlPath, - maxFrameSize: self.configuration.maxFrameSize, - additionalHeaders: self.configuration.additionalHeaders + configuration: self.configuration ), address: .hostname(host, port: port), eventLoopGroup: self.eventLoopGroup, @@ -210,7 +189,7 @@ extension WebSocketClient { /// - process: Closure handling webSocket public static func connect( url: URI, - configuration: Configuration = .init(), + configuration: WebSocketClientConfiguration = .init(), tlsConfiguration: TLSConfiguration? = nil, eventLoopGroup: EventLoopGroup = MultiThreadedEventLoopGroup.singleton, logger: Logger, @@ -239,7 +218,7 @@ extension WebSocketClient { /// - process: WebSocket data handler public static func connect( url: URI, - configuration: Configuration = .init(), + configuration: WebSocketClientConfiguration = .init(), transportServicesTLSOptions: TSTLSOptions, eventLoopGroup: NIOTSEventLoopGroup = NIOTSEventLoopGroup.singleton, logger: Logger, diff --git a/Sources/HummingbirdWebSocket/Client/WebSocketClientChannel.swift b/Sources/HummingbirdWebSocket/Client/WebSocketClientChannel.swift index 12b5942..dfed093 100644 --- a/Sources/HummingbirdWebSocket/Client/WebSocketClientChannel.swift +++ b/Sources/HummingbirdWebSocket/Client/WebSocketClientChannel.swift @@ -30,20 +30,18 @@ public struct WebSocketClientChannel: ClientConne let url: String let handler: Handler - let maxFrameSize: Int - let additionalHeaders: HTTPFields + let configuration: WebSocketClientConfiguration - init(handler: Handler, url: String, maxFrameSize: Int = 1 << 14, additionalHeaders: HTTPFields = .init()) { + init(handler: Handler, url: String, configuration: WebSocketClientConfiguration) { self.url = url self.handler = handler - self.maxFrameSize = maxFrameSize - self.additionalHeaders = additionalHeaders + self.configuration = configuration } public func setup(channel: any Channel, logger: Logger) -> NIOCore.EventLoopFuture { channel.eventLoop.makeCompletedFuture { let upgrader = NIOTypedWebSocketClientUpgrader( - maxFrameSize: maxFrameSize, + maxFrameSize: self.configuration.maxFrameSize, upgradePipelineHandler: { channel, _ in channel.eventLoop.makeCompletedFuture { let asyncChannel = try NIOAsyncChannel(wrappingChannelSynchronously: channel) @@ -55,7 +53,7 @@ public struct WebSocketClientChannel: ClientConne var headers = HTTPHeaders() headers.add(name: "Content-Type", value: "text/plain; charset=utf-8") headers.add(name: "Content-Length", value: "0") - let additionalHeaders = HTTPHeaders(self.additionalHeaders) + let additionalHeaders = HTTPHeaders(self.configuration.additionalHeaders) headers.add(contentsOf: additionalHeaders) let requestHead = HTTPRequestHead( @@ -85,9 +83,9 @@ public struct WebSocketClientChannel: ClientConne public func handle(value: Value, logger: Logger) async throws { switch try await value.get() { - case .websocket(let websocketChannel): - let webSocket = WebSocketHandler(asyncChannel: websocketChannel, type: .client) - let context = self.handler.alreadySetupContext ?? .init(logger: logger, allocator: websocketChannel.channel.allocator) + case .websocket(let webSocketChannel): + let webSocket = WebSocketHandler(asyncChannel: webSocketChannel, type: .client) + let context = self.handler.alreadySetupContext ?? .init(channel: webSocketChannel.channel, logger: logger) await webSocket.handle(handler: self.handler, context: context) case .notUpgraded: // The upgrade to websocket did not succeed. diff --git a/Sources/HummingbirdWebSocket/Client/WebSocketClientConfiguration.swift b/Sources/HummingbirdWebSocket/Client/WebSocketClientConfiguration.swift new file mode 100644 index 0000000..d4085fe --- /dev/null +++ b/Sources/HummingbirdWebSocket/Client/WebSocketClientConfiguration.swift @@ -0,0 +1,34 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the Hummingbird server framework project +// +// Copyright (c) 2023-2024 the Hummingbird authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See hummingbird/CONTRIBUTORS.txt for the list of Hummingbird authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// + +import HTTPTypes + +public struct WebSocketClientConfiguration: Sendable { + /// Max websocket frame size that can be sent/received + public var maxFrameSize: Int + /// Additional headers to be sent with the initial HTTP request + public var additionalHeaders: HTTPFields + + /// Initialize WebSocketClient configuration + /// - Paramters + /// - maxFrameSize: Max websocket frame size that can be sent/received + /// - additionalHeaders: Additional headers to be sent with the initial HTTP request + public init( + maxFrameSize: Int = (1 << 14), + additionalHeaders: HTTPFields = .init() + ) { + self.maxFrameSize = maxFrameSize + self.additionalHeaders = additionalHeaders + } +} diff --git a/Sources/HummingbirdWebSocket/Server/NIOWebSocketServerUpgrade+ext.swift b/Sources/HummingbirdWebSocket/Server/NIOWebSocketServerUpgrade+ext.swift index 2110e28..9994496 100644 --- a/Sources/HummingbirdWebSocket/Server/NIOWebSocketServerUpgrade+ext.swift +++ b/Sources/HummingbirdWebSocket/Server/NIOWebSocketServerUpgrade+ext.swift @@ -12,14 +12,17 @@ // //===----------------------------------------------------------------------===// +import HTTPTypes import NIOConcurrencyHelpers import NIOCore import NIOHTTP1 +import NIOHTTPTypesHTTP1 import NIOWebSocket +/// Should HTTP channel upgrade to WebSocket public enum ShouldUpgradeResult: Sendable { case dontUpgrade - case upgrade(HTTPHeaders, Value) + case upgrade(HTTPFields, Value) } extension NIOTypedWebSocketServerUpgrader { @@ -47,21 +50,27 @@ extension NIOTypedWebSocketServerUpgrader { public convenience init( maxFrameSize: Int = 1 << 14, enableAutomaticErrorHandling: Bool = true, - shouldUpgrade: @escaping @Sendable (Channel, HTTPRequestHead) -> EventLoopFuture>, + shouldUpgrade: @escaping @Sendable (Channel, HTTPRequest) -> EventLoopFuture>, upgradePipelineHandler: @escaping @Sendable (Channel, Value) -> EventLoopFuture ) { let shouldUpgradeResult = NIOLockedValueBox(nil) self.init( maxFrameSize: maxFrameSize, enableAutomaticErrorHandling: enableAutomaticErrorHandling, - shouldUpgrade: { channel, head in - shouldUpgrade(channel, head).map { result in + shouldUpgrade: { (channel, head: HTTPRequestHead) in + let request: HTTPRequest + do { + request = try HTTPRequest(head, secure: false, splitCookie: false) + } catch { + return channel.eventLoop.makeFailedFuture(error) + } + return shouldUpgrade(channel, request).map { result in switch result { case .dontUpgrade: return nil case .upgrade(let headers, let value): shouldUpgradeResult.withLockedValue { $0 = value } - return headers + return .init(headers) } } }, diff --git a/Sources/HummingbirdWebSocket/Server/WebSocketChannelHandler.swift b/Sources/HummingbirdWebSocket/Server/WebSocketChannel.swift similarity index 70% rename from Sources/HummingbirdWebSocket/Server/WebSocketChannelHandler.swift rename to Sources/HummingbirdWebSocket/Server/WebSocketChannel.swift index c2ed5c4..0f55d9b 100644 --- a/Sources/HummingbirdWebSocket/Server/WebSocketChannelHandler.swift +++ b/Sources/HummingbirdWebSocket/Server/WebSocketChannel.swift @@ -27,12 +27,12 @@ public struct HTTP1AndWebSocketChannel: ServerChi /// Upgrade result (either a websocket AsyncChannel, or an HTTP1 AsyncChannel) public enum UpgradeResult { case websocket(NIOAsyncChannel, Handler) - case notUpgraded(NIOAsyncChannel) + case notUpgraded(NIOAsyncChannel, failed: Bool) } public typealias Value = EventLoopFuture - /// Initialize HTTP1AndWebSocketChannel with async `shouldUpgrade` function + /// Initialize HTTP1AndWebSocketChannel with synchronous `shouldUpgrade` function /// - Parameters: /// - additionalChannelHandlers: Additional channel handlers to add /// - responder: HTTP responder @@ -40,22 +40,22 @@ public struct HTTP1AndWebSocketChannel: ServerChi /// - shouldUpgrade: Function returning whether upgrade should be allowed /// - Returns: Upgrade result future public init( + responder: @escaping @Sendable (Request, Channel) async throws -> Response, + configuration: WebSocketServerConfiguration, additionalChannelHandlers: @escaping @Sendable () -> [any RemovableChannelHandler] = { [] }, - responder: @escaping @Sendable (Request, Channel) async throws -> Response = { _, _ in throw HTTPError(.notImplemented) }, - maxFrameSize: Int = (1 << 14), - shouldUpgrade: @escaping @Sendable (Channel, HTTPRequestHead) throws -> ShouldUpgradeResult + shouldUpgrade: @escaping @Sendable (HTTPRequest, Channel, Logger) throws -> ShouldUpgradeResult ) { self.additionalChannelHandlers = additionalChannelHandlers - self.maxFrameSize = maxFrameSize - self.shouldUpgrade = { channel, head in + self.configuration = configuration + self.shouldUpgrade = { head, channel, logger in channel.eventLoop.makeCompletedFuture { - try shouldUpgrade(channel, head) + try shouldUpgrade(head, channel, logger) } } self.responder = responder } - /// Initialize HTTP1AndWebSocketChannel with synchronous `shouldUpgrade` function + /// Initialize HTTP1AndWebSocketChannel with async `shouldUpgrade` function /// - Parameters: /// - additionalChannelHandlers: Additional channel handlers to add /// - responder: HTTP responder @@ -63,17 +63,17 @@ public struct HTTP1AndWebSocketChannel: ServerChi /// - shouldUpgrade: Function returning whether upgrade should be allowed /// - Returns: Upgrade result future public init( + responder: @escaping @Sendable (Request, Channel) async throws -> Response, + configuration: WebSocketServerConfiguration, additionalChannelHandlers: @escaping @Sendable () -> [any RemovableChannelHandler] = { [] }, - responder: @escaping @Sendable (Request, Channel) async throws -> Response = { _, _ in throw HTTPError(.notImplemented) }, - maxFrameSize: Int = (1 << 14), - shouldUpgrade: @escaping @Sendable (Channel, HTTPRequestHead) async throws -> ShouldUpgradeResult + shouldUpgrade: @escaping @Sendable (HTTPRequest, Channel, Logger) async throws -> ShouldUpgradeResult ) { self.additionalChannelHandlers = additionalChannelHandlers - self.maxFrameSize = maxFrameSize - self.shouldUpgrade = { channel, head in + self.configuration = configuration + self.shouldUpgrade = { head, channel, logger in let promise = channel.eventLoop.makePromise(of: ShouldUpgradeResult.self) promise.completeWithTask { - try await shouldUpgrade(channel, head) + try await shouldUpgrade(head, channel, logger) } return promise.futureResult } @@ -88,10 +88,12 @@ public struct HTTP1AndWebSocketChannel: ServerChi /// - Returns: Negotiated result future public func setup(channel: Channel, logger: Logger) -> EventLoopFuture { return channel.eventLoop.makeCompletedFuture { + let upgradeAttempted = NIOLoopBoundBox(false, eventLoop: channel.eventLoop) let upgrader = NIOTypedWebSocketServerUpgrader( - maxFrameSize: self.maxFrameSize, + maxFrameSize: self.configuration.maxFrameSize, shouldUpgrade: { channel, head in - self.shouldUpgrade(channel, head) + upgradeAttempted.value = true + return self.shouldUpgrade(head, channel, logger) }, upgradePipelineHandler: { channel, handler in channel.eventLoop.makeCompletedFuture { @@ -111,7 +113,7 @@ public struct HTTP1AndWebSocketChannel: ServerChi return channel.eventLoop.makeCompletedFuture { try channel.pipeline.syncOperations.addHandlers(childChannelHandlers) let asyncChannel = try NIOAsyncChannel(wrappingChannelSynchronously: channel) - return UpgradeResult.notUpgraded(asyncChannel) + return UpgradeResult.notUpgraded(asyncChannel, failed: upgradeAttempted.value) } } ) @@ -132,11 +134,15 @@ public struct HTTP1AndWebSocketChannel: ServerChi do { let result = try await upgradeResult.get() switch result { - case .notUpgraded(let http1): - await handleHTTP(asyncChannel: http1, logger: logger) + case .notUpgraded(let http1, let failed): + if failed { + await self.write405(asyncChannel: http1, logger: logger) + } else { + await self.handleHTTP(asyncChannel: http1, logger: logger) + } case .websocket(let asyncChannel, let handler): let webSocket = WebSocketHandler(asyncChannel: asyncChannel, type: .server) - let context = handler.alreadySetupContext ?? .init(logger: logger, allocator: asyncChannel.channel.allocator) + let context = handler.alreadySetupContext ?? .init(channel: asyncChannel.channel, logger: logger) await webSocket.handle(handler: handler, context: context) } } catch { @@ -144,8 +150,34 @@ public struct HTTP1AndWebSocketChannel: ServerChi } } + /// Upgrade failed we should write a 405 + private func write405(asyncChannel: NIOAsyncChannel, logger: Logger) async { + do { + try await asyncChannel.executeThenClose { _, outbound in + let headers: HTTPFields = [ + .connection: "close", + .contentLength: "0", + ] + let head = HTTPResponse( + status: .methodNotAllowed, + headerFields: headers + ) + + try await outbound.write( + contentsOf: [ + .head(head), + .end(nil), + ] + ) + } + } catch { + // we got here because we failed to either read or write to the channel + logger.trace("Failed to write to Channel. Error: \(error)") + } + } + public var responder: @Sendable (Request, Channel) async throws -> Response - let shouldUpgrade: @Sendable (Channel, HTTPRequestHead) -> EventLoopFuture> - let maxFrameSize: Int + let shouldUpgrade: @Sendable (HTTPRequest, Channel, Logger) -> EventLoopFuture> + let configuration: WebSocketServerConfiguration let additionalChannelHandlers: @Sendable () -> [any RemovableChannelHandler] } diff --git a/Sources/HummingbirdWebSocket/Server/WebSocketHTTPChannelBuilder.swift b/Sources/HummingbirdWebSocket/Server/WebSocketHTTPChannelBuilder.swift index 7173d05..c02d4f2 100644 --- a/Sources/HummingbirdWebSocket/Server/WebSocketHTTPChannelBuilder.swift +++ b/Sources/HummingbirdWebSocket/Server/WebSocketHTTPChannelBuilder.swift @@ -12,23 +12,24 @@ // //===----------------------------------------------------------------------===// +import HTTPTypes import HummingbirdCore +import Logging import NIOCore -import NIOHTTP1 extension HTTPChannelBuilder { /// HTTP1 channel builder supporting a websocket upgrade /// - parameters public static func webSocketUpgrade( + configuration: WebSocketServerConfiguration = .init(), additionalChannelHandlers: @autoclosure @escaping @Sendable () -> [any RemovableChannelHandler] = [], - maxFrameSize: Int = 1 << 14, - shouldUpgrade: @escaping @Sendable (Channel, HTTPRequestHead) async throws -> ShouldUpgradeResult + shouldUpgrade: @escaping @Sendable (HTTPRequest, Channel, Logger) async throws -> ShouldUpgradeResult ) -> HTTPChannelBuilder> { return .init { responder in return HTTP1AndWebSocketChannel( - additionalChannelHandlers: additionalChannelHandlers, responder: responder, - maxFrameSize: maxFrameSize, + configuration: configuration, + additionalChannelHandlers: additionalChannelHandlers, shouldUpgrade: shouldUpgrade ) } @@ -36,15 +37,15 @@ extension HTTPChannelBuilder { /// HTTP1 channel builder supporting a websocket upgrade public static func webSocketUpgrade( + configuration: WebSocketServerConfiguration = .init(), additionalChannelHandlers: @autoclosure @escaping @Sendable () -> [any RemovableChannelHandler] = [], - maxFrameSize: Int = 1 << 14, - shouldUpgrade: @escaping @Sendable (Channel, HTTPRequestHead) throws -> ShouldUpgradeResult + shouldUpgrade: @escaping @Sendable (HTTPRequest, Channel, Logger) throws -> ShouldUpgradeResult ) -> HTTPChannelBuilder> { return .init { responder in return HTTP1AndWebSocketChannel( - additionalChannelHandlers: additionalChannelHandlers, responder: responder, - maxFrameSize: maxFrameSize, + configuration: configuration, + additionalChannelHandlers: additionalChannelHandlers, shouldUpgrade: shouldUpgrade ) } diff --git a/Sources/HummingbirdWebSocket/Server/WebSocketRouter.swift b/Sources/HummingbirdWebSocket/Server/WebSocketRouter.swift new file mode 100644 index 0000000..eee133a --- /dev/null +++ b/Sources/HummingbirdWebSocket/Server/WebSocketRouter.swift @@ -0,0 +1,210 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the Hummingbird server framework project +// +// Copyright (c) 2023-2024 the Hummingbird authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See hummingbird/CONTRIBUTORS.txt for the list of Hummingbird authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// + +import Atomics +import HTTPTypes +import Hummingbird +import HummingbirdCore +import Logging +import NIOConcurrencyHelpers +import NIOCore + +/// WebSocket Router context type. +/// +/// Includes reference to optional websocket handler +public struct WebSocketRouterContext: Sendable { + public init() { + self.handler = .init(nil) + } + + let handler: NIOLockedValueBox +} + +/// Request context protocol requirement for routers that support websockets +public protocol WebSocketRequestContext: RequestContext, WebSocketContextProtocol { + var webSocket: WebSocketRouterContext { get } +} + +/// Default implementation of a request context that supports WebSockets +public struct BasicWebSocketRequestContext: WebSocketRequestContext { + public var coreContext: CoreRequestContext + public let webSocket: WebSocketRouterContext + + public init(channel: Channel, logger: Logger) { + self.coreContext = .init(allocator: channel.allocator, logger: logger) + self.webSocket = .init() + } +} + +/// Enum indicating whether a router `shouldUpgrade` function expects a +/// WebSocket upgrade or not +public enum RouterShouldUpgrade: Sendable { + case dontUpgrade + case upgrade(HTTPFields) +} + +extension RouterMethods { + /// Add path to router that support WebSocket upgrade + /// - Parameters: + /// - path: Path to match + /// - shouldUpgrade: Should request be upgraded + /// - handle: WebSocket channel handler + @discardableResult public func ws( + _ path: String = "", + shouldUpgrade: @Sendable @escaping (Request, Context) async throws -> RouterShouldUpgrade = { _, _ in .upgrade([:]) }, + handle: @escaping WebSocketDataCallbackHandler.Callback + ) -> Self where Context: WebSocketRequestContext { + return on(path, method: .get) { request, context -> Response in + let result = try await shouldUpgrade(request, context) + switch result { + case .dontUpgrade: + return .init(status: .methodNotAllowed) + case .upgrade(let headers): + context.webSocket.handler.withLockedValue { $0 = WebSocketDataCallbackHandler(handle) } + return .init(status: .ok, headers: headers) + } + } + } +} + +/// An alternative way to add a WebSocket upgrade to a router via Middleware +/// +/// This is primarily designed to be used with ``HummingbirdRouter/RouterBuilder`` but can be used +/// with ``Hummingbird/Router`` if you add a route immediately after it. +public struct WebSocketUpgradeMiddleware: RouterMiddleware { + let shouldUpgrade: @Sendable (Request, Context) async throws -> RouterShouldUpgrade + let handle: WebSocketDataCallbackHandler.Callback + + /// Initialize WebSocketUpgradeMiddleare + /// - Parameters: + /// - shouldUpgrade: Return whether the WebSocket upgrade should occur + /// - handle: WebSocket handler + public init( + shouldUpgrade: @Sendable @escaping (Request, Context) async throws -> RouterShouldUpgrade = { _, _ in .upgrade([:]) }, + handle: @escaping WebSocketDataCallbackHandler.Callback + ) { + self.shouldUpgrade = shouldUpgrade + self.handle = handle + } + + /// WebSocketUpgradeMiddleware handler + public func handle(_ request: Request, context: Context, next: (Request, Context) async throws -> Response) async throws -> Response { + let result = try await shouldUpgrade(request, context) + switch result { + case .dontUpgrade: + return .init(status: .methodNotAllowed) + case .upgrade(let headers): + context.webSocket.handler.withLockedValue { $0 = WebSocketDataCallbackHandler(self.handle) } + return .init(status: .ok, headers: headers) + } + } +} + +extension HTTP1AndWebSocketChannel { + /// Initialize HTTP1AndWebSocketChannel with async `shouldUpgrade` function + /// - Parameters: + /// - additionalChannelHandlers: Additional channel handlers to add + /// - responder: HTTP responder + /// - maxFrameSize: Max frame size WebSocket will allow + /// - webSocketRouter: WebSocket router + /// - Returns: Upgrade result future + public init( + responder: @escaping @Sendable (Request, Channel) async throws -> Response, + webSocketResponder: WSResponder, + configuration: WebSocketServerConfiguration, + additionalChannelHandlers: @escaping @Sendable () -> [any RemovableChannelHandler] = { [] } + ) where Handler == WebSocketDataCallbackHandler, WSResponder.Context == Context { + self.init(responder: responder, configuration: configuration, additionalChannelHandlers: additionalChannelHandlers) { head, channel, logger in + let request = Request(head: head, body: .init(buffer: .init())) + let context = Context(channel: channel, logger: logger.with(metadataKey: "hb_id", value: .stringConvertible(RequestID()))) + do { + let response = try await webSocketResponder.respond(to: request, context: context) + if response.status == .ok, let webSocketHandler = context.webSocket.handler.withLockedValue({ $0 }) { + return .upgrade(response.headers, webSocketHandler) + } else { + return .dontUpgrade + } + } catch { + return .dontUpgrade + } + } + } +} + +extension HTTPChannelBuilder { + /// HTTP1 channel builder supporting a websocket upgrade + /// + /// With this function you provide a separate router from the one you have supplied + /// to ``Hummingbird/Application``. You can provide the same router as is used for + /// standard HTTP routing, but it is preferable that you supply a separate one to + /// avoid attempting to match against paths which will never produce a WebSocket upgrade. + /// - Parameters: + /// - webSocketRouter: Router used for testing whether a WebSocket upgrade should occur + /// - configuration: WebSocket server configuration + /// - additionalChannelHandlers: Additional channel handlers to add to channel pipeline + /// - Returns: + public static func webSocketUpgrade( + webSocketRouter: WSResponderBuilder, + configuration: WebSocketServerConfiguration = .init(), + additionalChannelHandlers: @autoclosure @escaping @Sendable () -> [any RemovableChannelHandler] = [] + ) -> HTTPChannelBuilder> where WSResponderBuilder.Responder.Context: WebSocketRequestContext { + let webSocketReponder = webSocketRouter.buildResponder() + return .init { responder in + return HTTP1AndWebSocketChannel( + responder: responder, + webSocketResponder: webSocketReponder, + configuration: configuration, + additionalChannelHandlers: additionalChannelHandlers + ) + } + } +} + +extension Logger { + /// Create new Logger with additional metadata value + /// - Parameters: + /// - metadataKey: Metadata key + /// - value: Metadata value + /// - Returns: Logger + func with(metadataKey: String, value: MetadataValue) -> Logger { + var logger = self + logger[metadataKey: metadataKey] = value + return logger + } +} + +/// Generate Unique ID for each request. This is a duplicate of the RequestID in Hummingbird +package struct RequestID: CustomStringConvertible { + let low: UInt64 + + package init() { + self.low = Self.globalRequestID.loadThenWrappingIncrement(by: 1, ordering: .relaxed) + } + + package var description: String { + Self.high + self.formatAsHexWithLeadingZeros(self.low) + } + + func formatAsHexWithLeadingZeros(_ value: UInt64) -> String { + let string = String(value, radix: 16) + if string.count < 16 { + return String(repeating: "0", count: 16 - string.count) + string + } else { + return string + } + } + + private static let high = String(UInt64.random(in: .min ... .max), radix: 16) + private static let globalRequestID = ManagedAtomic(UInt64.random(in: .min ... .max)) +} diff --git a/Sources/HummingbirdWebSocket/Server/WebSocketServerConfiguration.swift b/Sources/HummingbirdWebSocket/Server/WebSocketServerConfiguration.swift new file mode 100644 index 0000000..8d7e471 --- /dev/null +++ b/Sources/HummingbirdWebSocket/Server/WebSocketServerConfiguration.swift @@ -0,0 +1,29 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the Hummingbird server framework project +// +// Copyright (c) 2023-2024 the Hummingbird authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See hummingbird/CONTRIBUTORS.txt for the list of Hummingbird authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// + +/// Configuration for a WebSocket server +public struct WebSocketServerConfiguration: Sendable { + /// Max websocket frame size that can be sent/received + public var maxFrameSize: Int + + /// Initialize WebSocketClient configuration + /// - Paramters + /// - maxFrameSize: Max websocket frame size that can be sent/received + /// - additionalHeaders: Additional headers to be sent with the initial HTTP request + public init( + maxFrameSize: Int = (1 << 14) + ) { + self.maxFrameSize = maxFrameSize + } +} diff --git a/Sources/HummingbirdWebSocket/WebSocketContext.swift b/Sources/HummingbirdWebSocket/WebSocketContext.swift index 141e2f3..24952bd 100644 --- a/Sources/HummingbirdWebSocket/WebSocketContext.swift +++ b/Sources/HummingbirdWebSocket/WebSocketContext.swift @@ -19,7 +19,7 @@ import NIOCore public protocol WebSocketContextProtocol: Sendable { var logger: Logger { get } var allocator: ByteBufferAllocator { get } - init(logger: Logger, allocator: ByteBufferAllocator) + init(channel: Channel, logger: Logger) } /// Default implementation of ``WebSocketContextProtocol`` @@ -27,8 +27,8 @@ public struct WebSocketContext: WebSocketContextProtocol { public let logger: Logger public let allocator: ByteBufferAllocator - public init(logger: Logger, allocator: ByteBufferAllocator) { + public init(channel: Channel, logger: Logger) { self.logger = logger - self.allocator = allocator + self.allocator = channel.allocator } } diff --git a/Sources/HummingbirdWebSocket/WebSocketDataHandler.swift b/Sources/HummingbirdWebSocket/WebSocketDataHandler.swift index 4205374..9c3321a 100644 --- a/Sources/HummingbirdWebSocket/WebSocketDataHandler.swift +++ b/Sources/HummingbirdWebSocket/WebSocketDataHandler.swift @@ -13,8 +13,8 @@ //===----------------------------------------------------------------------===// import AsyncAlgorithms +import HTTPTypes import NIOCore -import NIOHTTP1 import NIOWebSocket /// Protocol for web socket data handling @@ -63,7 +63,7 @@ public struct WebSocketDataCallbackHandler: WebSocketDataHandler { extension ShouldUpgradeResult where Value == WebSocketDataCallbackHandler { /// Extension to ShouldUpgradeResult that takes just a callback - public static func upgrade(_ headers: HTTPHeaders, _ callback: @escaping WebSocketDataCallbackHandler.Callback) -> Self { + public static func upgrade(_ headers: HTTPFields, _ callback: @escaping WebSocketDataCallbackHandler.Callback) -> Self { .upgrade(headers, WebSocketDataCallbackHandler(callback)) } } diff --git a/Tests/HummingbirdWebSocketTests/WebSocketTests.swift b/Tests/HummingbirdWebSocketTests/WebSocketTests.swift index 88caadf..e286221 100644 --- a/Tests/HummingbirdWebSocketTests/WebSocketTests.swift +++ b/Tests/HummingbirdWebSocketTests/WebSocketTests.swift @@ -15,11 +15,11 @@ import HTTPTypes import Hummingbird import HummingbirdCore +import HummingbirdTesting import HummingbirdTLS import HummingbirdWebSocket import Logging import NIOCore -import NIOHTTP1 import NIOPosix import ServiceLifecycle import XCTest @@ -77,7 +77,7 @@ final class HummingbirdWebSocketTests: XCTestCase { func testClientAndServer( serverTLSConfiguration: TLSConfiguration? = nil, server serverHandler: @escaping WebSocketDataCallbackHandler.Callback, - shouldUpgrade: @escaping @Sendable (HTTPRequestHead) throws -> HTTPHeaders? = { _ in return [:] }, + shouldUpgrade: @escaping @Sendable (HTTPRequest) throws -> HTTPFields? = { _ in return [:] }, getClient: @escaping @Sendable (Int, Logger) throws -> WebSocketClient ) async throws { try await withThrowingTaskGroup(of: Void.self) { group in @@ -89,7 +89,7 @@ final class HummingbirdWebSocketTests: XCTestCase { }() let router = Router() let serviceGroup: ServiceGroup - let webSocketUpgrade: HTTPChannelBuilder = .webSocketUpgrade { _, head in + let webSocketUpgrade: HTTPChannelBuilder = .webSocketUpgrade { head, _, _ in if let headers = try shouldUpgrade(head) { return .upgrade(headers, WebSocketDataCallbackHandler(serverHandler)) } else { @@ -145,7 +145,7 @@ final class HummingbirdWebSocketTests: XCTestCase { func testClientAndServer( serverTLSConfiguration: TLSConfiguration? = nil, server serverHandler: @escaping WebSocketDataCallbackHandler.Callback, - shouldUpgrade: @escaping @Sendable (HTTPRequestHead) throws -> HTTPHeaders? = { _ in return [:] }, + shouldUpgrade: @escaping @Sendable (HTTPRequest) throws -> HTTPFields? = { _ in return [:] }, client clientHandler: @escaping WebSocketDataCallbackHandler.Callback ) async throws { try await self.testClientAndServer( @@ -162,6 +162,52 @@ final class HummingbirdWebSocketTests: XCTestCase { ) } + func testClientAndServerWithRouter( + webSocketRouter: Router, + uri: URI, + getClient: @escaping @Sendable (Int, Logger) throws -> WebSocketClient + ) async throws { + try await withThrowingTaskGroup(of: Void.self) { group in + let promise = Promise() + let logger = { + var logger = Logger(label: "WebSocketTest") + logger.logLevel = .debug + return logger + }() + let router = Router() + let serviceGroup: ServiceGroup + let app = Application( + router: router, + server: .webSocketUpgrade(webSocketRouter: webSocketRouter), + onServerRunning: { channel in await promise.complete(channel.localAddress!.port!) }, + logger: logger + ) + serviceGroup = ServiceGroup( + configuration: .init( + services: [app], + gracefulShutdownSignals: [.sigterm, .sigint], + logger: app.logger + ) + ) + group.addTask { + try await serviceGroup.run() + } + group.addTask { + let client = try await getClient(promise.wait(), logger) + try await client.run() + } + do { + try await group.next() + await serviceGroup.triggerGracefulShutdown() + } catch { + await serviceGroup.triggerGracefulShutdown() + throw error + } + } + } + + // MARK: Tests + func testServerToClientMessage() async throws { try await self.testClientAndServer { _, outbound, _ in try await outbound.write(.text("Hello")) @@ -211,8 +257,6 @@ final class HummingbirdWebSocketTests: XCTestCase { } func testNotWebSocket() async throws { - // currently disabled as NIO websocket code doesnt shutdown correctly here - try XCTSkipIf(true) do { try await self.testClientAndServer { inbound, _, _ in for try await _ in inbound {} @@ -258,7 +302,7 @@ final class HummingbirdWebSocketTests: XCTestCase { try await self.testClientAndServer { inbound, _, _ in for try await _ in inbound {} } shouldUpgrade: { head in - XCTAssertEqual(head.uri, "/ws") + XCTAssertEqual(head.path, "/ws") return [:] } getClient: { port, logger in try WebSocketClient( @@ -273,8 +317,7 @@ final class HummingbirdWebSocketTests: XCTestCase { try await self.testClientAndServer { inbound, _, _ in for try await _ in inbound {} } shouldUpgrade: { head in - let httpRequest = try HTTPRequest(head, secure: false, splitCookie: false) - let request = Request(head: httpRequest, body: .init(buffer: ByteBuffer())) + let request = Request(head: head, body: .init(buffer: ByteBuffer())) XCTAssertEqual(request.uri.query, "query=parameters&test=true") return [:] } getClient: { port, logger in @@ -290,8 +333,7 @@ final class HummingbirdWebSocketTests: XCTestCase { try await self.testClientAndServer { inbound, _, _ in for try await _ in inbound {} } shouldUpgrade: { head in - let httpRequest = try HTTPRequest(head, secure: false, splitCookie: false) - let request = Request(head: httpRequest, body: .init(buffer: ByteBuffer())) + let request = Request(head: head, body: .init(buffer: ByteBuffer())) XCTAssertEqual(request.headers[.secWebSocketExtensions], "hb") return [:] } getClient: { port, logger in @@ -317,7 +359,7 @@ final class HummingbirdWebSocketTests: XCTestCase { let serviceGroup: ServiceGroup let app = Application( router: router, - server: .webSocketUpgrade { _, _ in + server: .webSocketUpgrade { _, _, _ in return .upgrade([:]) { _, outbound, _ in try await outbound.write(.text("Hello")) } @@ -347,6 +389,85 @@ final class HummingbirdWebSocketTests: XCTestCase { } } + func testRouteSelection() async throws { + let router = Router(context: BasicWebSocketRequestContext.self) + router.ws("/ws1") { _, _ in + return .upgrade([:]) + } handle: { _, outbound, _ in + try await outbound.write(.text("One")) + } + router.ws("/ws2") { _, _ in + return .upgrade([:]) + } handle: { _, outbound, _ in + try await outbound.write(.text("Two")) + } + try await self.testClientAndServerWithRouter(webSocketRouter: router, uri: "localhost:8080") { port, logger in + try WebSocketClient(url: .init("ws://localhost:\(port)/ws1"), logger: logger) { inbound, _, _ in + var inboundIterator = inbound.makeAsyncIterator() + let msg = await inboundIterator.next() + XCTAssertEqual(msg, .text("One")) + } + } + try await self.testClientAndServerWithRouter(webSocketRouter: router, uri: "localhost:8080") { port, logger in + try WebSocketClient(url: .init("ws://localhost:\(port)/ws2"), logger: logger) { inbound, _, _ in + var inboundIterator = inbound.makeAsyncIterator() + let msg = await inboundIterator.next() + XCTAssertEqual(msg, .text("Two")) + } + } + } + + func testWebSocketMiddleware() async throws { + let router = Router(context: BasicWebSocketRequestContext.self) + router.group("/ws") + .add(middleware: WebSocketUpgradeMiddleware { _, _ in + return .upgrade([:]) + } handle: { _, outbound, _ in + try await outbound.write(.text("One")) + }) + .get { _, _ -> Response in return .init(status: .ok) } + do { + try await self.testClientAndServerWithRouter(webSocketRouter: router, uri: "localhost:8080") { port, logger in + try WebSocketClient(url: .init("ws://localhost:\(port)/ws"), logger: logger) { _, _, _ in } + } + } + } + + func testRouteSelectionFail() async throws { + let router = Router(context: BasicWebSocketRequestContext.self) + router.ws("/ws") { _, _ in + return .upgrade([:]) + } handle: { _, outbound, _ in + try await outbound.write(.text("One")) + } + do { + try await self.testClientAndServerWithRouter(webSocketRouter: router, uri: "localhost:8080") { port, logger in + try WebSocketClient(url: .init("ws://localhost:\(port)/not-ws"), logger: logger) { _, _, _ in } + } + } catch let error as WebSocketClientError where error == .webSocketUpgradeFailed {} + } + + func testHTTPRequest() async throws { + let router = Router(context: BasicWebSocketRequestContext.self) + router.ws("/ws") { _, _ in + return .upgrade([:]) + } handle: { _, outbound, _ in + try await outbound.write(.text("Hello")) + } + router.get("/http") { _, _ in + return "Hello" + } + let application = Application( + router: router, + server: .webSocketUpgrade(webSocketRouter: router) + ) + try await application.test(.live) { client in + try await client.execute(uri: "/http", method: .get) { response in + XCTAssertEqual(response.status, .ok) + XCTAssertEqual(String(buffer: response.body), "Hello") + } + } + } /* func testPingPong() throws { let promise = TimeoutPromise(eventLoop: Self.eventLoopGroup.next(), timeout: .seconds(10))