Skip to content

Commit

Permalink
WebSocket Router (#41)
Browse files Browse the repository at this point in the history
* WebSocket Router

* reorder parameters

* Changed comment in testNotWebSocket

* Catch websocket upgrade fail and write methodNotAllowed

* Cleaner websocket fail

* Move maxFrameSize into configuration type

* Added WebSocketUpgradeMiddleware
  • Loading branch information
adam-fowler authored Mar 20, 2024
1 parent de55102 commit 60f1f10
Show file tree
Hide file tree
Showing 13 changed files with 522 additions and 112 deletions.
3 changes: 2 additions & 1 deletion Package.swift
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ let package = Package(
.package(url: "https://github.com/apple/swift-atomics.git", from: "1.0.0"),
.package(url: "https://github.com/apple/swift-http-types.git", from: "1.0.0"),
.package(url: "https://github.com/apple/swift-nio.git", from: "2.62.0"),
.package(url: "https://github.com/apple/swift-nio-extras.git", from: "1.21.0"),
.package(url: "https://github.com/apple/swift-nio-extras.git", from: "1.22.0"),
.package(url: "https://github.com/apple/swift-nio-ssl.git", from: "2.5.0"),
.package(url: "https://github.com/swift-extras/swift-extras-base64.git", from: "0.5.0"),
.package(url: "https://github.com/adam-fowler/compress-nio.git", from: "1.0.0"),
Expand All @@ -41,6 +41,7 @@ let package = Package(
// .byName(name: "HummingbirdWSCompression"),
.product(name: "Atomics", package: "swift-atomics"),
.product(name: "Hummingbird", package: "hummingbird"),
.product(name: "HummingbirdTesting", package: "hummingbird"),
.product(name: "HummingbirdTLS", package: "hummingbird"),
]),
]
Expand Down
30 changes: 13 additions & 17 deletions Snippets/WebsocketTest.swift
Original file line number Diff line number Diff line change
@@ -1,28 +1,24 @@
import HTTPTypes
import Hummingbird
import HummingbirdWebSocket
import NIOHTTP1

let router = Router()
let router = Router(context: BasicWebSocketRequestContext.self)
router.middlewares.add(FileMiddleware("Snippets/public"))
router.get { _, _ in
"Hello"
}

