Skip to content

Commit

Permalink
Allow bindings with optional values in PostgresBindings (#520)
Browse files Browse the repository at this point in the history
  • Loading branch information
rgcottrell authored Oct 21, 2024
1 parent 225c5c4 commit d4c2f38
Show file tree
Hide file tree
Showing 2 changed files with 127 additions and 0 deletions.
46 changes: 46 additions & 0 deletions Sources/PostgresNIO/New/PostgresQuery.swift
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,16 @@ public struct PostgresBindings: Sendable, Hashable {
try self.append(value, context: .default)
}

@inlinable
public mutating func append<Value: PostgresThrowingDynamicTypeEncodable>(_ value: Optional<Value>) throws {
switch value {
case .none:
self.appendNull()
case let .some(value):
try self.append(value)
}
}

@inlinable
public mutating func append<Value: PostgresThrowingDynamicTypeEncodable, JSONEncoder: PostgresJSONEncoder>(
_ value: Value,
Expand All @@ -181,11 +191,34 @@ public struct PostgresBindings: Sendable, Hashable {
self.metadata.append(.init(value: value, protected: true))
}

@inlinable
public mutating func append<Value: PostgresThrowingDynamicTypeEncodable, JSONEncoder: PostgresJSONEncoder>(
_ value: Optional<Value>,
context: PostgresEncodingContext<JSONEncoder>
) throws {
switch value {
case .none:
self.appendNull()
case let .some(value):
try self.append(value, context: context)
}
}

@inlinable
public mutating func append<Value: PostgresDynamicTypeEncodable>(_ value: Value) {
self.append(value, context: .default)
}

@inlinable
public mutating func append<Value: PostgresDynamicTypeEncodable>(_ value: Optional<Value>) {
switch value {
case .none:
self.appendNull()
case let .some(value):
self.append(value)
}
}

@inlinable
public mutating func append<Value: PostgresDynamicTypeEncodable, JSONEncoder: PostgresJSONEncoder>(
_ value: Value,
Expand All @@ -195,6 +228,19 @@ public struct PostgresBindings: Sendable, Hashable {
self.metadata.append(.init(value: value, protected: true))
}

@inlinable
public mutating func append<Value: PostgresDynamicTypeEncodable, JSONEncoder: PostgresJSONEncoder>(
_ value: Optional<Value>,
context: PostgresEncodingContext<JSONEncoder>
) {
switch value {
case .none:
self.appendNull()
case let .some(value):
self.append(value, context: context)
}
}

@inlinable
mutating func appendUnprotected<Value: PostgresEncodable, JSONEncoder: PostgresJSONEncoder>(
_ value: Value,
Expand Down
81 changes: 81 additions & 0 deletions Tests/IntegrationTests/AsyncTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -476,6 +476,87 @@ final class AsyncPostgresConnectionTests: XCTestCase {
XCTFail("Unexpected error: \(String(describing: error))")
}
}

static let preparedStatementWithOptionalTestTable = "AsyncTestPreparedStatementWithOptionalTestTable"
func testPreparedStatementWithOptionalBinding() async throws {
let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1)
defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) }
let eventLoop = eventLoopGroup.next()

struct InsertPreparedStatement: PostgresPreparedStatement {
static let name = "INSERT-AsyncTestPreparedStatementWithOptionalTestTable"

static let sql = #"INSERT INTO "\#(AsyncPostgresConnectionTests.preparedStatementWithOptionalTestTable)" (uuid) VALUES ($1);"#
typealias Row = ()

var uuid: UUID?

func makeBindings() -> PostgresBindings {
var bindings = PostgresBindings()
bindings.append(self.uuid)
return bindings
}

func decodeRow(_ row: PostgresNIO.PostgresRow) throws -> Row {
()
}
}

struct SelectPreparedStatement: PostgresPreparedStatement {
static let name = "SELECT-AsyncTestPreparedStatementWithOptionalTestTable"

static let sql = #"SELECT id, uuid FROM "\#(AsyncPostgresConnectionTests.preparedStatementWithOptionalTestTable)" WHERE id <= $1;"#
typealias Row = (Int, UUID?)

var id: Int

func makeBindings() -> PostgresBindings {
var bindings = PostgresBindings()
bindings.append(self.id)
return bindings
}

func decodeRow(_ row: PostgresNIO.PostgresRow) throws -> Row {
try row.decode((Int, UUID?).self)
}
}

do {
try await withTestConnection(on: eventLoop) { connection in
try await connection.query("""
CREATE TABLE IF NOT EXISTS "\(unescaped: Self.preparedStatementWithOptionalTestTable)" (
id SERIAL PRIMARY KEY,
uuid UUID
)
""",
logger: .psqlTest
)

_ = try await connection.execute(InsertPreparedStatement(uuid: nil), logger: .psqlTest)
_ = try await connection.execute(InsertPreparedStatement(uuid: .init()), logger: .psqlTest)
_ = try await connection.execute(InsertPreparedStatement(uuid: nil), logger: .psqlTest)
_ = try await connection.execute(InsertPreparedStatement(uuid: .init()), logger: .psqlTest)
_ = try await connection.execute(InsertPreparedStatement(uuid: nil), logger: .psqlTest)

let rows = try await connection.execute(SelectPreparedStatement(id: 3), logger: .psqlTest)
var counter = 0
for try await (id, uuid) in rows {
Logger.psqlTest.info("Received row", metadata: [
"id": "\(id)", "uuid": "\(String(describing: uuid))"
])
counter += 1
}

try await connection.query("""
DROP TABLE "\(unescaped: Self.preparedStatementWithOptionalTestTable)";
""",
logger: .psqlTest
)
}
} catch {
XCTFail("Unexpected error: \(String(describing: error))")
}
}
}

extension XCTestCase {
Expand Down

0 comments on commit d4c2f38

Please sign in to comment.