Skip to content

Commit

Permalink
Persist query parameters in WebSocket URI (#101)
Browse files Browse the repository at this point in the history
  • Loading branch information
JakeTiritilli authored Sep 5, 2021
1 parent d7537b7 commit b1c4df8
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 2 deletions.
10 changes: 8 additions & 2 deletions Sources/WebSocketKit/HTTPInitialRequestHandler.swift
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,14 @@ final class HTTPInitialRequestHandler: ChannelInboundHandler, RemovableChannelHa

let host: String
let path: String
let query: String?
let headers: HTTPHeaders
let upgradePromise: EventLoopPromise<Void>

init(host: String, path: String, headers: HTTPHeaders, upgradePromise: EventLoopPromise<Void>) {
init(host: String, path: String, query: String?, headers: HTTPHeaders, upgradePromise: EventLoopPromise<Void>) {
self.host = host
self.path = path
self.query = query
self.headers = headers
self.upgradePromise = upgradePromise
}
Expand All @@ -23,10 +25,14 @@ final class HTTPInitialRequestHandler: ChannelInboundHandler, RemovableChannelHa
headers.add(name: "Content-Length", value: "\(0)")
headers.add(name: "Host", value: self.host)

var uri = self.path.hasPrefix("/") ? self.path : "/" + self.path
if let query = self.query {
uri += "?\(query)"
}
let requestHead = HTTPRequestHead(
version: HTTPVersion(major: 1, minor: 1),
method: .GET,
uri: self.path.hasPrefix("/") ? self.path : "/" + self.path,
uri: uri,
headers: headers
)
context.write(self.wrapOutboundOut(.head(requestHead)), promise: nil)
Expand Down
3 changes: 3 additions & 0 deletions Sources/WebSocketKit/WebSocket+Connect.swift
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ extension WebSocket {
host: url.host ?? "localhost",
port: url.port ?? (scheme == "wss" ? 443 : 80),
path: url.path,
query: url.query,
headers: headers,
configuration: configuration,
on: eventLoopGroup,
Expand All @@ -43,6 +44,7 @@ extension WebSocket {
host: String,
port: Int = 80,
path: String = "/",
query: String? = nil,
headers: HTTPHeaders = [:],
configuration: WebSocketClient.Configuration = .init(),
on eventLoopGroup: EventLoopGroup,
Expand All @@ -56,6 +58,7 @@ extension WebSocket {
host: host,
port: port,
path: path,
query: query,
headers: headers,
onUpgrade: onUpgrade
)
Expand Down
2 changes: 2 additions & 0 deletions Sources/WebSocketKit/WebSocketClient.swift
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ public final class WebSocketClient {
host: String,
port: Int,
path: String = "/",
query: String? = nil,
headers: HTTPHeaders = [:],
onUpgrade: @escaping (WebSocket) -> ()
) -> EventLoopFuture<Void> {
Expand All @@ -65,6 +66,7 @@ public final class WebSocketClient {
let httpHandler = HTTPInitialRequestHandler(
host: host,
path: path,
query: query,
headers: headers,
upgradePromise: upgradePromise
)
Expand Down
23 changes: 23 additions & 0 deletions Tests/WebSocketKitTests/WebSocketKitTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,29 @@ final class WebSocketKitTests: XCTestCase {
try XCTAssertEqual(promise.futureResult.wait(), "supersecretsauce")
try server.close(mode: .all).wait()
}

func testQueryParamsAreSent() throws {
let promise = self.elg.next().makePromise(of: String.self)

let server = try ServerBootstrap.webSocket(on: self.elg) { req, ws in
promise.succeed(req.uri)
ws.close(promise: nil)
}.bind(host: "localhost", port: 0).wait()

guard let port = server.localAddress?.port else {
XCTFail("couldn't get port from \(server.localAddress.debugDescription)")
return
}

WebSocket.connect(
to: "ws://localhost:\(port)?foo=bar&bar=baz",
on: self.elg) { ws in
_ = ws.close()
}.cascadeFailure(to: promise)

try XCTAssertEqual(promise.futureResult.wait(), "/?foo=bar&bar=baz")
try server.close(mode: .all).wait()
}

func testLocally() throws {
// swap to test websocket server against local client
Expand Down

0 comments on commit b1c4df8

Please sign in to comment.