router.middlewares.add(FileMiddleware("Snippets/public"))
let app = Application(
router: router,
server: .webSocketUpgrade { _, head in
if head.uri == "/ws" {
return .upgrade(HTTPHeaders()) { inbound, outbound, _ in
for try await packet in inbound {
if case .text("disconnect") = packet {
break
}
try await outbound.write(.custom(packet.webSocketFrame))
}
}
} else {
return .dontUpgrade
router.ws("/ws") { inbound, outbound, _ in
for try await packet in inbound {
if case .text("disconnect") = packet {
break
}
try await outbound.write(.custom(packet.webSocketFrame))
}
}

let app = Application(
router: router,
server: .webSocketUpgrade(webSocketRouter: router)
)
try await app.runService()
39 changes: 9 additions & 30 deletions Sources/HummingbirdWebSocket/Client/WebSocketClient.swift
Original file line number Diff line number Diff line change
Expand Up @@ -40,25 +40,6 @@ import ServiceLifecycle
/// }
/// ```
public struct WebSocketClient {
public struct Configuration: Sendable {
/// Max websocket frame size that can be sent/received
public var maxFrameSize: Int
/// Additional headers to be sent with the initial HTTP request
public var additionalHeaders: HTTPFields

/// Initialize WebSocketClient configuration
/// - Paramters
/// - maxFrameSize: Max websocket frame size that can be sent/received
/// - additionalHeaders: Additional headers to be sent with the initial HTTP request
public init(
maxFrameSize: Int = (1 << 14),
additionalHeaders: HTTPFields = .init()
) {
self.maxFrameSize = maxFrameSize
self.additionalHeaders = additionalHeaders
}
}

enum MultiPlatformTLSConfiguration: Sendable {
case niossl(TLSConfiguration)
#if canImport(Network)
Expand All @@ -71,7 +52,7 @@ public struct WebSocketClient {
/// WebSocket data handler
let handler: WebSocketDataCallbackHandler
/// configuration
let configuration: Configuration
let configuration: WebSocketClientConfiguration
/// EventLoopGroup to use
let eventLoopGroup: EventLoopGroup
/// Logger
Expand All @@ -90,7 +71,7 @@ public struct WebSocketClient {
/// - logger: Logger
public init(
url: URI,
configuration: Configuration = .init(),
configuration: WebSocketClientConfiguration = .init(),
tlsConfiguration: TLSConfiguration? = nil,
eventLoopGroup: EventLoopGroup = MultiThreadedEventLoopGroup.singleton,
logger: Logger,
Expand All @@ -116,7 +97,7 @@ public struct WebSocketClient {
/// - logger: Logger
public init(
url: URI,
configuration: Configuration = .init(),
configuration: WebSocketClientConfiguration = .init(),
transportServicesTLSOptions: TSTLSOptions,
eventLoopGroup: NIOTSEventLoopGroup = NIOTSEventLoopGroup.singleton,
logger: Logger,
Expand All @@ -143,7 +124,7 @@ public struct WebSocketClient {
case .niossl(let tlsConfiguration):
let client = try ClientConnection(
TLSClientChannel(
WebSocketClientChannel(handler: handler, url: urlPath, maxFrameSize: self.configuration.maxFrameSize),
WebSocketClientChannel(handler: handler, url: urlPath, configuration: self.configuration),
tlsConfiguration: tlsConfiguration
),
address: .hostname(host, port: port),
Expand All @@ -155,7 +136,7 @@ public struct WebSocketClient {
#if canImport(Network)
case .ts(let tlsOptions):
let client = try ClientConnection(
WebSocketClientChannel(handler: handler, url: urlPath, maxFrameSize: self.configuration.maxFrameSize),
WebSocketClientChannel(handler: handler, url: urlPath, configuration: self.configuration),
address: .hostname(host, port: port),
transportServicesTLSOptions: tlsOptions,
eventLoopGroup: self.eventLoopGroup,
Expand All @@ -170,8 +151,7 @@ public struct WebSocketClient {
WebSocketClientChannel(
handler: handler,
url: urlPath,
maxFrameSize: self.configuration.maxFrameSize,
additionalHeaders: self.configuration.additionalHeaders
configuration: self.configuration
),
tlsConfiguration: TLSConfiguration.makeClientConfiguration()
),
Expand All @@ -186,8 +166,7 @@ public struct WebSocketClient {
WebSocketClientChannel(
handler: handler,
url: urlPath,
maxFrameSize: self.configuration.maxFrameSize,
additionalHeaders: self.configuration.additionalHeaders
configuration: self.configuration
),
address: .hostname(host, port: port),
eventLoopGroup: self.eventLoopGroup,
Expand All @@ -210,7 +189,7 @@ extension WebSocketClient {
/// - process: Closure handling webSocket
public static func connect(
url: URI,
configuration: Configuration = .init(),
configuration: WebSocketClientConfiguration = .init(),
tlsConfiguration: TLSConfiguration? = nil,
eventLoopGroup: EventLoopGroup = MultiThreadedEventLoopGroup.singleton,
logger: Logger,
Expand Down Expand Up @@ -239,7 +218,7 @@ extension WebSocketClient {
/// - process: WebSocket data handler
public static func connect(
url: URI,
configuration: Configuration = .init(),
configuration: WebSocketClientConfiguration = .init(),
transportServicesTLSOptions: TSTLSOptions,
eventLoopGroup: NIOTSEventLoopGroup = NIOTSEventLoopGroup.singleton,
logger: Logger,
Expand Down
18 changes: 8 additions & 10 deletions Sources/HummingbirdWebSocket/Client/WebSocketClientChannel.swift
Original file line number Diff line number Diff line change
Expand Up @@ -30,20 +30,18 @@ public struct WebSocketClientChannel<Handler: WebSocketDataHandler>: ClientConne

let url: String
let handler: Handler
let maxFrameSize: Int
let additionalHeaders: HTTPFields
let configuration: WebSocketClientConfiguration

init(handler: Handler, url: String, maxFrameSize: Int = 1 << 14, additionalHeaders: HTTPFields = .init()) {
init(handler: Handler, url: String, configuration: WebSocketClientConfiguration) {
self.url = url
self.handler = handler
self.maxFrameSize = maxFrameSize
self.additionalHeaders = additionalHeaders
self.configuration = configuration
}

public func setup(channel: any Channel, logger: Logger) -> NIOCore.EventLoopFuture<Value> {
channel.eventLoop.makeCompletedFuture {
let upgrader = NIOTypedWebSocketClientUpgrader<UpgradeResult>(
maxFrameSize: maxFrameSize,
maxFrameSize: self.configuration.maxFrameSize,
upgradePipelineHandler: { channel, _ in
channel.eventLoop.makeCompletedFuture {
let asyncChannel = try NIOAsyncChannel<WebSocketFrame, WebSocketFrame>(wrappingChannelSynchronously: channel)
Expand All @@ -55,7 +53,7 @@ public struct WebSocketClientChannel<Handler: WebSocketDataHandler>: ClientConne
var headers = HTTPHeaders()
headers.add(name: "Content-Type", value: "text/plain; charset=utf-8")
headers.add(name: "Content-Length", value: "0")
let additionalHeaders = HTTPHeaders(self.additionalHeaders)
let additionalHeaders = HTTPHeaders(self.configuration.additionalHeaders)
headers.add(contentsOf: additionalHeaders)

let requestHead = HTTPRequestHead(
Expand Down Expand Up @@ -85,9 +83,9 @@ public struct WebSocketClientChannel<Handler: WebSocketDataHandler>: ClientConne

public func handle(value: Value, logger: Logger) async throws {
switch try await value.get() {
case .websocket(let websocketChannel):
let webSocket = WebSocketHandler(asyncChannel: websocketChannel, type: .client)
let context = self.handler.alreadySetupContext ?? .init(logger: logger, allocator: websocketChannel.channel.allocator)
case .websocket(let webSocketChannel):
let webSocket = WebSocketHandler(asyncChannel: webSocketChannel, type: .client)
let context = self.handler.alreadySetupContext ?? .init(channel: webSocketChannel.channel, logger: logger)
await webSocket.handle(handler: self.handler, context: context)
case .notUpgraded:
// The upgrade to websocket did not succeed.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
//===----------------------------------------------------------------------===//
//
// This source file is part of the Hummingbird server framework project
//
// Copyright (c) 2023-2024 the Hummingbird authors
// Licensed under Apache License v2.0
//
// See LICENSE.txt for license information
// See hummingbird/CONTRIBUTORS.txt for the list of Hummingbird authors
//
// SPDX-License-Identifier: Apache-2.0
//
//===----------------------------------------------------------------------===//

import HTTPTypes

public struct WebSocketClientConfiguration: Sendable {
/// Max websocket frame size that can be sent/received
public var maxFrameSize: Int
/// Additional headers to be sent with the initial HTTP request
public var additionalHeaders: HTTPFields

/// Initialize WebSocketClient configuration
/// - Paramters
/// - maxFrameSize: Max websocket frame size that can be sent/received
/// - additionalHeaders: Additional headers to be sent with the initial HTTP request
public init(
maxFrameSize: Int = (1 << 14),
additionalHeaders: HTTPFields = .init()
) {
self.maxFrameSize = maxFrameSize
self.additionalHeaders = additionalHeaders
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,17 @@
//
//===----------------------------------------------------------------------===//

import HTTPTypes
import NIOConcurrencyHelpers
import NIOCore
import NIOHTTP1
import NIOHTTPTypesHTTP1
import NIOWebSocket

/// Should HTTP channel upgrade to WebSocket
public enum ShouldUpgradeResult<Value: Sendable>: Sendable {
case dontUpgrade
case upgrade(HTTPHeaders, Value)
case upgrade(HTTPFields, Value)
}

extension NIOTypedWebSocketServerUpgrader {
Expand Down Expand Up @@ -47,21 +50,27 @@ extension NIOTypedWebSocketServerUpgrader {
public convenience init<Value>(
maxFrameSize: Int = 1 << 14,
enableAutomaticErrorHandling: Bool = true,
shouldUpgrade: @escaping @Sendable (Channel, HTTPRequestHead) -> EventLoopFuture<ShouldUpgradeResult<Value>>,
shouldUpgrade: @escaping @Sendable (Channel, HTTPRequest) -> EventLoopFuture<ShouldUpgradeResult<Value>>,
upgradePipelineHandler: @escaping @Sendable (Channel, Value) -> EventLoopFuture<UpgradeResult>
) {
let shouldUpgradeResult = NIOLockedValueBox<Value?>(nil)
self.init(
maxFrameSize: maxFrameSize,
enableAutomaticErrorHandling: enableAutomaticErrorHandling,
shouldUpgrade: { channel, head in
shouldUpgrade(channel, head).map { result in
shouldUpgrade: { (channel, head: HTTPRequestHead) in
let request: HTTPRequest
do {
request = try HTTPRequest(head, secure: false, splitCookie: false)
} catch {
return channel.eventLoop.makeFailedFuture(error)
}
return shouldUpgrade(channel, request).map { result in
switch result {
case .dontUpgrade:
return nil
case .upgrade(let headers, let value):
shouldUpgradeResult.withLockedValue { $0 = value }
return headers
return .init(headers)
}
}
},
Expand Down
Loading

0 comments on commit 60f1f10

Please sign in to comment.