Skip to content

Commit

Permalink
Use a state machine for WebSocket connection (#84)
Browse files Browse the repository at this point in the history
  • Loading branch information
adam-fowler authored Nov 11, 2024
1 parent a1cd286 commit 19bbff0
Show file tree
Hide file tree
Showing 10 changed files with 402 additions and 132 deletions.
39 changes: 22 additions & 17 deletions Snippets/AutobahnClientTest.swift
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,24 @@ import HummingbirdWSClient
import HummingbirdWSCompression
import Logging

// Autobahn tests
// 1. Framing
// 2. Pings/Pongs
// 3. Reserved bits
// 4. Opcodes
// 5. Fragmentation
// 6. UTF8 handling
// 7. Close handling
// 9. Limits/performance
// 10. Misc
// 12. WebSocket compression (different payloads)
// 13. WebSocket compression (different parameters)
// Autobahn tests (https://github.com/crossbario/autobahn-testsuite)
// run
// ```
// .scripts/autobahn.sh
// ```
// 1. Framing (passed)
// 2. Pings/Pongs (passed)
// 3. Reserved bits (reserved bit checking not supported)
// 4. Opcodes (passed)
// 5. Fragmentation (passed. 5.1/5.2 non-strict)
// 6. UTF8 handling (utf8 validation not supported)
// 7. Close handling (passed, except 7.5.1)
// 9. Limits/performance (passed)
// 10. Misc (passed)
// 12. WebSocket compression (different payloads) (passed)
// 13. WebSocket compression (different parameters) (passed)

let cases = 1...102
let cases = 1...1000

var logger = Logger(label: "TestClient")
logger.logLevel = .trace
Expand All @@ -24,7 +28,7 @@ do {
logger.info("Case \(c)")
try await WebSocketClient.connect(
url: .init("ws://127.0.0.1:9001/runCase?case=\(c)&agent=HB"),
configuration: .init(maxFrameSize: 1 << 16, extensions: [.perMessageDeflate(maxDecompressedFrameSize: 65536)]),
configuration: .init(maxFrameSize: 16_777_216, extensions: [.perMessageDeflate(maxDecompressedFrameSize: 16_777_216)]),
logger: logger
) { inbound, outbound, _ in
for try await msg in inbound.messages(maxSize: .max) {
Expand All @@ -37,9 +41,10 @@ do {
}
}
}
try await WebSocketClient.connect(url: .init("ws://127.0.0.1:9001/updateReports?agent=HB"), logger: logger) { inbound, _, _ in
for try await _ in inbound {}
}
} catch {
logger.error("Error: \(error)")
}

try await WebSocketClient.connect(url: .init("ws://127.0.0.1:9001/updateReports?agent=HB"), logger: logger) { inbound, _, _ in
for try await _ in inbound {}
}
45 changes: 45 additions & 0 deletions Snippets/AutobahnServerTest.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import Hummingbird
import HummingbirdWebSocket
import HummingbirdWSCompression
import Logging

// Autobahn tests (https://github.com/crossbario/autobahn-testsuite)
// run
// ```
// .scripts/autobahn.sh
// ```
// 1. Framing (passed)
// 2. Pings/Pongs (passed)
// 3. Reserved bits (not supported) (reserved bit checking not supported)
// 4. Opcodes (passed)
// 5. Fragmentation (passed)
// 6. UTF8 handling (not supported) (utf8 validation not supported)
// 7. Close handling (passed, except 7.5.1)
// 9. Limits/performance (passed)
// 10. Misc (passed)
// 12. WebSocket compression (different payloads) (not run)
// 13. WebSocket compression (different parameters) (not run)

var logger = Logger(label: "TestClient")
logger.logLevel = .trace

// let router = Router().get("report") {}

let app = Application(
router: Router(),
server: .http1WebSocketUpgrade(configuration: .init(maxFrameSize: 16_777_216, extensions: [.perMessageDeflate(maxDecompressedFrameSize: 16_777_216)])) { _, _, _ in
return .upgrade([:]) { inbound, outbound, _ in
for try await msg in inbound.messages(maxSize: .max) {
switch msg {
case .binary(let buffer):
try await outbound.write(.binary(buffer))
case .text(let string):
try await outbound.write(.text(string))
}
}
}
},
configuration: .init(address: .hostname("127.0.0.1", port: 9001)),
logger: logger
)
try await app.runService()
183 changes: 70 additions & 113 deletions Sources/HummingbirdWSCore/WebSocketHandler.swift
Original file line number Diff line number Diff line change
Expand Up @@ -74,27 +74,26 @@ package actor WebSocketHandler {
}
}

