diff --git a/Snippets/AutobahnClientTest.swift b/Snippets/AutobahnClientTest.swift index ed2e026..e1f1314 100644 --- a/Snippets/AutobahnClientTest.swift +++ b/Snippets/AutobahnClientTest.swift @@ -2,20 +2,24 @@ import HummingbirdWSClient import HummingbirdWSCompression import Logging -// Autobahn tests -// 1. Framing -// 2. Pings/Pongs -// 3. Reserved bits -// 4. Opcodes -// 5. Fragmentation -// 6. UTF8 handling -// 7. Close handling -// 9. Limits/performance -// 10. Misc -// 12. WebSocket compression (different payloads) -// 13. WebSocket compression (different parameters) +// Autobahn tests (https://github.com/crossbario/autobahn-testsuite) +// run +// ``` +// .scripts/autobahn.sh +// ``` +// 1. Framing (passed) +// 2. Pings/Pongs (passed) +// 3. Reserved bits (reserved bit checking not supported) +// 4. Opcodes (passed) +// 5. Fragmentation (passed. 5.1/5.2 non-strict) +// 6. UTF8 handling (utf8 validation not supported) +// 7. Close handling (passed, except 7.5.1) +// 9. Limits/performance (passed) +// 10. Misc (passed) +// 12. WebSocket compression (different payloads) (passed) +// 13. WebSocket compression (different parameters) (passed) -let cases = 1...102 +let cases = 1...1000 var logger = Logger(label: "TestClient") logger.logLevel = .trace @@ -24,7 +28,7 @@ do { logger.info("Case \(c)") try await WebSocketClient.connect( url: .init("ws://127.0.0.1:9001/runCase?case=\(c)&agent=HB"), - configuration: .init(maxFrameSize: 1 << 16, extensions: [.perMessageDeflate(maxDecompressedFrameSize: 65536)]), + configuration: .init(maxFrameSize: 16_777_216, extensions: [.perMessageDeflate(maxDecompressedFrameSize: 16_777_216)]), logger: logger ) { inbound, outbound, _ in for try await msg in inbound.messages(maxSize: .max) { @@ -37,9 +41,10 @@ do { } } } - try await WebSocketClient.connect(url: .init("ws://127.0.0.1:9001/updateReports?agent=HB"), logger: logger) { inbound, _, _ in - for try await _ in inbound {} - } } catch { logger.error("Error: \(error)") } + +try await WebSocketClient.connect(url: .init("ws://127.0.0.1:9001/updateReports?agent=HB"), logger: logger) { inbound, _, _ in + for try await _ in inbound {} +} diff --git a/Snippets/AutobahnServerTest.swift b/Snippets/AutobahnServerTest.swift new file mode 100644 index 0000000..3d7fb2a --- /dev/null +++ b/Snippets/AutobahnServerTest.swift @@ -0,0 +1,45 @@ +import Hummingbird +import HummingbirdWebSocket +import HummingbirdWSCompression +import Logging + +// Autobahn tests (https://github.com/crossbario/autobahn-testsuite) +// run +// ``` +// .scripts/autobahn.sh +// ``` +// 1. Framing (passed) +// 2. Pings/Pongs (passed) +// 3. Reserved bits (not supported) (reserved bit checking not supported) +// 4. Opcodes (passed) +// 5. Fragmentation (passed) +// 6. UTF8 handling (not supported) (utf8 validation not supported) +// 7. Close handling (passed, except 7.5.1) +// 9. Limits/performance (passed) +// 10. Misc (passed) +// 12. WebSocket compression (different payloads) (not run) +// 13. WebSocket compression (different parameters) (not run) + +var logger = Logger(label: "TestClient") +logger.logLevel = .trace + +// let router = Router().get("report") {} + +let app = Application( + router: Router(), + server: .http1WebSocketUpgrade(configuration: .init(maxFrameSize: 16_777_216, extensions: [.perMessageDeflate(maxDecompressedFrameSize: 16_777_216)])) { _, _, _ in + return .upgrade([:]) { inbound, outbound, _ in + for try await msg in inbound.messages(maxSize: .max) { + switch msg { + case .binary(let buffer): + try await outbound.write(.binary(buffer)) + case .text(let string): + try await outbound.write(.text(string)) + } + } + } + }, + configuration: .init(address: .hostname("127.0.0.1", port: 9001)), + logger: logger +) +try await app.runService() diff --git a/Sources/HummingbirdWSCore/WebSocketHandler.swift b/Sources/HummingbirdWSCore/WebSocketHandler.swift index 9ca029f..f8659a9 100644 --- a/Sources/HummingbirdWSCore/WebSocketHandler.swift +++ b/Sources/HummingbirdWSCore/WebSocketHandler.swift @@ -74,27 +74,26 @@ package actor WebSocketHandler { } } - static let pingDataSize = 16 + let channel: Channel var outbound: NIOAsyncChannelOutboundWriter let type: WebSocketType let configuration: Configuration let logger: Logger - var pingData: ByteBuffer - var pingTime: ContinuousClock.Instant = .now - var closeState: CloseState + var stateMachine: WebSocketStateMachine private init( + channel: Channel, outbound: NIOAsyncChannelOutboundWriter, type: WebSocketType, configuration: Configuration, context: some WebSocketContext ) { + self.channel = channel self.outbound = outbound self.type = type self.configuration = configuration self.logger = context.logger - self.pingData = ByteBufferAllocator().buffer(capacity: Self.pingDataSize) - self.closeState = .open + self.stateMachine = .init(autoPingSetup: configuration.autoPing) } package static func handle( @@ -111,27 +110,12 @@ package actor WebSocketHandler { let rt = try await asyncChannel.executeThenClose { inbound, outbound in try await withTaskCancellationHandler { try await withThrowingTaskGroup(of: WebSocketCloseFrame.self) { group in - let webSocketHandler = Self(outbound: outbound, type: type, configuration: configuration, context: context) - if case .enabled(let period) = configuration.autoPing.value { + let webSocketHandler = Self(channel: asyncChannel.channel, outbound: outbound, type: type, configuration: configuration, context: context) + if case .enabled = configuration.autoPing.value { /// Add task sending ping frames every so often and verifying a pong frame was sent back group.addTask { - var waitTime = period - while true { - try await Task.sleep(for: waitTime) - if let timeSinceLastPing = await webSocketHandler.getTimeSinceLastWaitingPing() { - // if time is less than timeout value, set wait time to when it would timeout - // and re-run loop - if timeSinceLastPing < period { - waitTime = period - timeSinceLastPing - continue - } else { - try await asyncChannel.channel.close(mode: .input) - return .init(closeCode: .goingAway, reason: "No response to ping") - } - } - try await webSocketHandler.ping() - waitTime = period - } + try await webSocketHandler.runAutoPingLoop() + return .init(closeCode: .goingAway, reason: "Ping timeout") } } let rt = try await webSocketHandler.handle(inbound: inbound, outbound: outbound, handler: handler, context: context) @@ -177,7 +161,7 @@ package actor WebSocketHandler { } do { try await self.close(code: closeCode) - if case .closing = self.closeState { + if case .closing = self.stateMachine.state { // Close handshake. Wait for responding close or until inbound ends while let frame = try await inboundIterator.next() { if case .connectionClose = frame.opcode { @@ -193,12 +177,34 @@ package actor WebSocketHandler { try? await self.close(code: .normalClosure) } } - return switch self.closeState { + return switch self.stateMachine.state { case .closed(let code): code default: nil } } + func runAutoPingLoop() async throws { + let period = self.stateMachine.pingTimePeriod + try await Task.sleep(for: period) + while true { + switch self.stateMachine.sendPing() { + case .sendPing(let buffer): + try await self.write(frame: .init(fin: true, opcode: .ping, data: buffer)) + + case .wait(let time): + try await Task.sleep(for: time) + + case .closeConnection(let errorCode): + try await self.sendClose(code: errorCode, reason: "Ping timeout") + try await self.channel.close(mode: .input) + return + + case .stop: + return + } + } + } + /// Send WebSocket frame func write(frame: WebSocketFrame) async throws { var frame = frame @@ -227,122 +233,73 @@ package actor WebSocketHandler { } /// Respond to ping - func onPing( - _ frame: WebSocketFrame - ) async throws { - if frame.fin { - try await self.pong(data: frame.unmaskedData) - } else { - try await self.close(code: .protocolError) + func onPing(_ frame: WebSocketFrame) async throws { + guard frame.fin else { + self.channel.close(promise: nil) + return } - } + switch self.stateMachine.receivedPing(frameData: frame.unmaskedData) { + case .pong(let frameData): + try await self.write(frame: .init(fin: true, opcode: .pong, data: frameData)) - /// Respond to pong - func onPong( - _ frame: WebSocketFrame - ) throws { - let frameData = frame.unmaskedData - // ignore pong frames with frame data not the same as the last ping - guard frameData == self.pingData else { return } - // clear ping data - self.pingData.clear() - } + case .protocolError: + try await self.close(code: .protocolError) - /// Send ping - func ping() async throws { - guard case .open = self.closeState else { return } - if self.pingData.readableBytes == 0 { - // creating random payload - let random = (0.. Duration? { - guard self.pingData.readableBytes > 0 else { return nil } - return .now - self.pingTime + /// Respond to pong + func onPong(_ frame: WebSocketFrame) async throws { + guard frame.fin else { + self.channel.close(promise: nil) + return + } + self.stateMachine.receivedPong(frameData: frame.unmaskedData) } /// Send close - func close( - code: WebSocketErrorCode = .normalClosure, - reason: String? = nil - ) async throws { - switch self.closeState { - case .open: - var buffer = ByteBufferAllocator().buffer(capacity: 2 + (reason?.utf8.count ?? 0)) - buffer.write(webSocketErrorCode: code) - if let reason { - buffer.writeString(reason) - } - - try await self.write(frame: .init(fin: true, opcode: .connectionClose, data: buffer)) + func close(code: WebSocketErrorCode = .normalClosure, reason: String? = nil) async throws { + switch self.stateMachine.close() { + case .sendClose: + try await self.sendClose(code: code, reason: reason) // Only server should initiate a connection close. Clients should wait for the // server to close the connection when it receives the WebSocket close packet // See https://www.rfc-editor.org/rfc/rfc6455#section-7.1.1 if self.type == .server { self.outbound.finish() } - self.closeState = .closing - default: + case .doNothing: break } } func receivedClose(_ frame: WebSocketFrame) async throws { - // we received a connection close. - // send a close back if it hasn't already been send and exit - var data = frame.unmaskedData - let dataSize = data.readableBytes - // read close code and close reason - let closeCode = data.readWebSocketErrorCode() - let reason = data.readableBytes > 0 - ? data.readString(length: data.readableBytes) - : nil - - switch self.closeState { - case .open: - self.closeState = .closed(closeCode.map { .init(closeCode: $0, reason: reason) }) - let code: WebSocketErrorCode = if dataSize == 0 || closeCode != nil { - // codes 3000 - 3999 are reserved for use by libraries, frameworks, and applications - // so are considered valid - if case .unknown(let code) = closeCode, code < 3000 || code > 3999 { - .protocolError - } else { - .normalClosure - } - } else { - .protocolError - } - - var buffer = ByteBufferAllocator().buffer(capacity: 2) - buffer.write(webSocketErrorCode: code) - - try await self.write(frame: .init(fin: true, opcode: .connectionClose, data: buffer)) + switch self.stateMachine.receivedClose(frameData: frame.unmaskedData) { + case .sendClose(let errorCode): + try await self.sendClose(code: errorCode, reason: nil) // Only server should initiate a connection close. Clients should wait for the // server to close the connection when it receives the WebSocket close packet // See https://www.rfc-editor.org/rfc/rfc6455#section-7.1.1 if self.type == .server { self.outbound.finish() } - - case .closing: - self.closeState = .closed(closeCode.map { .init(closeCode: $0, reason: reason) }) - - default: + case .doNothing: break } } + private func sendClose(code: WebSocketErrorCode = .normalClosure, reason: String? = nil) async throws { + var buffer = ByteBufferAllocator().buffer(capacity: 2 + (reason?.utf8.count ?? 0)) + buffer.write(webSocketErrorCode: code) + if let reason { + buffer.writeString(reason) + } + + try await self.write(frame: .init(fin: true, opcode: .connectionClose, data: buffer)) + } + /// Make mask key to be used in WebSocket frame private func makeMaskKey() -> WebSocketMaskingKey? { guard self.type == .client else { return nil } diff --git a/Sources/HummingbirdWSCore/WebSocketStateMachine.swift b/Sources/HummingbirdWSCore/WebSocketStateMachine.swift new file mode 100644 index 0000000..60307cd --- /dev/null +++ b/Sources/HummingbirdWSCore/WebSocketStateMachine.swift @@ -0,0 +1,180 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the Hummingbird server framework project +// +// Copyright (c) 2024 the Hummingbird authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See hummingbird/CONTRIBUTORS.txt for the list of Hummingbird authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// + +import NIOCore +import NIOWebSocket + +struct WebSocketStateMachine { + static let pingDataSize = 16 + let pingTimePeriod: Duration + var state: State + + init(autoPingSetup: AutoPingSetup) { + switch autoPingSetup.value { + case .enabled(let timePeriod): + self.pingTimePeriod = timePeriod + case .disabled: + self.pingTimePeriod = .nanoseconds(0) + } + self.state = .open(.init()) + } + + enum CloseResult { + case sendClose + case doNothing + } + + mutating func close() -> CloseResult { + switch self.state { + case .open: + self.state = .closing + return .sendClose + case .closing: + return .doNothing + case .closed: + return .doNothing + } + } + + enum ReceivedCloseResult { + case sendClose(WebSocketErrorCode) + case doNothing + } + + // we received a connection close. + // send a close back if it hasn't already been send and exit + mutating func receivedClose(frameData: ByteBuffer) -> ReceivedCloseResult { + var frameData = frameData + let dataSize = frameData.readableBytes + // read close code and close reason + let closeCode = frameData.readWebSocketErrorCode() + let reason = frameData.readableBytes > 0 + ? frameData.readString(length: frameData.readableBytes) + : nil + + switch self.state { + case .open: + self.state = .closed(closeCode.map { .init(closeCode: $0, reason: reason) }) + let code: WebSocketErrorCode = if dataSize == 0 || closeCode != nil { + // codes 3000 - 3999 are reserved for use by libraries, frameworks + // codes 4000 - 4999 are reserved for private use + // both of these are considered valid. + if case .unknown(let code) = closeCode, code < 3000 || code > 4999 { + .protocolError + } else { + .normalClosure + } + } else { + .protocolError + } + return .sendClose(code) + case .closing: + self.state = .closed(closeCode.map { .init(closeCode: $0, reason: reason) }) + return .doNothing + case .closed: + return .doNothing + } + } + + enum SendPingResult { + case sendPing(ByteBuffer) + case wait(Duration) + case closeConnection(WebSocketErrorCode) + case stop + } + + mutating func sendPing() -> SendPingResult { + switch self.state { + case .open(var state): + if let lastPingTime = state.lastPingTime { + let timeSinceLastPing = .now - lastPingTime + // if time is less than timeout value, set wait time to when it would timeout + // and re-run loop + if timeSinceLastPing < self.pingTimePeriod { + return .wait(self.pingTimePeriod - timeSinceLastPing) + } else { + return .closeConnection(.goingAway) + } + } + // creating random payload + let random = (0.. ReceivedPingResult { + switch self.state { + case .open: + guard frameData.readableBytes < 126 else { return .protocolError } + return .pong(frameData) + + case .closing: + return .pong(frameData) + + case .closed: + return .doNothing + } + } + + mutating func receivedPong(frameData: ByteBuffer) { + switch self.state { + case .open(var state): + let frameData = frameData + // ignore pong frames with frame data not the same as the last ping + guard frameData == state.pingData else { return } + // clear ping data + state.lastPingTime = nil + self.state = .open(state) + + case .closing: + break + + case .closed: + break + } + } +} + +extension WebSocketStateMachine { + struct OpenState { + var pingData: ByteBuffer + var lastPingTime: ContinuousClock.Instant? + + init() { + self.pingData = ByteBufferAllocator().buffer(capacity: WebSocketStateMachine.pingDataSize) + self.lastPingTime = nil + } + } + + enum State { + case open(OpenState) + case closing + case closed(WebSocketCloseFrame?) + } +} diff --git a/Tests/HummingbirdWebSocketTests/WebSocketStateMachineTests.swift b/Tests/HummingbirdWebSocketTests/WebSocketStateMachineTests.swift new file mode 100644 index 0000000..e7fd708 --- /dev/null +++ b/Tests/HummingbirdWebSocketTests/WebSocketStateMachineTests.swift @@ -0,0 +1,62 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the Hummingbird server framework project +// +// Copyright (c) 2024 the Hummingbird authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See hummingbird/CONTRIBUTORS.txt for the list of Hummingbird authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// + +@testable import HummingbirdWSCore +import NIOCore +import NIOWebSocket +import XCTest + +final class WebSocketStateMachineTests: XCTestCase { + private func closeFrameData(code: WebSocketErrorCode = .normalClosure, reason: String? = nil) -> ByteBuffer { + var buffer = ByteBufferAllocator().buffer(capacity: 2 + (reason?.utf8.count ?? 0)) + buffer.write(webSocketErrorCode: code) + if let reason { + buffer.writeString(reason) + } + return buffer + } + + func testClose() { + var stateMachine = WebSocketStateMachine(autoPingSetup: .disabled) + guard case .sendClose = stateMachine.close() else { XCTFail(); return } + guard case .doNothing = stateMachine.close() else { XCTFail(); return } + guard case .doNothing = stateMachine.receivedClose(frameData: self.closeFrameData()) else { XCTFail(); return } + guard case .closed(let frame) = stateMachine.state else { XCTFail(); return } + XCTAssertEqual(frame?.closeCode, .normalClosure) + } + + func testReceivedClose() { + var stateMachine = WebSocketStateMachine(autoPingSetup: .disabled) + guard case .sendClose(let error) = stateMachine.receivedClose(frameData: closeFrameData(code: .goingAway)) else { XCTFail(); return } + XCTAssertEqual(error, .normalClosure) + guard case .closed(let frame) = stateMachine.state else { XCTFail(); return } + XCTAssertEqual(frame?.closeCode, .goingAway) + } + + func testPingLoopNoPong() { + var stateMachine = WebSocketStateMachine(autoPingSetup: .enabled(timePeriod: .seconds(15))) + guard case .sendPing = stateMachine.sendPing() else { XCTFail(); return } + guard case .wait = stateMachine.sendPing() else { XCTFail(); return } + } + + func testPingLoop() { + var stateMachine = WebSocketStateMachine(autoPingSetup: .enabled(timePeriod: .seconds(15))) + guard case .sendPing(let buffer) = stateMachine.sendPing() else { XCTFail(); return } + guard case .wait = stateMachine.sendPing() else { XCTFail(); return } + stateMachine.receivedPong(frameData: buffer) + guard case .open(let openState) = stateMachine.state else { XCTFail(); return } + XCTAssertEqual(openState.lastPingTime, nil) + guard case .sendPing = stateMachine.sendPing() else { XCTFail(); return } + } +} diff --git a/Tests/HummingbirdWebSocketTests/WebSocketTests.swift b/Tests/HummingbirdWebSocketTests/WebSocketTests.swift index 203490d..10ded90 100644 --- a/Tests/HummingbirdWebSocketTests/WebSocketTests.swift +++ b/Tests/HummingbirdWebSocketTests/WebSocketTests.swift @@ -524,11 +524,13 @@ final class HummingbirdWebSocketTests: XCTestCase { configuration: .init(address: .hostname("127.0.0.1", port: 0)) ) _ = try await application.test(.live) { client in - try await client.ws("/ws") { inbound, _, _ in + let frame = try await client.ws("/ws") { inbound, _, _ in // don't handle any inbound data for a period much longer than the auto ping period try await Task.sleep(for: .milliseconds(500)) for try await _ in inbound {} } + XCTAssertEqual(frame?.closeCode, .goingAway) + XCTAssertEqual(frame?.reason, "Ping timeout") } } diff --git a/scripts/autobahn.sh b/scripts/autobahn-client.sh similarity index 100% rename from scripts/autobahn.sh rename to scripts/autobahn-client.sh diff --git a/scripts/autobahn-config/fuzzingclient.json b/scripts/autobahn-config/fuzzingclient.json new file mode 100644 index 0000000..3c1cee5 --- /dev/null +++ b/scripts/autobahn-config/fuzzingclient.json @@ -0,0 +1,13 @@ +{ + "servers": [ + {"url": "ws://host.docker.internal:9001", "agent": "HB"} + ], + "outdir": "./reports/server", + "cases": ["*"], + "exclude-cases": [ + "9.*", + "12.*", + "13.*" + ], + "exclude-agent-cases": {} +} diff --git a/scripts/autobahn-config/fuzzingserver.json b/scripts/autobahn-config/fuzzingserver.json index 902cc5c..4bc4fc5 100644 --- a/scripts/autobahn-config/fuzzingserver.json +++ b/scripts/autobahn-config/fuzzingserver.json @@ -3,7 +3,6 @@ "outdir": "./reports/clients", "cases": ["*"], "exclude-cases": [ - "6.*", "9.*", "12.*", "13.*" diff --git a/scripts/autobahn-server.sh b/scripts/autobahn-server.sh new file mode 100755 index 0000000..702e9d8 --- /dev/null +++ b/scripts/autobahn-server.sh @@ -0,0 +1,7 @@ +docker run -it --rm \ + -v "${PWD}/scripts/autobahn-config:/config" \ + -v "${PWD}/.build/reports:/reports" \ + -p 9001:9001 \ + --network=host \ + --name fuzzingclient \ + crossbario/autobahn-testsuite wstest -m fuzzingclient -s /config/fuzzingclient.json