Skip to content

Commit

Permalink
add connection tests
Browse files Browse the repository at this point in the history
  • Loading branch information
MahdiBM committed Aug 26, 2024
1 parent e904c86 commit 1ed7e7d
Show file tree
Hide file tree
Showing 3 changed files with 81 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,10 @@ extension PostgresFrontendMessage {
case .saslResponse:
preconditionFailure("TODO: Unimplemented")
case .query:
return .query
guard let query = buffer.readNullTerminatedString() else {
throw PSQLPartialDecodingError.fieldNotDecodable(type: String.self)
}
return .query(.init(query: query))
case .sync:
return .sync
case .terminate:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,11 @@ enum PostgresFrontendMessage: Equatable {
}
}

struct Query: Hashable {
/// The query string.
let query: String
}

struct Parse: Hashable {
/// The name of the destination prepared statement (an empty string selects the unnamed prepared statement).
let preparedStatementName: String
Expand Down Expand Up @@ -179,7 +184,7 @@ enum PostgresFrontendMessage: Equatable {
case saslInitialResponse(SASLInitialResponse)
case saslResponse(SASLResponse)
case sslRequest
case query
case query(Query)
case sync
case startup(Startup)
case terminate
Expand Down
71 changes: 71 additions & 0 deletions Tests/PostgresNIOTests/New/PostgresConnectionTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -317,6 +317,40 @@ class PostgresConnectionTests: XCTestCase {
}
}

func testCloseImmediatelyWithSimpleQuery() async throws {
let (connection, channel) = try await self.makeTestConnectionWithAsyncTestingChannel()

try await withThrowingTaskGroup(of: Void.self) { [logger] taskGroup async throws -> () in
for _ in 1...2 {
taskGroup.addTask {
try await connection.__simpleQuery("SELECT 1;", logger: logger)
}
}

let query = try await channel.waitForSimpleQueryRequest()
XCTAssertEqual(query.query, "SELECT 1;")

async let close: () = connection.close()

try await channel.closeFuture.get()
XCTAssertEqual(channel.isActive, false)

try await close

while let taskResult = await taskGroup.nextResult() {
switch taskResult {
case .success:
XCTFail("Expected queries to fail")
case .failure(let failure):
guard let error = failure as? PSQLError else {
return XCTFail("Unexpected error type: \(failure)")
}
XCTAssertEqual(error.code, .clientClosedConnection)
}
}
}
}

func testIfServerJustClosesTheErrorReflectsThat() async throws {
let (connection, channel) = try await self.makeTestConnectionWithAsyncTestingChannel()
let logger = self.logger
Expand Down Expand Up @@ -346,6 +380,35 @@ class PostgresConnectionTests: XCTestCase {
}
}

func testIfServerJustClosesTheErrorReflectsThatInSimpleQuery() async throws {
let (connection, channel) = try await self.makeTestConnectionWithAsyncTestingChannel()
let logger = self.logger

async let response = try await connection.__simpleQuery("SELECT 1;", logger: logger)

let query = try await channel.waitForSimpleQueryRequest()
XCTAssertEqual(query.query, "SELECT 1;")

try await channel.testingEventLoop.executeInContext { channel.pipeline.fireChannelInactive() }
try await channel.testingEventLoop.executeInContext { channel.pipeline.fireChannelUnregistered() }

do {
_ = try await response
XCTFail("Expected to throw")
} catch {
XCTAssertEqual((error as? PSQLError)?.code, .serverClosedConnection)
}

// retry on same connection

do {
_ = try await connection.__simpleQuery("SELECT 1;", logger: self.logger)
XCTFail("Expected to throw")
} catch {
XCTAssertEqual((error as? PSQLError)?.code, .serverClosedConnection)
}
}

struct TestPrepareStatement: PostgresPreparedStatement {
static let sql = "SELECT datname FROM pg_stat_activity WHERE state = $1"
typealias Row = String
Expand Down Expand Up @@ -692,6 +755,14 @@ extension NIOAsyncTestingChannel {
return UnpreparedRequest(parse: parse, describe: describe, bind: bind, execute: execute)
}

func waitForSimpleQueryRequest() async throws -> PostgresFrontendMessage.Query {
let query = try await self.waitForOutboundWrite(as: PostgresFrontendMessage.self)
guard case .query(let query) = query else {
fatalError()
}
return query
}

func waitForPrepareRequest() async throws -> PrepareRequest {
let parse = try await self.waitForOutboundWrite(as: PostgresFrontendMessage.self)
let describe = try await self.waitForOutboundWrite(as: PostgresFrontendMessage.self)
Expand Down

0 comments on commit 1ed7e7d

Please sign in to comment.