static let pingDataSize = 16
let channel: Channel
var outbound: NIOAsyncChannelOutboundWriter<WebSocketFrame>
let type: WebSocketType
let configuration: Configuration
let logger: Logger
var pingData: ByteBuffer
var pingTime: ContinuousClock.Instant = .now
var closeState: CloseState
var stateMachine: WebSocketStateMachine

private init(
channel: Channel,
outbound: NIOAsyncChannelOutboundWriter<WebSocketFrame>,
type: WebSocketType,
configuration: Configuration,
context: some WebSocketContext
) {
self.channel = channel
self.outbound = outbound
self.type = type
self.configuration = configuration
self.logger = context.logger
self.pingData = ByteBufferAllocator().buffer(capacity: Self.pingDataSize)
self.closeState = .open
self.stateMachine = .init(autoPingSetup: configuration.autoPing)
}

package static func handle<Context: WebSocketContext>(
Expand All @@ -111,27 +110,12 @@ package actor WebSocketHandler {
let rt = try await asyncChannel.executeThenClose { inbound, outbound in
try await withTaskCancellationHandler {
try await withThrowingTaskGroup(of: WebSocketCloseFrame.self) { group in
let webSocketHandler = Self(outbound: outbound, type: type, configuration: configuration, context: context)
if case .enabled(let period) = configuration.autoPing.value {
let webSocketHandler = Self(channel: asyncChannel.channel, outbound: outbound, type: type, configuration: configuration, context: context)
if case .enabled = configuration.autoPing.value {
/// Add task sending ping frames every so often and verifying a pong frame was sent back
group.addTask {
var waitTime = period
while true {
try await Task.sleep(for: waitTime)
if let timeSinceLastPing = await webSocketHandler.getTimeSinceLastWaitingPing() {
// if time is less than timeout value, set wait time to when it would timeout
// and re-run loop
if timeSinceLastPing < period {
waitTime = period - timeSinceLastPing
continue
} else {
try await asyncChannel.channel.close(mode: .input)
return .init(closeCode: .goingAway, reason: "No response to ping")
}
}
try await webSocketHandler.ping()
waitTime = period
}
try await webSocketHandler.runAutoPingLoop()
return .init(closeCode: .goingAway, reason: "Ping timeout")
}
}
let rt = try await webSocketHandler.handle(inbound: inbound, outbound: outbound, handler: handler, context: context)
Expand Down Expand Up @@ -177,7 +161,7 @@ package actor WebSocketHandler {
}
do {
try await self.close(code: closeCode)
if case .closing = self.closeState {
if case .closing = self.stateMachine.state {
// Close handshake. Wait for responding close or until inbound ends
while let frame = try await inboundIterator.next() {
if case .connectionClose = frame.opcode {
Expand All @@ -193,12 +177,34 @@ package actor WebSocketHandler {
try? await self.close(code: .normalClosure)
}
}
return switch self.closeState {
return switch self.stateMachine.state {
case .closed(let code): code
default: nil
}
}

func runAutoPingLoop() async throws {
let period = self.stateMachine.pingTimePeriod
try await Task.sleep(for: period)
while true {
switch self.stateMachine.sendPing() {
case .sendPing(let buffer):
try await self.write(frame: .init(fin: true, opcode: .ping, data: buffer))

case .wait(let time):
try await Task.sleep(for: time)

case .closeConnection(let errorCode):
try await self.sendClose(code: errorCode, reason: "Ping timeout")
try await self.channel.close(mode: .input)
return

case .stop:
return
}
}
}

/// Send WebSocket frame
func write(frame: WebSocketFrame) async throws {
var frame = frame
Expand Down Expand Up @@ -227,122 +233,73 @@ package actor WebSocketHandler {
}

/// Respond to ping
func onPing(
_ frame: WebSocketFrame
) async throws {
if frame.fin {
try await self.pong(data: frame.unmaskedData)
} else {
try await self.close(code: .protocolError)
func onPing(_ frame: WebSocketFrame) async throws {
guard frame.fin else {
self.channel.close(promise: nil)
return
}
}
switch self.stateMachine.receivedPing(frameData: frame.unmaskedData) {
case .pong(let frameData):
try await self.write(frame: .init(fin: true, opcode: .pong, data: frameData))

/// Respond to pong
func onPong(
_ frame: WebSocketFrame
) throws {
let frameData = frame.unmaskedData
// ignore pong frames with frame data not the same as the last ping
guard frameData == self.pingData else { return }
// clear ping data
self.pingData.clear()
}
case .protocolError:
try await self.close(code: .protocolError)

/// Send ping
func ping() async throws {
guard case .open = self.closeState else { return }
if self.pingData.readableBytes == 0 {
// creating random payload
let random = (0..<Self.pingDataSize).map { _ in UInt8.random(in: 0...255) }
self.pingData.writeBytes(random)
case .doNothing:
break
}
self.pingTime = .now
try await self.write(frame: .init(fin: true, opcode: .ping, data: self.pingData))
}

/// Send pong
func pong(data: ByteBuffer?) async throws {
guard case .open = self.closeState else { return }
try await self.write(frame: .init(fin: true, opcode: .pong, data: data ?? .init()))
}

/// Return time ping occurred if it is still waiting for a pong
func getTimeSinceLastWaitingPing() -> Duration? {
guard self.pingData.readableBytes > 0 else { return nil }
return .now - self.pingTime
/// Respond to pong
func onPong(_ frame: WebSocketFrame) async throws {
guard frame.fin else {
self.channel.close(promise: nil)
return
}
self.stateMachine.receivedPong(frameData: frame.unmaskedData)
}

/// Send close
func close(
code: WebSocketErrorCode = .normalClosure,
reason: String? = nil
) async throws {
switch self.closeState {
case .open:
var buffer = ByteBufferAllocator().buffer(capacity: 2 + (reason?.utf8.count ?? 0))
buffer.write(webSocketErrorCode: code)
if let reason {
buffer.writeString(reason)
}

try await self.write(frame: .init(fin: true, opcode: .connectionClose, data: buffer))
func close(code: WebSocketErrorCode = .normalClosure, reason: String? = nil) async throws {
switch self.stateMachine.close() {
case .sendClose:
try await self.sendClose(code: code, reason: reason)
// Only server should initiate a connection close. Clients should wait for the
// server to close the connection when it receives the WebSocket close packet
// See https://www.rfc-editor.org/rfc/rfc6455#section-7.1.1
if self.type == .server {
self.outbound.finish()
}
self.closeState = .closing
default:
case .doNothing:
break
}
}

func receivedClose(_ frame: WebSocketFrame) async throws {
// we received a connection close.
// send a close back if it hasn't already been send and exit
var data = frame.unmaskedData
let dataSize = data.readableBytes
// read close code and close reason
let closeCode = data.readWebSocketErrorCode()
let reason = data.readableBytes > 0
? data.readString(length: data.readableBytes)
: nil

switch self.closeState {
case .open:
self.closeState = .closed(closeCode.map { .init(closeCode: $0, reason: reason) })
let code: WebSocketErrorCode = if dataSize == 0 || closeCode != nil {
// codes 3000 - 3999 are reserved for use by libraries, frameworks, and applications
// so are considered valid
if case .unknown(let code) = closeCode, code < 3000 || code > 3999 {
.protocolError
} else {
.normalClosure
}
} else {
.protocolError
}

var buffer = ByteBufferAllocator().buffer(capacity: 2)
buffer.write(webSocketErrorCode: code)

try await self.write(frame: .init(fin: true, opcode: .connectionClose, data: buffer))
switch self.stateMachine.receivedClose(frameData: frame.unmaskedData) {
case .sendClose(let errorCode):
try await self.sendClose(code: errorCode, reason: nil)
// Only server should initiate a connection close. Clients should wait for the
// server to close the connection when it receives the WebSocket close packet
// See https://www.rfc-editor.org/rfc/rfc6455#section-7.1.1
if self.type == .server {
self.outbound.finish()
}

case .closing:
self.closeState = .closed(closeCode.map { .init(closeCode: $0, reason: reason) })

default:
case .doNothing:
break
}
}

private func sendClose(code: WebSocketErrorCode = .normalClosure, reason: String? = nil) async throws {
var buffer = ByteBufferAllocator().buffer(capacity: 2 + (reason?.utf8.count ?? 0))
buffer.write(webSocketErrorCode: code)
if let reason {
buffer.writeString(reason)
}

try await self.write(frame: .init(fin: true, opcode: .connectionClose, data: buffer))
}

/// Make mask key to be used in WebSocket frame
private func makeMaskKey() -> WebSocketMaskingKey? {
guard self.type == .client else { return nil }
Expand Down
Loading

0 comments on commit 19bbff0

Please sign in to comment.