From 6ffe70a5a0e2c56be2e0a282703e720badc3715d Mon Sep 17 00:00:00 2001 From: Timo <38291523+lovetodream@users.noreply.github.com> Date: Fri, 18 Oct 2024 14:52:44 +0100 Subject: [PATCH] Add support for batch execution (#65) --- Package.resolved | 29 +- .../OracleConnection+BatchExecution.swift | 207 ++++++++++++++ .../Connection/OracleConnection.swift | 9 + .../ConnectionStateMachine.swift | 20 +- .../StatementStateMachine.swift | 107 ++++++- .../OracleNIO/Data/Bool+OracleCodable.swift | 2 +- .../Data/ByteBuffer+OracleCodable.swift | 2 +- Sources/OracleNIO/Data/Cursor.swift | 2 +- .../OracleNIO/Data/Data+OracleCodable.swift | 2 +- .../OracleNIO/Data/Date+OracleCodable.swift | 2 +- .../OracleNIO/Data/Double+OracleCodable.swift | 2 +- .../OracleNIO/Data/Float+OracleCodable.swift | 2 +- .../OracleNIO/Data/Int+OracleCodable.swift | 2 +- .../Data/IntervalDS+OracleCodable.swift | 2 +- .../Data/JSON/OracleJSON+Encoding.swift | 4 +- Sources/OracleNIO/Data/LOB.swift | 1 + Sources/OracleNIO/Data/OracleNumber.swift | 2 +- Sources/OracleNIO/Data/OracleVector.swift | 1 + Sources/OracleNIO/Data/RowID.swift | 2 +- .../OracleNIO/Data/String+OracleCodable.swift | 4 +- .../Documentation.docc/Documentation.md | 2 + .../Coding/OracleFrontendMessageEncoder.swift | 63 ++-- .../OracleBackendMessage+RowData.swift | 2 +- Sources/OracleNIO/OracleChannelHandler.swift | 16 +- Sources/OracleNIO/OracleCodable.swift | 11 + Sources/OracleNIO/OracleRowSequence.swift | 24 +- Sources/OracleNIO/OracleRowStream.swift | 92 +++++- Sources/OracleNIO/OracleSQLError.swift | 33 +++ Sources/OracleNIO/OracleTask.swift | 101 ++++++- .../OracleBindings.swift} | 211 +++++--------- .../OraclePreparedStatement.swift | 0 .../Statements/OracleStatement.swift | 166 +++++++++++ .../BatchExecutionTests.swift | 268 ++++++++++++++++++ Tests/IntegrationTests/BugReportTests.swift | 2 +- .../StatementStateMachineTests.swift | 4 +- .../OracleNIOTests/OracleStatementTests.swift | 2 +- .../ConnectionAction+TestUtils.swift | 6 +- .../TestUtils/OracleRowStream+TestUtils.swift | 10 +- .../TestUtils/QueryResult+TestUtils.swift | 2 +- docker-compose.yaml | 2 +- 40 files changed, 1182 insertions(+), 239 deletions(-) create mode 100644 Sources/OracleNIO/Connection/OracleConnection+BatchExecution.swift rename Sources/OracleNIO/{OracleStatement.swift => Statements/OracleBindings.swift} (71%) rename Sources/OracleNIO/{ => Statements}/OraclePreparedStatement.swift (100%) create mode 100644 Sources/OracleNIO/Statements/OracleStatement.swift create mode 100644 Tests/IntegrationTests/BatchExecutionTests.swift diff --git a/Package.resolved b/Package.resolved index 2d26f11..7110320 100644 --- a/Package.resolved +++ b/Package.resolved @@ -1,12 +1,21 @@ { "pins" : [ + { + "identity" : "swift-asn1", + "kind" : "remoteSourceControl", + "location" : "https://github.com/apple/swift-asn1.git", + "state" : { + "revision" : "7faebca1ea4f9aaf0cda1cef7c43aecd2311ddf6", + "version" : "1.3.0" + } + }, { "identity" : "swift-async-algorithms", "kind" : "remoteSourceControl", "location" : "https://github.com/apple/swift-async-algorithms.git", "state" : { - "revision" : "6ae9a051f76b81cc668305ceed5b0e0a7fd93d20", - "version" : "1.0.1" + "revision" : "5c8bd186f48c16af0775972700626f0b74588278", + "version" : "1.0.2" } }, { @@ -32,8 +41,8 @@ "kind" : "remoteSourceControl", "location" : "https://github.com/apple/swift-crypto.git", "state" : { - "revision" : "9f95b4d033a4edd3814b48608db3f2ca90c7218b", - "version" : "3.7.0" + "revision" : "21f7878f2b39d46fd8ba2b06459ccb431cdf876c", + "version" : "3.8.1" } }, { @@ -50,8 +59,8 @@ "kind" : "remoteSourceControl", "location" : "https://github.com/apple/swift-nio.git", "state" : { - "revision" : "9746cf80e29edfef2a39924a66731249223f42a3", - "version" : "2.72.0" + "revision" : "f7dc3f527576c398709b017584392fb58592e7f5", + "version" : "2.75.0" } }, { @@ -59,8 +68,8 @@ "kind" : "remoteSourceControl", "location" : "https://github.com/apple/swift-nio-ssl.git", "state" : { - "revision" : "7b84abbdcef69cc3be6573ac12440220789dcd69", - "version" : "2.27.2" + "revision" : "d7ceaf0e4d8001cd35cdc12e42cdd281e9e564e8", + "version" : "2.28.0" } }, { @@ -68,8 +77,8 @@ "kind" : "remoteSourceControl", "location" : "https://github.com/apple/swift-nio-transport-services.git", "state" : { - "revision" : "38ac8221dd20674682148d6451367f89c2652980", - "version" : "1.21.0" + "revision" : "dbace16f126fdcd80d58dc54526c561ca17327d7", + "version" : "1.22.0" } }, { diff --git a/Sources/OracleNIO/Connection/OracleConnection+BatchExecution.swift b/Sources/OracleNIO/Connection/OracleConnection+BatchExecution.swift new file mode 100644 index 0000000..26baa25 --- /dev/null +++ b/Sources/OracleNIO/Connection/OracleConnection+BatchExecution.swift @@ -0,0 +1,207 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the OracleNIO open source project +// +// Copyright (c) 2024 Timo Zacherl and the OracleNIO project authors +// Licensed under Apache License v2.0 +// +// See LICENSE for license information +// See CONTRIBUTORS.md for the list of OracleNIO project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// + +import Logging + +extension OracleConnection { + /// Executes the statement multiple times using the specified bind collections without requiring multiple roundtrips to the database. + /// - Parameters: + /// - statement: The raw SQL statement. + /// - binds: A collection of bind parameters to execute the statement with. The statement will execute `binds.count` times. + /// - encodingContext: The ``OracleEncodingContext`` used to encode the binds. A default parameter is provided. + /// - options: A bunch of parameters to optimize the statement in different ways. + /// Normally this can be ignored, but feel free to experiment based on your needs. + /// Every option and its impact is documented. + /// - logger: The `Logger` to log statement related background events into. Defaults to logging disabled. + /// - file: The file, the statement was started in. Used for better error reporting. + /// - line: The line, the statement was started in. Used for better error reporting. + /// - Returns: A ``OracleBatchExecutionResult`` containing the amount of affected rows and other metadata the server sent. + /// + /// Batch execution is useful for inserting or updating multiple rows efficiently when working with large data sets. It significally outperforms + /// repeated calls to ``execute(_:options:logger:file:line:)-9uyvp`` by reducing network transfer costs and database overheads. + /// It can also be used to execute PL/SQL statements multiple times at once. + /// ```swift + /// let binds: [(Int, String, Int)] = [ + /// (1, "John", 20), + /// (2, "Jane", 30), + /// (3, "Jack", 40), + /// (4, "Jill", 50), + /// (5, "Pete", 60), + /// ] + /// try await connection.executeBatch( + /// "INSERT INTO users (id, name, age) VALUES (:1, :2, :3)", + /// binds: binds + /// ) + /// ``` + @discardableResult + public func executeBatch( + _ statement: String, + binds: [(repeat (each Bind)?)], + encodingContext: OracleEncodingContext = .default, + options: StatementOptions = .init(), + logger: Logger? = nil, + file: String = #fileID, line: Int = #line + ) async throws -> OracleBatchExecutionResult { + var logger = logger ?? Self.noopLogger + logger[oracleMetadataKey: .connectionID] = "\(self.id)" + logger[oracleMetadataKey: .sessionID] = "\(self.sessionID)" + + var collection = OracleBindingsCollection() + for row in binds { + try collection.appendRow(repeat each row, context: encodingContext) + } + + return try await _executeBatch( + statement: statement, + collection: collection, + options: options, + logger: logger, + file: file, + line: line + ) + } + + /// Executes the prepared statements without requiring multiple roundtrips to the database. + /// - Parameters: + /// - statements: The prepared statements to execute. + /// - options: A bunch of parameters to optimize the statement in different ways. + /// Normally this can be ignored, but feel free to experiment based on your needs. + /// Every option and its impact is documented. + /// - logger: The `Logger` to log statement related background events into. Defaults to logging disabled. + /// - file: The file, the statement was started in. Used for better error reporting. + /// - line: The line, the statement was started in. Used for better error reporting. + /// - Returns: A ``OracleBatchExecutionResult`` containing the amount of affected rows and other metadata the server sent. + /// + /// Batch execution is useful for inserting or updating multiple rows efficiently when working with large data sets. It significally outperforms + /// repeated calls to ``execute(_:options:logger:file:line:)-9uyvp`` by reducing network transfer costs and database overheads. + /// It can also be used to execute PL/SQL statements multiple times at once. + /// ```swift + /// try await connection.executeBatch([ + /// InsertUserStatement(id: 1, name: "John", age: 20), + /// InsertUserStatement(id: 2, name: "Jane", age: 30), + /// InsertUserStatement(id: 3, name: "Jack", age: 40), + /// InsertUserStatement(id: 4, name: "Jill", age: 50), + /// InsertUserStatement(id: 5, name: "Pete", age: 60), + /// ]) + /// ``` + @discardableResult + public func executeBatch( + _ statements: [Statement], + options: StatementOptions = .init(), + logger: Logger? = nil, + file: String = #fileID, line: Int = #line + ) async throws -> OracleBatchExecutionResult { + if statements.isEmpty { + throw OracleSQLError.missingStatement + } + + var logger = logger ?? Self.noopLogger + logger[oracleMetadataKey: .connectionID] = "\(self.id)" + logger[oracleMetadataKey: .sessionID] = "\(self.sessionID)" + + var collection = OracleBindingsCollection() + for statement in statements { + try collection.appendRow(statement.makeBindings()) + } + return try await _executeBatch( + statement: Statement.sql, + collection: collection, + options: options, + logger: logger, + file: file, + line: line + ) + } + + private func _executeBatch( + statement: String, + collection: OracleBindingsCollection, + options: StatementOptions, + logger: Logger, + file: String, + line: Int + ) async throws -> OracleBatchExecutionResult { + let promise = self.channel.eventLoop.makePromise( + of: OracleRowStream.self + ) + let context = StatementContext( + statement: statement, + bindCollection: collection, + options: options, + logger: logger, + promise: promise + ) + + self.channel.write(OracleTask.statement(context), promise: nil) + + do { + let stream = try await promise.futureResult + .map({ $0.asyncSequence() }) + .get() + let affectedRows = try await stream.affectedRows + let affectedRowsPerStatement = options.arrayDMLRowCounts ? stream.rowCounts : nil + let batchErrors = options.batchErrors ? stream.batchErrors : nil + let result = OracleBatchExecutionResult( + affectedRows: affectedRows, + affectedRowsPerStatement: affectedRowsPerStatement + ) + if let batchErrors { + throw OracleBatchExecutionError( + result: result, + errors: batchErrors, + statement: statement, + file: file, + line: line + ) + } + return result + } catch var error as OracleSQLError { + error.file = file + error.line = line + error.statement = .init(unsafeSQL: statement) + throw error // rethrow with more metadata + } + } +} + +/// The result of a batch execution. +public struct OracleBatchExecutionResult: Sendable { + /// The total amount of affected rows. + public let affectedRows: Int + /// The amount of affected rows per statement. + /// + /// - Note: Only available if ``StatementOptions/arrayDMLRowCounts`` is set to `true`. + /// + /// For example, if five single row `INSERT` statements are executed and the fifth one fails, the following array would be returned. + /// ```swift + /// [1, 1, 1, 1, 0] + /// ``` + public let affectedRowsPerStatement: [Int]? +} + +/// An error that is thrown when a batch execution contains both successful and failed statements. +/// +/// - Note: This error is only thrown when ``StatementOptions/batchErrors`` is set to `true`. +/// Otherwise ``OracleSQLError`` will be thrown as usual. Be aware that all the statements +/// executed before the error is thrown won't be reverted regardless of this setting. +/// They can still be reverted using a ``OracleConnection/rollback()``. +public struct OracleBatchExecutionError: Error, Sendable { + /// The result of the partially finished batch execution. + public let result: OracleBatchExecutionResult + /// A collection of errors thrown by statements in the batch execution. + public let errors: [OracleSQLError.BatchError] + public let statement: String + public let file: String + public let line: Int +} diff --git a/Sources/OracleNIO/Connection/OracleConnection.swift b/Sources/OracleNIO/Connection/OracleConnection.swift index b577808..ea38e7e 100644 --- a/Sources/OracleNIO/Connection/OracleConnection.swift +++ b/Sources/OracleNIO/Connection/OracleConnection.swift @@ -475,6 +475,15 @@ extension OracleConnection { } /// Execute a prepared statement. + /// - Parameters: + /// - statement: The statement to be executed. + /// - options: A bunch of parameters to optimize the statement in different ways. + /// Normally this can be ignored, but feel free to experiment based on your needs. + /// Every option and its impact is documented. + /// - logger: The `Logger` to log statement related background events into. Defaults to logging disabled. + /// - file: The file, the statement was started in. Used for better error reporting. + /// - line: The line, the statement was started in. Used for better error reporting. + /// - Returns: An async sequence of `Row`s. The result sequence can be discarded if the statement has no result. public func execute( _ statement: Statement, options: StatementOptions = .init(), diff --git a/Sources/OracleNIO/ConnectionStateMachine/ConnectionStateMachine.swift b/Sources/OracleNIO/ConnectionStateMachine/ConnectionStateMachine.swift index d03ddee..ece4389 100644 --- a/Sources/OracleNIO/ConnectionStateMachine/ConnectionStateMachine.swift +++ b/Sources/OracleNIO/ConnectionStateMachine/ConnectionStateMachine.swift @@ -137,7 +137,7 @@ struct ConnectionStateMachine { // Statement streaming case forwardRows([DataRow]) - case forwardStreamComplete([DataRow], cursorID: UInt16) + case forwardStreamComplete([DataRow], cursorID: UInt16, affectedRows: Int) case forwardStreamError( OracleSQLError, read: Bool, cursorID: UInt16?, clientCancelled: Bool ) @@ -639,7 +639,17 @@ struct ConnectionStateMachine { mutating func queryParameterReceived( _ parameter: OracleBackendMessage.QueryParameter ) -> ConnectionAction { - return .wait + switch self.state { + case .statement(var statement): + return self.avoidingStateMachineCoW { machine in + let action = statement.queryParameterReceived(parameter) + machine.state = .statement(statement) + return machine.modify(with: action) + } + default: + assertionFailure("Invalid state: \(self.state)") + return self.closeConnectionAndCleanup(.unexpectedBackendMessage(.queryParameter(parameter))) + } } mutating func bitVectorReceived( @@ -1016,7 +1026,7 @@ extension ConnectionStateMachine { .uncleanShutdown, .unsupportedDataType: return true - case .statementCancelled, .nationalCharsetNotSupported: + case .statementCancelled, .nationalCharsetNotSupported, .missingStatement: return false case .server: switch error.serverInfo?.number { @@ -1147,8 +1157,8 @@ extension ConnectionStateMachine { return .succeedStatement(promise, columns) case .forwardRows(let rows): return .forwardRows(rows) - case .forwardStreamComplete(let rows, let cursorID): - return .forwardStreamComplete(rows, cursorID: cursorID) + case .forwardStreamComplete(let rows, let cursorID, let affectedRows): + return .forwardStreamComplete(rows, cursorID: cursorID, affectedRows: affectedRows) case .forwardStreamError( let error, let read, let cursorID, let clientCancelled ): diff --git a/Sources/OracleNIO/ConnectionStateMachine/StatementStateMachine.swift b/Sources/OracleNIO/ConnectionStateMachine/StatementStateMachine.swift index 16f1396..40c70f4 100644 --- a/Sources/OracleNIO/ConnectionStateMachine/StatementStateMachine.swift +++ b/Sources/OracleNIO/ConnectionStateMachine/StatementStateMachine.swift @@ -18,6 +18,7 @@ struct StatementStateMachine { private enum State { case initialized(StatementContext) + case rowCountsReceived(StatementContext, [Int]) case describeInfoReceived(StatementContext, DescribeInfo) case streaming( StatementContext, @@ -47,7 +48,7 @@ struct StatementStateMachine { case evaluateErrorAtConnectionLevel(OracleSQLError) case forwardRows([DataRow]) - case forwardStreamComplete([DataRow], cursorID: UInt16) + case forwardStreamComplete([DataRow], cursorID: UInt16, affectedRows: Int) /// Error payload and a optional cursor ID, which should be closed in a future roundtrip. case forwardStreamError( OracleSQLError, @@ -90,7 +91,7 @@ struct StatementStateMachine { "Start must be called immediately after the statement was created" ) - case .describeInfoReceived(let context, _): + case .rowCountsReceived(let context, _), .describeInfoReceived(let context, _): guard !self.isCancelled else { return .wait } @@ -162,7 +163,9 @@ struct StatementStateMachine { promise, StatementResult( value: .describeInfo(describeInfo.columns), - logger: context.logger + logger: context.logger, + batchErrors: nil, + rowCounts: nil ) ) } @@ -185,7 +188,7 @@ struct StatementStateMachine { // marker. Due to that it might send more data without knowing yet. return .wait - case .initialized, .error, .commandComplete: + case .initialized, .rowCountsReceived, .error, .commandComplete: preconditionFailure("Invalid state: \(self.state)") case .modifying: @@ -199,7 +202,7 @@ struct StatementStateMachine { ) -> Action { switch self.state { case .initialized(let context): - let outBinds = context.statement.binds.metadata.compactMap(\.outContainer) + let outBinds = context.binds.metadata.compactMap(\.outContainer) precondition(rowData.columns.count == outBinds.count) for (index, column) in rowData.columns.enumerated() { switch column { @@ -270,9 +273,21 @@ struct StatementStateMachine { } } + mutating func queryParameterReceived(_ parameter: OracleBackendMessage.QueryParameter) -> Action { + if let rowCounts = parameter.rowCounts { + guard case .initialized(let statementContext) = state else { + preconditionFailure("Invalid state: \(self.state)") + } + self.state = .modifying + self.state = .rowCountsReceived(statementContext, rowCounts.map(Int.init)) + } + return .wait + } + mutating func errorReceived( _ error: OracleBackendMessage.BackendError ) -> Action { + let batchErrors = error.batchErrors.map(OracleSQLError.BatchError.init) let action: Action if Constants.TNS_ERR_NO_DATA_FOUND == error.number @@ -281,6 +296,7 @@ struct StatementStateMachine { switch self.state { case .commandComplete, .error, .drain: return .wait // stream has already finished + case .initialized(let context), .describeInfoReceived(let context, _): context.cursorID = error.cursorID ?? context.cursorID @@ -297,7 +313,38 @@ struct StatementStateMachine { .cursor(_, let promise), .plain(let promise): action = .succeedStatement( - promise, .init(value: .noRows, logger: context.logger) + promise, + .init( + value: .noRows(affectedRows: Int(error.rowCount ?? 0)), + logger: context.logger, + batchErrors: batchErrors, + rowCounts: nil + ) + ) // empty response + } + + case .rowCountsReceived(let context, let rowCounts): + context.cursorID = error.cursorID ?? context.cursorID + + self.avoidingStateMachineCoWVoid { state in + state = .commandComplete + } + + switch context.type { + case .query(let promise), + .plsql(let promise), + .dml(let promise), + .ddl(let promise), + .cursor(_, let promise), + .plain(let promise): + action = .succeedStatement( + promise, + .init( + value: .noRows(affectedRows: Int(error.rowCount ?? 0)), + logger: context.logger, + batchErrors: batchErrors, + rowCounts: rowCounts + ) ) // empty response } @@ -309,7 +356,8 @@ struct StatementStateMachine { } let rows = demandStateMachine.end() - action = .forwardStreamComplete(rows, cursorID: context.cursorID) + action = .forwardStreamComplete( + rows, cursorID: context.cursorID, affectedRows: Int(error.rowCount ?? 0)) case .modifying: preconditionFailure("Invalid state: \(self.state)") @@ -399,7 +447,33 @@ struct StatementStateMachine { action = .succeedStatement( promise, StatementResult( - value: .noRows, logger: context.logger + value: .noRows(affectedRows: Int(error.rowCount ?? 0)), + logger: context.logger, + batchErrors: batchErrors, + rowCounts: nil + ) + ) + } + self.state = .commandComplete + + case .rowCountsReceived(let context, let rowCounts): + if let cursorID = error.cursorID { + context.cursorID = cursorID + } + switch context.type { + case .query(let promise), + .plsql(let promise), + .dml(let promise), + .ddl(let promise), + .cursor(_, let promise), + .plain(let promise): + action = .succeedStatement( + promise, + StatementResult( + value: .noRows(affectedRows: Int(error.rowCount ?? 0)), + logger: context.logger, + batchErrors: batchErrors, + rowCounts: rowCounts ) ) } @@ -492,10 +566,10 @@ struct StatementStateMachine { ) -> Action { switch self.state { case .initialized(let context): - guard context.statement.binds.count == vector.bindMetadata.count else { + guard context.binds.count == vector.bindMetadata.count else { preconditionFailure( """ - mismatch in binds - sent: \(context.statement.binds.count), \ + mismatch in binds - sent: \(context.binds.count), \ received: \(vector.bindMetadata.count) """) } @@ -503,7 +577,8 @@ struct StatementStateMachine { // we won't change the state return .wait - case .describeInfoReceived, + case .rowCountsReceived, + .describeInfoReceived, .streaming, .drain, .commandComplete, @@ -526,7 +601,8 @@ struct StatementStateMachine { // round trip. return .sendFlushOutBinds - case .describeInfoReceived, + case .rowCountsReceived, + .describeInfoReceived, .streaming, .drain, .commandComplete, @@ -545,6 +621,7 @@ struct StatementStateMachine { mutating func requestFetch() -> Action { switch self.state { case .initialized(let context), + .rowCountsReceived(let context, _), .describeInfoReceived(let context, _), .streaming(let context, _, _, _): return .sendFetch(context) @@ -580,7 +657,7 @@ struct StatementStateMachine { } } - case .drain, .describeInfoReceived: + case .drain, .rowCountsReceived, .describeInfoReceived: return .wait case .initialized: @@ -605,6 +682,7 @@ struct StatementStateMachine { mutating func channelReadComplete() -> Action { switch self.state { case .initialized, + .rowCountsReceived, .describeInfoReceived, .drain, .commandComplete, @@ -654,6 +732,7 @@ struct StatementStateMachine { .commandComplete, .drain, .error, + .rowCountsReceived, .describeInfoReceived: // we already have the complete stream received, now we are waiting // for a `readyForStatement` package. To receive this we need to read. @@ -669,6 +748,7 @@ struct StatementStateMachine { private mutating func setAndFireError(_ error: OracleSQLError) -> Action { switch self.state { case .initialized(let context), + .rowCountsReceived(let context, _), .describeInfoReceived(let context, _): if self.isCancelled { return .evaluateErrorAtConnectionLevel(error) @@ -714,6 +794,7 @@ struct StatementStateMachine { var isComplete: Bool { switch self.state { case .initialized, + .rowCountsReceived, .describeInfoReceived, .streaming, .drain: diff --git a/Sources/OracleNIO/Data/Bool+OracleCodable.swift b/Sources/OracleNIO/Data/Bool+OracleCodable.swift index 841a461..59d1d5f 100644 --- a/Sources/OracleNIO/Data/Bool+OracleCodable.swift +++ b/Sources/OracleNIO/Data/Bool+OracleCodable.swift @@ -15,7 +15,7 @@ import NIOCore extension Bool: OracleEncodable { - public var oracleType: OracleDataType { .boolean } + public static var defaultOracleType: OracleDataType { .boolean } public func encode( into buffer: inout ByteBuffer, diff --git a/Sources/OracleNIO/Data/ByteBuffer+OracleCodable.swift b/Sources/OracleNIO/Data/ByteBuffer+OracleCodable.swift index 2957f19..c4ef153 100644 --- a/Sources/OracleNIO/Data/ByteBuffer+OracleCodable.swift +++ b/Sources/OracleNIO/Data/ByteBuffer+OracleCodable.swift @@ -15,7 +15,7 @@ import NIOCore extension ByteBuffer: OracleEncodable { - public var oracleType: OracleDataType { .raw } + public static var defaultOracleType: OracleDataType { .raw } public func encode( into buffer: inout ByteBuffer, diff --git a/Sources/OracleNIO/Data/Cursor.swift b/Sources/OracleNIO/Data/Cursor.swift index 45a7d74..ac4d093 100644 --- a/Sources/OracleNIO/Data/Cursor.swift +++ b/Sources/OracleNIO/Data/Cursor.swift @@ -37,7 +37,7 @@ public struct Cursor { } extension Cursor: OracleEncodable { - public var oracleType: OracleDataType { .cursor } + public static var defaultOracleType: OracleDataType { .cursor } public func encode( into buffer: inout ByteBuffer, diff --git a/Sources/OracleNIO/Data/Data+OracleCodable.swift b/Sources/OracleNIO/Data/Data+OracleCodable.swift index a09e4f7..6921c0d 100644 --- a/Sources/OracleNIO/Data/Data+OracleCodable.swift +++ b/Sources/OracleNIO/Data/Data+OracleCodable.swift @@ -17,7 +17,7 @@ import NIOCore import struct Foundation.Data extension Data: OracleEncodable { - public var oracleType: OracleDataType { .raw } + public static var defaultOracleType: OracleDataType { .raw } public func encode( into buffer: inout ByteBuffer, diff --git a/Sources/OracleNIO/Data/Date+OracleCodable.swift b/Sources/OracleNIO/Data/Date+OracleCodable.swift index e403688..4c1a4f9 100644 --- a/Sources/OracleNIO/Data/Date+OracleCodable.swift +++ b/Sources/OracleNIO/Data/Date+OracleCodable.swift @@ -21,7 +21,7 @@ import struct Foundation.TimeZone import func Foundation.pow extension Date: OracleEncodable { - public var oracleType: OracleDataType { .timestampTZ } + public static var defaultOracleType: OracleDataType { .timestampTZ } public func encode( into buffer: inout ByteBuffer, diff --git a/Sources/OracleNIO/Data/Double+OracleCodable.swift b/Sources/OracleNIO/Data/Double+OracleCodable.swift index 5e71356..4d28d72 100644 --- a/Sources/OracleNIO/Data/Double+OracleCodable.swift +++ b/Sources/OracleNIO/Data/Double+OracleCodable.swift @@ -15,7 +15,7 @@ import NIOCore extension Double: OracleEncodable { - public var oracleType: OracleDataType { + public static var defaultOracleType: OracleDataType { .binaryDouble } diff --git a/Sources/OracleNIO/Data/Float+OracleCodable.swift b/Sources/OracleNIO/Data/Float+OracleCodable.swift index cd74f17..b1c8755 100644 --- a/Sources/OracleNIO/Data/Float+OracleCodable.swift +++ b/Sources/OracleNIO/Data/Float+OracleCodable.swift @@ -15,7 +15,7 @@ import NIOCore extension Float: OracleEncodable { - public var oracleType: OracleDataType { .binaryFloat } + public static var defaultOracleType: OracleDataType { .binaryFloat } public func encode( into buffer: inout ByteBuffer, diff --git a/Sources/OracleNIO/Data/Int+OracleCodable.swift b/Sources/OracleNIO/Data/Int+OracleCodable.swift index 1098312..b19a435 100644 --- a/Sources/OracleNIO/Data/Int+OracleCodable.swift +++ b/Sources/OracleNIO/Data/Int+OracleCodable.swift @@ -178,7 +178,7 @@ extension UInt: OracleDecodable { // MARK: Int extension Int: OracleEncodable { - public var oracleType: OracleDataType { .binaryInteger } + public static var defaultOracleType: OracleDataType { .binaryInteger } public func encode( into buffer: inout ByteBuffer, diff --git a/Sources/OracleNIO/Data/IntervalDS+OracleCodable.swift b/Sources/OracleNIO/Data/IntervalDS+OracleCodable.swift index 9affec8..67dac5b 100644 --- a/Sources/OracleNIO/Data/IntervalDS+OracleCodable.swift +++ b/Sources/OracleNIO/Data/IntervalDS+OracleCodable.swift @@ -76,7 +76,7 @@ extension IntervalDS: Decodable { } extension IntervalDS: OracleEncodable { - public var oracleType: OracleDataType { .intervalDS } + public static var defaultOracleType: OracleDataType { .intervalDS } public func encode( into buffer: inout ByteBuffer, diff --git a/Sources/OracleNIO/Data/JSON/OracleJSON+Encoding.swift b/Sources/OracleNIO/Data/JSON/OracleJSON+Encoding.swift index d50430e..7e7ff40 100644 --- a/Sources/OracleNIO/Data/JSON/OracleJSON+Encoding.swift +++ b/Sources/OracleNIO/Data/JSON/OracleJSON+Encoding.swift @@ -23,9 +23,7 @@ extension OracleJSON: Encodable where Value: Encodable { } extension OracleJSON: OracleThrowingDynamicTypeEncodable where Value: Encodable { - public var oracleType: OracleDataType { - .json - } + public static var defaultOracleType: OracleDataType { .json } public func _encodeRaw( into buffer: inout ByteBuffer, diff --git a/Sources/OracleNIO/Data/LOB.swift b/Sources/OracleNIO/Data/LOB.swift index 1b521c8..ab44f52 100644 --- a/Sources/OracleNIO/Data/LOB.swift +++ b/Sources/OracleNIO/Data/LOB.swift @@ -87,6 +87,7 @@ public final class LOB: Sendable { let locator: NIOLockedValueBox<[UInt8]> private let hasMetadata: Bool + public static let defaultOracleType: OracleDataType = .blob public let oracleType: OracleDataType init( diff --git a/Sources/OracleNIO/Data/OracleNumber.swift b/Sources/OracleNIO/Data/OracleNumber.swift index c9def0d..1ab1aca 100644 --- a/Sources/OracleNIO/Data/OracleNumber.swift +++ b/Sources/OracleNIO/Data/OracleNumber.swift @@ -117,7 +117,7 @@ extension OracleNumber: OracleDecodable { } extension OracleNumber: OracleEncodable { - public var oracleType: OracleDataType { .number } + public static var defaultOracleType: OracleDataType { .number } public func encode( into buffer: inout ByteBuffer, diff --git a/Sources/OracleNIO/Data/OracleVector.swift b/Sources/OracleNIO/Data/OracleVector.swift index 85d17b9..d7e982c 100644 --- a/Sources/OracleNIO/Data/OracleVector.swift +++ b/Sources/OracleNIO/Data/OracleVector.swift @@ -365,6 +365,7 @@ where Index == Int { } extension _OracleVectorProtocol { + public static var defaultOracleType: OracleDataType { .vector } public var oracleType: OracleDataType { .vector } @inlinable public var count: Int { base.count } @inlinable public var startIndex: Index { 0 } diff --git a/Sources/OracleNIO/Data/RowID.swift b/Sources/OracleNIO/Data/RowID.swift index 505240b..624208c 100644 --- a/Sources/OracleNIO/Data/RowID.swift +++ b/Sources/OracleNIO/Data/RowID.swift @@ -113,7 +113,7 @@ extension RowID: OracleDecodable { } extension RowID: OracleEncodable { - public var oracleType: OracleDataType { .rowID } + public static var defaultOracleType: OracleDataType { .rowID } public func encode( into buffer: inout ByteBuffer, diff --git a/Sources/OracleNIO/Data/String+OracleCodable.swift b/Sources/OracleNIO/Data/String+OracleCodable.swift index 6c661b5..1c37cfd 100644 --- a/Sources/OracleNIO/Data/String+OracleCodable.swift +++ b/Sources/OracleNIO/Data/String+OracleCodable.swift @@ -31,9 +31,7 @@ extension String: OracleEncodable { ._encodeRaw(into: &buffer, context: context) } - public var oracleType: OracleDataType { - .varchar - } + public static var defaultOracleType: OracleDataType { .varchar } public var size: UInt32 { // empty strings have a length of 1 diff --git a/Sources/OracleNIO/Documentation.docc/Documentation.md b/Sources/OracleNIO/Documentation.docc/Documentation.md index cc3c515..21c7f8d 100644 --- a/Sources/OracleNIO/Documentation.docc/Documentation.md +++ b/Sources/OracleNIO/Documentation.docc/Documentation.md @@ -49,6 +49,7 @@ Oracle Database 12.1 or later. - ``StatementOptions`` - ``OraclePreparedStatement`` - ``Statement(_:)`` +- ``OracleBatchExecutionResult`` ### Encoding and Decoding @@ -79,3 +80,4 @@ Oracle Database 12.1 or later. - ``OracleSQLError`` - ``OracleDecodingError`` +- ``OracleBatchExecutionError`` diff --git a/Sources/OracleNIO/Messages/Coding/OracleFrontendMessageEncoder.swift b/Sources/OracleNIO/Messages/Coding/OracleFrontendMessageEncoder.swift index 829c908..9a0220b 100644 --- a/Sources/OracleNIO/Messages/Coding/OracleFrontendMessageEncoder.swift +++ b/Sources/OracleNIO/Messages/Coding/OracleFrontendMessageEncoder.swift @@ -441,7 +441,6 @@ struct OracleFrontendMessageEncoder { ) { self.clearIfNeeded() - let statement = statementContext.statement let statementOptions = statementContext.options // 1. options @@ -450,12 +449,12 @@ struct OracleFrontendMessageEncoder { var parametersCount: UInt32 = 0 var iterationsCount: UInt32 = 1 - if !statementContext.requiresDefine && statement.binds.count != 0 { - parametersCount = .init(statement.binds.count) + if !statementContext.requiresDefine && statementContext.binds.count != 0 { + parametersCount = .init(statementContext.binds.count) } if statementContext.requiresDefine { options |= Constants.TNS_EXEC_OPTION_DEFINE - } else if !statement.sql.isEmpty { + } else if !statementContext.sql.isEmpty { dmlOptions = Constants.TNS_EXEC_OPTION_IMPLICIT_RESULTSET options |= Constants.TNS_EXEC_OPTION_EXECUTE } @@ -484,7 +483,7 @@ struct OracleFrontendMessageEncoder { options |= Constants.TNS_EXEC_OPTION_BATCH_ERRORS } if statementContext.options.arrayDMLRowCounts { - options |= Constants.TNS_EXEC_OPTION_DML_ROWCOUNTS + dmlOptions |= Constants.TNS_EXEC_OPTION_DML_ROWCOUNTS } if statementOptions.autoCommit { options |= Constants.TNS_EXEC_OPTION_COMMIT @@ -550,7 +549,7 @@ struct OracleFrontendMessageEncoder { self.buffer.writeUB4(0) // al8regid_msb if statementOptions.arrayDMLRowCounts { self.buffer.writeInteger(UInt8(1)) // pointer (al8pidmlrc) - self.buffer.writeUB4(1) // al8pidmlrcbl / numberOfExecutions + self.buffer.writeUB4(statementContext.executionCount) // al8pidmlrcbl / numberOfExecutions self.buffer.writeInteger(UInt8(1)) // pointer (al8pidmlrcl) } else { self.buffer.writeInteger(UInt8(0)) // pointer (al8pidmlrc) @@ -573,8 +572,7 @@ struct OracleFrontendMessageEncoder { } } if statementContext.cursorID == 0 || statementContext.type.isDDL { - statementContext.statement.sql - ._encodeRaw(into: &self.buffer, context: .default) + statementContext.sql._encodeRaw(into: &self.buffer, context: .default) self.buffer.writeUB4(1) // al8i4[0] parse } else { self.buffer.writeUB4(0) // al8i4[0] parse @@ -586,7 +584,7 @@ struct OracleFrontendMessageEncoder { self.buffer.writeUB4(iterationsCount) } } else { - self.buffer.writeUB4(1) // al8i4[1] execution count + self.buffer.writeUB4(statementContext.executionCount) // al8i4[1] execution count } self.buffer.writeUB4(0) // al8i4[2] self.buffer.writeUB4(0) // al8i4[3] @@ -608,7 +606,7 @@ struct OracleFrontendMessageEncoder { self.writeColumnMetadata(columns) } else if parametersCount > 0 { - self.writeBindParameters(statementContext.statement.binds) + self.writeBindParameters(statementContext.binds) } self.endRequest() @@ -629,7 +627,7 @@ struct OracleFrontendMessageEncoder { } else { functionCode = .reexecute } - let parameters = statementContext.statement.binds + let parameters = statementContext.binds var executionFlags1: UInt32 = 0 var executionFlags2: UInt32 = 0 var numberOfIterations: UInt32 = 0 @@ -654,8 +652,18 @@ struct OracleFrontendMessageEncoder { self.buffer.writeUB4(executionFlags1) self.buffer.writeUB4(executionFlags2) if !parameters.metadata.isEmpty { - self.buffer.writeOracleMessageID(.rowData) - self.writeBindParameterRow(bindings: parameters) + switch parameters { + case .one(let binds): + self.buffer.writeOracleMessageID(.rowData) + self.writeBindParameterRow(bytes: binds.bytes, longBytes: binds.longBytes) + case .many(let collection): + for binds in collection.bindings { + self.buffer.writeOracleMessageID(.rowData) + self.writeBindParameterRow(bytes: binds.0, longBytes: binds.long) + } + case .none: + preconditionFailure("How can no binds have metadata?") + } } self.endRequest() } @@ -1147,16 +1155,23 @@ extension OracleFrontendMessageEncoder { } } - private mutating func writeBindParameters(_ binds: OracleBindings) { + private mutating func writeBindParameters(_ binds: StatementContext.Binds) { self.writeColumnMetadata(binds.metadata) // write parameter values unless statement contains only return binds - if !binds.metadata.isEmpty - && (binds.bytes.readableBytes > 0 - || binds.longBytes.readableBytes > 0) - { - self.buffer.writeOracleMessageID(.rowData) - self.writeBindParameterRow(bindings: binds) + if binds.hasData { + switch binds { + case .one(let binds): + self.buffer.writeOracleMessageID(.rowData) + self.writeBindParameterRow(bytes: binds.bytes, longBytes: binds.longBytes) + case .many(let collection): + for binds in collection.bindings { + self.buffer.writeOracleMessageID(.rowData) + self.writeBindParameterRow(bytes: binds.0, longBytes: binds.long) + } + case .none: + preconditionFailure("How can no binds have data?") + } } } @@ -1221,10 +1236,10 @@ extension OracleFrontendMessageEncoder { } } - private mutating func writeBindParameterRow(bindings: OracleBindings) { - self.buffer.writeImmutableBuffer(bindings.bytes) - if bindings.longBytes.readableBytes > 0 { - self.buffer.writeImmutableBuffer(bindings.longBytes) + private mutating func writeBindParameterRow(bytes: ByteBuffer, longBytes: ByteBuffer) { + self.buffer.writeImmutableBuffer(bytes) + if longBytes.readableBytes > 0 { + self.buffer.writeImmutableBuffer(longBytes) } } diff --git a/Sources/OracleNIO/Messages/OracleBackendMessage+RowData.swift b/Sources/OracleNIO/Messages/OracleBackendMessage+RowData.swift index 27902ef..04f89ed 100644 --- a/Sources/OracleNIO/Messages/OracleBackendMessage+RowData.swift +++ b/Sources/OracleNIO/Messages/OracleBackendMessage+RowData.swift @@ -261,7 +261,7 @@ extension OracleBackendMessage { statementContext: StatementContext, capabilities: Capabilities ) throws -> [ColumnStorage] { - let outBinds = statementContext.statement.binds.metadata.compactMap(\.outContainer) + let outBinds = statementContext.binds.metadata.compactMap(\.outContainer) guard !outBinds.isEmpty else { preconditionFailure() } var columns: [ColumnStorage] = [] if statementContext.isReturning { diff --git a/Sources/OracleNIO/OracleChannelHandler.swift b/Sources/OracleNIO/OracleChannelHandler.swift index 25ed882..265b1c6 100644 --- a/Sources/OracleNIO/OracleChannelHandler.swift +++ b/Sources/OracleNIO/OracleChannelHandler.swift @@ -400,7 +400,7 @@ final class OracleChannelHandler: ChannelDuplexHandler { case .forwardRows(let rows): self.rowStream!.receive(rows) - case .forwardStreamComplete(let buffer, let cursorID): + case .forwardStreamComplete(let buffer, let cursorID, let affectedRows): guard let rowStream else { // if the stream was cancelled we don't have it here anymore. return @@ -409,7 +409,7 @@ final class OracleChannelHandler: ChannelDuplexHandler { if buffer.count > 0 { rowStream.receive(buffer) } - rowStream.receive(completion: .success(())) + rowStream.receive(completion: .success(affectedRows)) if cursorID != 0 { self.cleanupContext.cursorsToClose.insert(cursorID) @@ -701,16 +701,22 @@ final class OracleChannelHandler: ChannelDuplexHandler { rows = OracleRowStream( source: .stream(describeInfo, self), eventLoop: context.channel.eventLoop, - logger: result.logger + logger: result.logger, + affectedRows: nil, + rowCounts: result.rowCounts, + batchErrors: result.batchErrors ) self.rowStream = rows promise.succeed(rows) - case .noRows: + case .noRows(let affectedRows): rows = OracleRowStream( source: .noRows(.success(())), eventLoop: context.channel.eventLoop, - logger: result.logger + logger: result.logger, + affectedRows: affectedRows, + rowCounts: result.rowCounts, + batchErrors: result.batchErrors ) promise.succeed(rows) self.run(self.state.readyForStatementReceived(), with: context) diff --git a/Sources/OracleNIO/OracleCodable.swift b/Sources/OracleNIO/OracleCodable.swift index 54496ed..d2b00a4 100644 --- a/Sources/OracleNIO/OracleCodable.swift +++ b/Sources/OracleNIO/OracleCodable.swift @@ -23,7 +23,14 @@ import class Foundation.JSONEncoder /// For example, custom types created at runtime, such as enums, or extension types whose OID is not /// stable between databases. public protocol OracleThrowingDynamicTypeEncodable: Sendable { + /// Identifies the default data type that we will encode into `ByteBuffer` in `encode`. + /// + /// It is used to encode `NULL` values to the correct format. + static var defaultOracleType: OracleDataType { get } + /// Identifies the data type that we will encode into `ByteBuffer` in `encode`. + /// + /// A default implementation is provided. var oracleType: OracleDataType { get } /// Identifies the byte size indicator which will be sent to Oracle. @@ -99,6 +106,8 @@ public protocol OracleDynamicTypeEncodable: OracleThrowingDynamicTypeEncodable { public protocol OracleThrowingEncodable: OracleThrowingDynamicTypeEncodable {} extension OracleThrowingDynamicTypeEncodable { + public var oracleType: OracleDataType { Self.defaultOracleType } + public var size: UInt32 { UInt32(self.oracleType.defaultSize) } public static var isArray: Bool { false } @@ -108,6 +117,8 @@ extension OracleThrowingDynamicTypeEncodable { // swift-format-ignore: DontRepeatTypeInStaticProperties extension Array where Element: OracleThrowingDynamicTypeEncodable { + public var oracleType: OracleDataType { Element.defaultOracleType } + public static var isArray: Bool { true } public var arrayCount: Int? { self.count } public var arraySize: Int? { self.capacity } diff --git a/Sources/OracleNIO/OracleRowSequence.swift b/Sources/OracleNIO/OracleRowSequence.swift index 8fc0b13..994ddf2 100644 --- a/Sources/OracleNIO/OracleRowSequence.swift +++ b/Sources/OracleNIO/OracleRowSequence.swift @@ -28,15 +28,18 @@ public struct OracleRowSequence: AsyncSequence, Sendable { let backing: BackingSequence let lookupTable: [String: Int] let columns: [OracleColumn] + let listeners: OracleRowStream.MetadataListeners init( _ backing: BackingSequence, lookupTable: [String: Int], - columns: [OracleColumn] + columns: [OracleColumn], + listeners: OracleRowStream.MetadataListeners ) { self.backing = backing self.lookupTable = lookupTable self.columns = columns + self.listeners = listeners } public func makeAsyncIterator() -> AsyncIterator { @@ -46,6 +49,25 @@ public struct OracleRowSequence: AsyncSequence, Sendable { columns: self.columns ) } + + /// Receive the total number of rows affected by the operation. + /// + /// The metric is only available after the query has completed, e.g. after all rows are retrieved from the server. + public var affectedRows: Int { + get async throws { + try await withCheckedThrowingContinuation { continuation in + listeners.addAffectedRowsListener(continuation) + } + } + } + + internal var rowCounts: [Int] { + listeners.rowCounts ?? [] + } + + internal var batchErrors: [OracleSQLError.BatchError] { + listeners.batchErrors ?? [] + } } extension OracleRowSequence { diff --git a/Sources/OracleNIO/OracleRowStream.swift b/Sources/OracleNIO/OracleRowStream.swift index ab6fc0e..5f40d14 100644 --- a/Sources/OracleNIO/OracleRowStream.swift +++ b/Sources/OracleNIO/OracleRowStream.swift @@ -13,16 +13,19 @@ //===----------------------------------------------------------------------===// import Logging +import NIOConcurrencyHelpers import NIOCore struct StatementResult { enum Value: Equatable { - case noRows + case noRows(affectedRows: Int) case describeInfo([OracleColumn]) } var value: Value var logger: Logger + var batchErrors: Optional<[OracleSQLError.BatchError]> + var rowCounts: Optional<[Int]> } final class OracleRowStream: @unchecked Sendable { @@ -59,14 +62,82 @@ final class OracleRowStream: @unchecked Sendable { case asyncSequence(AsyncSequenceSource, OracleRowsDataSource) } + final class MetadataListeners { + private let lock = NIOLock() + #if swift(>=5.10) + /// This property must only be accessed when ``lock`` is aquired. + private nonisolated(unsafe) var affectedRowsListeners: [CheckedContinuation] = [] + /// This property must only be accessed when ``lock`` is aquired. + private nonisolated(unsafe) var affectedRows: Int? + /// This property must only be accessed when ``lock`` is aquired. + private nonisolated(unsafe) var error: (any Error)? + #else + /// This property must only be accessed when ``lock`` is aquired. + private var affectedRowsListeners: [CheckedContinuation] = [] + /// This property must only be accessed when ``lock`` is aquired. + private var affectedRows: Int? + /// This property must only be accessed when ``lock`` is aquired. + private var error: (any Error)? + #endif + + let rowCounts: [Int]? + let batchErrors: [OracleSQLError.BatchError]? + + init(affectedRows: Int? = nil, rowCounts: [Int]?, batchErrors: [OracleSQLError.BatchError]?) { + self.affectedRows = affectedRows + self.rowCounts = rowCounts + self.batchErrors = batchErrors + } + + func addAffectedRowsListener(_ listener: CheckedContinuation) { + lock.withLock { + if let affectedRows { + listener.resume(returning: affectedRows) + } else if let error { + listener.resume(throwing: error) + } else { + affectedRowsListeners.append(listener) + } + } + } + + func receiveAffectedRows(_ affectedRows: Int) { + let listeners = lock.withLock { + self.affectedRows = affectedRows + let listeners = self.affectedRowsListeners + self.affectedRowsListeners.removeAll() + return listeners + } + for listener in listeners { + listener.resume(returning: affectedRows) + } + } + + func receiveError(_ error: any Error) { + let affectedRowsListeners = lock.withLock { + self.error = error + let listeners = self.affectedRowsListeners + self.affectedRowsListeners.removeAll() + return listeners + } + for listener in affectedRowsListeners { + listener.resume(throwing: error) + } + } + } + private let rowDescription: [OracleColumn] private let lookupTable: [String: Int] + private let listeners: MetadataListeners private var downstreamState: DownstreamState init( source: Source, eventLoop: EventLoop, - logger: Logger + logger: Logger, + affectedRows: Int?, + rowCounts: [Int]?, + batchErrors: [OracleSQLError.BatchError]? ) { let bufferState: BufferState switch source { @@ -92,6 +163,8 @@ final class OracleRowStream: @unchecked Sendable { lookup[column.name] = index } self.lookupTable = lookup + + self.listeners = MetadataListeners(affectedRows: affectedRows, rowCounts: rowCounts, batchErrors: batchErrors) } // MARK: Async Sequence @@ -132,7 +205,8 @@ final class OracleRowStream: @unchecked Sendable { return OracleRowSequence( producer.sequence, lookupTable: self.lookupTable, - columns: self.rowDescription + columns: self.rowDescription, + listeners: self.listeners ) } @@ -377,14 +451,16 @@ final class OracleRowStream: @unchecked Sendable { } } - internal func receive(completion result: Result) { + internal func receive(completion result: Result) { self.eventLoop.preconditionInEventLoop() switch result { - case .success: + case .success(let affectedRows): self.receiveEnd() + self.listeners.receiveAffectedRows(affectedRows) case .failure(let error): self.receiveError(error) + self.listeners.receiveError(error) } } @@ -476,3 +552,9 @@ protocol OracleRowsDataSource { func request(for stream: OracleRowStream) func cancel(for stream: OracleRowStream) } + +#if swift(>=5.10) + extension OracleRowStream.MetadataListeners: Sendable {} +#else + extension OracleRowStream.MetadataListeners: @unchecked Sendable {} +#endif diff --git a/Sources/OracleNIO/OracleSQLError.swift b/Sources/OracleNIO/OracleSQLError.swift index 5f3d580..593619d 100644 --- a/Sources/OracleNIO/OracleSQLError.swift +++ b/Sources/OracleNIO/OracleSQLError.swift @@ -36,6 +36,7 @@ public struct OracleSQLError: Sendable, Error { case sidNotSupported case missingParameter case unsupportedDataType + case missingStatement } internal var base: Base @@ -63,6 +64,7 @@ public struct OracleSQLError: Sendable, Error { public static let sidNotSupported = Self(.sidNotSupported) public static let missingParameter = Self(.missingParameter) public static let unsupportedDataType = Self(.unsupportedDataType) + public static let missingStatement = Self(.missingStatement) public var description: String { switch self.base { @@ -96,6 +98,8 @@ public struct OracleSQLError: Sendable, Error { return "missingParameter" case .unsupportedDataType: return "unsupportedDataType" + case .missingStatement: + return "missingStatement" } } } @@ -224,11 +228,39 @@ public struct OracleSQLError: Sendable, Error { self.underlying.message } + /// The amount of rows affected by the operation. + /// + /// In most cases, this is `0`, although it is posslbe that a statement + /// (e.g. ``OracleConnection/executeBatch(_:binds:encodingContext:options:logger:file:line:)`` + /// executes some if its statements successfully, while others might have failed. In this case, `affectedRows` shows + /// how many operations have been succesful. + /// + /// + /// Defaults to `0`. + public var affectedRows: Int { + Int(self.underlying.rowCount ?? 0) + } + init(_ underlying: OracleBackendMessage.BackendError) { self.underlying = underlying } } + public struct BatchError: Sendable { + /// The index of the statement in which the error occurred. + public let statementIndex: Int + /// The error number/identifier. + public let number: Int + /// The error message, typically prefixed with `ORA-` & ``number``. + public let message: String + + init(_ error: OracleError) { + self.statementIndex = error.offset + self.number = error.code + self.message = error.message ?? "" + } + } + // MARK: - Internal convenience factory methods - static func unexpectedBackendMessage( @@ -310,6 +342,7 @@ public struct OracleSQLError: Sendable, Error { static let unsupportedDataType = OracleSQLError(code: .unsupportedDataType) + static let missingStatement = OracleSQLError(code: .missingStatement) } extension OracleSQLError: CustomStringConvertible { diff --git a/Sources/OracleNIO/OracleTask.swift b/Sources/OracleNIO/OracleTask.swift index 3bf8983..c3c251b 100644 --- a/Sources/OracleNIO/OracleTask.swift +++ b/Sources/OracleNIO/OracleTask.swift @@ -119,8 +119,52 @@ final class StatementContext { } } + enum Binds { + case none + /// Single statement. + case one(OracleBindings) + /// Bulk statement, e.g. multiple rows to insert. + /// + /// Used in ``OracleConnection/execute(_:binds:encodingContext:options:logger:)``. + case many(OracleBindingsCollection) + + var count: Int { + switch self { + case .none: + return 0 + case .one(let binds): + return binds.count + case .many(let collection): + return collection.metadata.count + } + } + + var metadata: [OracleBindings.Metadata] { + switch self { + case .none: + return [] + case .one(let binds): + return binds.metadata + case .many(let collection): + return collection.metadata + } + } + + var hasData: Bool { + switch self { + case .none: + return false + case .one(let binds): + return !binds.metadata.isEmpty && (binds.bytes.readableBytes > 0 || binds.longBytes.readableBytes > 0) + case .many(let collection): + return collection.hasData + } + } + } + let type: StatementType - let statement: OracleStatement + let sql: String + let binds: Binds let options: StatementOptions let logger: Logger @@ -131,6 +175,7 @@ final class StatementContext { var requiresDefine: Bool = false var noPrefetch: Bool = false let isReturning: Bool + let executionCount: UInt32 var sequenceNumber: UInt8 = 2 @@ -141,9 +186,11 @@ final class StatementContext { promise: EventLoopPromise ) { self.logger = logger - self.statement = statement + self.sql = statement.sql + self.binds = .one(statement.binds) self.options = options self.sqlLength = .init(statement.sql.data(using: .utf8)?.count ?? 0) + self.executionCount = 1 // strip single/multiline comments and and strings from the sql var sql = statement.sql @@ -151,12 +198,34 @@ final class StatementContext { sql = sql.replacing(/\--.*(\n|$)/, with: "") sql = sql.replacing(/'[^']*'(?=(?:[^']*[^']*')*[^']*$)/, with: "") + self.isReturning = statement.binds.metadata.first(where: \.isReturnBind) != nil + let type = Self.determineStatementType(minifiedSQL: sql, promise: promise) + self.type = type + } + + init( + statement: String, + bindCollection: OracleBindingsCollection, + options: StatementOptions, + logger: Logger, + promise: EventLoopPromise + ) { + self.logger = logger + self.sql = statement + self.binds = .many(bindCollection) + self.options = options + self.sqlLength = UInt32(statement.utf8.count) + self.executionCount = UInt32(bindCollection.bindings.count) + + // strip single/multiline comments and and strings from the sql + var sql = statement + sql = sql.replacing(/\/\*[\S\n ]+?\*\//, with: "") + sql = sql.replacing(/\--.*(\n|$)/, with: "") + sql = sql.replacing(/'[^']*'(?=(?:[^']*[^']*')*[^']*$)/, with: "") + self.isReturning = - statement.binds.metadata - .first(where: \.isReturnBind) != nil - let type = Self.determineStatementType( - minifiedSQL: sql, promise: promise - ) + bindCollection.metadata.first(where: \.isReturnBind) != nil + let type = Self.determineStatementType(minifiedSQL: sql, promise: promise) self.type = type } @@ -167,11 +236,13 @@ final class StatementContext { promise: EventLoopPromise ) { self.logger = logger - self.statement = "" + self.sql = "" + self.binds = .none self.sqlLength = 0 self.cursorID = cursor.id self.options = options self.isReturning = false + self.executionCount = 1 self.type = .cursor(cursor, promise) } @@ -209,8 +280,15 @@ public struct StatementOptions { /// This happens on the Oracle server side. So it won't cause additional roundtrips to the database. public var autoCommit: Bool = false - internal var arrayDMLRowCounts: Bool = false - internal var batchErrors: Bool = false + public var arrayDMLRowCounts: Bool = false + + /// Indicates how errors will be handled in batch executions. + /// + /// If false, batch executions will discard all remaining data sets after an error occurred. + /// + /// If true, all data sets will be executed. Data sets with errors are skipped and the corresponding errors are + /// returned after the full operation is finished. + public var batchErrors: Bool = false /// Indicates how many rows will be returned with the initial roundtrip. /// @@ -257,6 +335,8 @@ public struct StatementOptions { /// - Parameters: /// - autoCommit: Automatically commit after execution of the statement without needing an /// additional roundtrip. + /// - batchErrors: Indicates how errors are handled in batch executions. Refer to + /// ``batchErrors`` for additional explanation. /// - prefetchRows: Indicates how many rows should be fetched with the initial response from /// the database. Refer to ``prefetchRows`` for additional explanation. /// - arraySize: Indicates how many rows will be returned by any subsequent fetch calls to the @@ -265,6 +345,7 @@ public struct StatementOptions { /// requires another round-trip to the server. public init( autoCommit: Bool = false, + batchErrors: Bool = false, prefetchRows: Int = 2, arraySize: Int = 50, fetchLOBs: Bool = false diff --git a/Sources/OracleNIO/OracleStatement.swift b/Sources/OracleNIO/Statements/OracleBindings.swift similarity index 71% rename from Sources/OracleNIO/OracleStatement.swift rename to Sources/OracleNIO/Statements/OracleBindings.swift index 8fec279..55594b8 100644 --- a/Sources/OracleNIO/OracleStatement.swift +++ b/Sources/OracleNIO/Statements/OracleBindings.swift @@ -15,153 +15,90 @@ import NIOConcurrencyHelpers import NIOCore -/// A Oracle SQL statement, that can be executed on a Oracle server. -/// Contains the raw sql string and bindings. -public struct OracleStatement: Sendable, Hashable { - /// The statement's string. - public var sql: String - /// The statement's binds. - public var binds: OracleBindings - - public init( - unsafeSQL sql: String, - binds: OracleBindings = OracleBindings() - ) { - self.sql = sql - self.binds = binds - } -} - -extension OracleStatement: ExpressibleByStringInterpolation { - public init(stringInterpolation: StringInterpolation) { - self.sql = stringInterpolation.sql - self.binds = stringInterpolation.binds - } - - public init(stringLiteral value: StringLiteralType) { - self.sql = value - self.binds = OracleBindings() +struct OracleBindingsCollection { + /// Metadata is shared by all bind rows. + var metadata: [OracleBindings.Metadata] = [] + var bindings: [(ByteBuffer, long: ByteBuffer)] = [] + var hasData = false + + mutating func appendRow( + _ row: repeat (each Bind)?, + context: OracleEncodingContext + ) throws { + var index = 0 + var bindings: (ByteBuffer, long: ByteBuffer) = (ByteBuffer(), ByteBuffer()) + repeat try appendBind(each row, context: context, into: &bindings, index: &index) + if !hasData { hasData = bindings.0.readableBytes > 0 || bindings.long.readableBytes > 0 } + self.bindings.append(bindings) } -} - -extension OracleStatement { - public struct StringInterpolation: StringInterpolationProtocol { - public typealias StringLiteralType = String - - @usableFromInline - var sql: String - @usableFromInline - var binds: OracleBindings - - public init(literalCapacity: Int, interpolationCount: Int) { - self.sql = "" - self.binds = OracleBindings(capacity: interpolationCount) - } - - public mutating func appendLiteral(_ literal: String) { - self.sql.append(contentsOf: literal) - } - - @inlinable - public mutating func appendInterpolation( - _ value: some OracleThrowingDynamicTypeEncodable, - context: OracleEncodingContext = .default - ) throws { - let bindName = "\(self.binds.count)" - try self.binds.append(value, context: context, bindName: bindName) - self.sql.append(contentsOf: ":\(bindName)") - } - - @inlinable - public mutating func appendInterpolation( - _ value: (some OracleThrowingDynamicTypeEncodable)?, - context: OracleEncodingContext = .default - ) throws { - let bindName = "\(self.binds.count)" - switch value { - case .none: - self.binds.appendNull(value?.oracleType, bindName: bindName) - case .some(let value): - try self.binds - .append(value, context: context, bindName: bindName) - } - - self.sql.append(contentsOf: ":\(bindName)") - } - - @inlinable - public mutating func appendInterpolation( - _ value: some OracleDynamicTypeEncodable, - context: OracleEncodingContext = .default - ) { - let bindName = "\(self.binds.count)" - self.binds.append(value, context: context, bindName: bindName) - self.sql.append(contentsOf: ":\(bindName)") - } - @inlinable - public mutating func appendInterpolation( - _ value: (some OracleDynamicTypeEncodable)?, - context: OracleEncodingContext = .default - ) { - let bindName = "\(self.binds.count)" - switch value { - case .none: - self.binds.appendNull(value?.oracleType, bindName: bindName) - case .some(let value): - self.binds.append(value, context: context, bindName: bindName) + mutating func appendRow(_ row: OracleBindings) throws { + for (index, column) in row.metadata.enumerated() { + if metadata.count <= index { + metadata.append(column) + } else { + let currentMetadata = metadata[index] + if column.size > currentMetadata.size || column.bufferSize > currentMetadata.bufferSize { + metadata[index] = column + } } - - self.sql.append(contentsOf: ":\(bindName)") } + if !hasData { hasData = row.bytes.readableBytes > 0 || row.longBytes.readableBytes > 0 } + self.bindings.append((row.bytes, row.longBytes)) + } - public mutating func appendInterpolation(_ value: some OracleRef) { - if let bindName = self.binds.contains(ref: value) { - self.sql.append(contentsOf: ":\(bindName)") + private mutating func appendBind( + _ bind: T?, + context: OracleEncodingContext, + into buffers: inout (ByteBuffer, long: ByteBuffer), + index: inout Int + ) throws { + let newMetadata = + if let bind { + OracleBindings.Metadata( + value: bind, + protected: true, + isReturnBind: false, + bindName: "\(index)" + ) } else { - let bindName = "\(self.binds.count)" - self.binds.append(value, bindName: bindName) - self.sql.append(contentsOf: ":\(bindName)") + OracleBindings.Metadata( + dataType: T.defaultOracleType, + protected: false, + isReturnBind: false, + size: 1, + isArray: false, + arrayCount: nil, + maxArraySize: nil, + bindName: "\(index)" + ) } + if let bind, newMetadata.bufferSize >= Constants.TNS_MIN_LONG_LENGTH { + try bind._encodeRaw(into: &buffers.long, context: context) + } else if let bind { + try bind._encodeRaw(into: &buffers.0, context: context) + } else if T.defaultOracleType == .boolean { + buffers.0.writeInteger(Constants.TNS_ESCAPE_CHAR) + buffers.0.writeInteger(UInt8(1)) + } else if T.defaultOracleType._oracleType == .intNamed { + buffers.0.writeUB4(0) // TOID + buffers.0.writeUB4(0) // OID + buffers.0.writeUB4(0) // snapshot + buffers.0.writeUB4(0) // version + buffers.0.writeUB4(0) // packed data length + buffers.0.writeUB4(Constants.TNS_OBJ_TOP_LEVEL) // flags + } else { + buffers.0.writeInteger(UInt8(0)) } - - /// Adds a list of values as individual binds. - /// - /// ```swift - /// let values = [15, 24, 33] - /// let statement: OracleStatement = "SELECT id FROM my_table WHERE id IN (\(list: values))" - /// print(statement.sql) - /// // SELECT id FROM my_table WHERE id IN (:1, :2, :3) - /// ``` - @inlinable - public mutating func appendInterpolation( - list: [some OracleDynamicTypeEncodable], - context: OracleEncodingContext = .default - ) { - guard !list.isEmpty else { return } - for value in list { - self.appendInterpolation(value, context: context) - self.sql.append(", ") + if metadata.count <= index { + metadata.append(newMetadata) + } else { + let currentMetadata = metadata[index] + if newMetadata.size > currentMetadata.size || newMetadata.bufferSize > currentMetadata.bufferSize { + metadata[index] = newMetadata } - self.sql.removeLast(2) } - - @inlinable - public mutating func appendInterpolation(unescaped interpolation: String) { - self.sql.append(contentsOf: interpolation) - } - } -} - -extension OracleStatement: CustomStringConvertible { - public var description: String { - "\(self.sql) \(self.binds)" - } -} - -extension OracleStatement: CustomDebugStringConvertible { - public var debugDescription: String { - "OracleStatement(sql: \(String(describing: self.sql)), binds: \(String(reflecting: self.binds))" + index += 1 } } diff --git a/Sources/OracleNIO/OraclePreparedStatement.swift b/Sources/OracleNIO/Statements/OraclePreparedStatement.swift similarity index 100% rename from Sources/OracleNIO/OraclePreparedStatement.swift rename to Sources/OracleNIO/Statements/OraclePreparedStatement.swift diff --git a/Sources/OracleNIO/Statements/OracleStatement.swift b/Sources/OracleNIO/Statements/OracleStatement.swift new file mode 100644 index 0000000..b061404 --- /dev/null +++ b/Sources/OracleNIO/Statements/OracleStatement.swift @@ -0,0 +1,166 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the OracleNIO open source project +// +// Copyright (c) 2024 Timo Zacherl and the OracleNIO project authors +// Licensed under Apache License v2.0 +// +// See LICENSE for license information +// See CONTRIBUTORS.md for the list of OracleNIO project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// + +import NIOConcurrencyHelpers +import NIOCore + +/// A Oracle SQL statement, that can be executed on a Oracle server. +/// Contains the raw sql string and bindings. +public struct OracleStatement: Sendable, Hashable { + /// The statement's string. + public var sql: String + /// The statement's binds. + public var binds: OracleBindings + + public init( + unsafeSQL sql: String, + binds: OracleBindings = OracleBindings() + ) { + self.sql = sql + self.binds = binds + } +} + +extension OracleStatement: ExpressibleByStringInterpolation { + public init(stringInterpolation: StringInterpolation) { + self.sql = stringInterpolation.sql + self.binds = stringInterpolation.binds + } + + public init(stringLiteral value: StringLiteralType) { + self.sql = value + self.binds = OracleBindings() + } +} + +extension OracleStatement { + public struct StringInterpolation: StringInterpolationProtocol { + public typealias StringLiteralType = String + + @usableFromInline + var sql: String + @usableFromInline + var binds: OracleBindings + + public init(literalCapacity: Int, interpolationCount: Int) { + self.sql = "" + self.binds = OracleBindings(capacity: interpolationCount) + } + + public mutating func appendLiteral(_ literal: String) { + self.sql.append(contentsOf: literal) + } + + @inlinable + public mutating func appendInterpolation( + _ value: some OracleThrowingDynamicTypeEncodable, + context: OracleEncodingContext = .default + ) throws { + let bindName = "\(self.binds.count)" + try self.binds.append(value, context: context, bindName: bindName) + self.sql.append(contentsOf: ":\(bindName)") + } + + @inlinable + public mutating func appendInterpolation( + _ value: T?, + context: OracleEncodingContext = .default + ) throws { + let bindName = "\(self.binds.count)" + switch value { + case .none: + self.binds.appendNull(T.defaultOracleType, bindName: bindName) + case .some(let value): + try self.binds + .append(value, context: context, bindName: bindName) + } + + self.sql.append(contentsOf: ":\(bindName)") + } + + @inlinable + public mutating func appendInterpolation( + _ value: some OracleDynamicTypeEncodable, + context: OracleEncodingContext = .default + ) { + let bindName = "\(self.binds.count)" + self.binds.append(value, context: context, bindName: bindName) + self.sql.append(contentsOf: ":\(bindName)") + } + + @inlinable + public mutating func appendInterpolation( + _ value: (some OracleDynamicTypeEncodable)?, + context: OracleEncodingContext = .default + ) { + let bindName = "\(self.binds.count)" + switch value { + case .none: + self.binds.appendNull(value?.oracleType, bindName: bindName) + case .some(let value): + self.binds.append(value, context: context, bindName: bindName) + } + + self.sql.append(contentsOf: ":\(bindName)") + } + + public mutating func appendInterpolation(_ value: some OracleRef) { + if let bindName = self.binds.contains(ref: value) { + self.sql.append(contentsOf: ":\(bindName)") + } else { + let bindName = "\(self.binds.count)" + self.binds.append(value, bindName: bindName) + self.sql.append(contentsOf: ":\(bindName)") + } + } + + /// Adds a list of values as individual binds. + /// + /// ```swift + /// let values = [15, 24, 33] + /// let statement: OracleStatement = "SELECT id FROM my_table WHERE id IN (\(list: values))" + /// print(statement.sql) + /// // SELECT id FROM my_table WHERE id IN (:1, :2, :3) + /// ``` + @inlinable + public mutating func appendInterpolation( + list: [some OracleDynamicTypeEncodable], + context: OracleEncodingContext = .default + ) { + guard !list.isEmpty else { return } + for value in list { + self.appendInterpolation(value, context: context) + self.sql.append(", ") + } + self.sql.removeLast(2) + } + + @inlinable + public mutating func appendInterpolation(unescaped interpolation: String) { + self.sql.append(contentsOf: interpolation) + } + } +} + +extension OracleStatement: CustomStringConvertible { + public var description: String { + "\(self.sql) \(self.binds)" + } +} + +extension OracleStatement: CustomDebugStringConvertible { + public var debugDescription: String { + "OracleStatement(sql: \(String(describing: self.sql)), binds: \(String(reflecting: self.binds))" + } +} diff --git a/Tests/IntegrationTests/BatchExecutionTests.swift b/Tests/IntegrationTests/BatchExecutionTests.swift new file mode 100644 index 0000000..6015bb6 --- /dev/null +++ b/Tests/IntegrationTests/BatchExecutionTests.swift @@ -0,0 +1,268 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the OracleNIO open source project +// +// Copyright (c) 2024 Timo Zacherl and the OracleNIO project authors +// Licensed under Apache License v2.0 +// +// See LICENSE for license information +// See CONTRIBUTORS.md for the list of OracleNIO project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// + +#if compiler(>=6.0) + import OracleNIO + import Testing + + @Suite + final class BatchExecutionTests { + private let client: OracleClient + private var running: Task! + + init() throws { + self.client = try OracleClient(configuration: OracleConnection.testConfig(), backgroundLogger: .oracleTest) + self.running = Task { await client.run() } + } + + deinit { + running.cancel() + } + + @Test + func simpleBatchExecution() async throws { + try await client.withConnection { connection in + do { + try await connection.execute("DROP TABLE users_simple_batch_exec", logger: .oracleTest) + } catch let error as OracleSQLError { + // "ORA-00942: table or view does not exist" can be ignored + #expect(error.serverInfo?.number == 942) + } + try await connection.execute( + "CREATE TABLE users_simple_batch_exec (id NUMBER, name VARCHAR2(50 byte), age NUMBER)") + let binds: [(Int, String, Int?)] = [ + (1, "John", nil), + (2, "Jane", 30), + (3, "Jack", 40), + (4, "Jill", 50), + (5, "Pete", nil), + ] + let batchResult = try await connection.executeBatch( + "INSERT INTO users_simple_batch_exec (id, name, age) VALUES (:1, :2, :3)", binds: binds) + #expect(batchResult.affectedRows == binds.count) + let stream = try await connection.execute( + "SELECT id, name, age FROM users_simple_batch_exec ORDER BY id ASC") + var index: Int = 0 + for try await (id, name, age) in stream.decode((Int, String, Int?).self) { + guard index < binds.count else { + Issue.record("Too many rows") + return + } + let expected = binds[index] + #expect(id == expected.0) + #expect(name == expected.1) + #expect(age == expected.2) + index += 1 + } + #expect(index == binds.count) + } + } + + @Test + func preparedStatementBatchExecution() async throws { + struct InsertUserStatement: OraclePreparedStatement { + static let sql: String = + "INSERT INTO users_prepared_statement_batch_exec (id, name, age) VALUES (:1, :2, :3)" + + typealias Row = Void + + var id: Int + var name: String + var age: Int + + func makeBindings() throws -> OracleBindings { + var bindings = OracleBindings(capacity: 3) + bindings.append(id) + bindings.append(name) + bindings.append(age) + return bindings + } + + func decodeRow(_ row: OracleRow) throws -> Row {} + } + + try await client.withConnection { connection in + do { + try await connection.execute("DROP TABLE users_prepared_statement_batch_exec", logger: .oracleTest) + } catch let error as OracleSQLError { + // "ORA-00942: table or view does not exist" can be ignored + #expect(error.serverInfo?.number == 942) + } + try await connection.execute( + "CREATE TABLE users_prepared_statement_batch_exec (id NUMBER, name VARCHAR2(50 byte), age NUMBER)") + let binds: [InsertUserStatement] = [ + InsertUserStatement(id: 1, name: "John", age: 20), + InsertUserStatement(id: 2, name: "Jane", age: 30), + InsertUserStatement(id: 3, name: "Jack", age: 40), + InsertUserStatement(id: 4, name: "Jill", age: 50), + InsertUserStatement(id: 5, name: "Pete", age: 60), + ] + let batchResult = try await connection.executeBatch(binds) + #expect(batchResult.affectedRows == binds.count) + let stream = try await connection.execute( + "SELECT id, name, age FROM users_prepared_statement_batch_exec ORDER BY id ASC") + var index: Int = 0 + for try await (id, name, age) in stream.decode((Int, String, Int).self) { + guard index < binds.count else { + Issue.record("Too many rows") + return + } + let expected = binds[index] + #expect(id == expected.id) + #expect(name == expected.name) + #expect(age == expected.age) + index += 1 + } + #expect(index == binds.count) + } + } + + @Test + func batchExecutionWithErrorDiscardsRemaining() async throws { + struct InsertUserStatement: OraclePreparedStatement { + static let sql: String = + "INSERT INTO users_error_discards_batch_exec (id, name, age) VALUES (:1, :2, :3)" + + typealias Row = Void + + var id: Int + var name: String + var age: Int + + func makeBindings() throws -> OracleBindings { + var bindings = OracleBindings(capacity: 3) + bindings.append(id) + bindings.append(name) + bindings.append(age) + return bindings + } + + func decodeRow(_ row: OracleRow) throws -> Row {} + } + + try await client.withConnection { connection in + do { + try await connection.execute("DROP TABLE users_error_discards_batch_exec", logger: .oracleTest) + } catch let error as OracleSQLError { + // "ORA-00942: table or view does not exist" can be ignored + #expect(error.serverInfo?.number == 942) + } + try await connection.execute( + "CREATE TABLE users_error_discards_batch_exec (id NUMBER, name VARCHAR2(50 byte), age NUMBER)") + let binds: [InsertUserStatement] = [ + InsertUserStatement(id: 1, name: "John", age: 20), + InsertUserStatement(id: 2, name: "Jane", age: 30), + InsertUserStatement( + id: 3, name: "Jack's name is too long to fit into this column, so we fail here", age: 40), + InsertUserStatement(id: 4, name: "Jill", age: 50), + InsertUserStatement(id: 5, name: "Pete", age: 60), + ] + do { + try await connection.executeBatch(binds) + } catch let error as OracleSQLError { + // expect a value too long for column error here + guard error.serverInfo?.number == 12899 else { throw error } + #expect(error.serverInfo?.affectedRows == 2) + } + let stream = try await connection.execute( + "SELECT id, name, age FROM users_error_discards_batch_exec ORDER BY id ASC") + var index: Int = 0 + for try await (id, name, age) in stream.decode((Int, String, Int).self) { + guard index < 2 else { + Issue.record("Too many rows") + return + } + let expected = binds[index] + #expect(id == expected.id) + #expect(name == expected.name) + #expect(age == expected.age) + index += 1 + } + #expect(index == 2) + } + } + + @Test + func batchExecutionWithBatchErrorsDoesNotDiscardSuccess() async throws { + struct InsertUserStatement: OraclePreparedStatement { + static let sql: String = + "INSERT INTO users_batch_error_does_not_discard_batch_exec (id, name, age) VALUES (:1, :2, :3)" + + typealias Row = Void + + var id: Int + var name: String + var age: Int + + func makeBindings() throws -> OracleBindings { + var bindings = OracleBindings(capacity: 3) + bindings.append(id) + bindings.append(name) + bindings.append(age) + return bindings + } + + func decodeRow(_ row: OracleRow) throws -> Row {} + } + + try await client.withConnection { connection in + do { + try await connection.execute( + "DROP TABLE users_batch_error_does_not_discard_batch_exec", logger: .oracleTest) + } catch let error as OracleSQLError { + // "ORA-00942: table or view does not exist" can be ignored + #expect(error.serverInfo?.number == 942) + } + try await connection.execute( + "CREATE TABLE users_batch_error_does_not_discard_batch_exec (id NUMBER, name VARCHAR2(50 byte), age NUMBER)" + ) + var binds: [InsertUserStatement] = [ + InsertUserStatement(id: 1, name: "John", age: 20), + InsertUserStatement(id: 2, name: "Jane", age: 30), + InsertUserStatement( + id: 3, name: "Jack's name is too long to fit into this column, so we fail here", age: 40), + InsertUserStatement(id: 4, name: "Jill", age: 50), + InsertUserStatement(id: 5, name: "Pete", age: 60), + ] + var options = StatementOptions() + options.batchErrors = true + options.arrayDMLRowCounts = true + do { + try await connection.executeBatch(binds, options: options) + } catch let error as OracleBatchExecutionError { + #expect(error.result.affectedRows == 4) + #expect(error.result.affectedRowsPerStatement == [1, 1, 0, 1, 1]) + #expect(error.errors.first?.statementIndex == 2) + #expect(error.errors.first?.number == 12899) + } + let stream = try await connection.execute( + "SELECT id, name, age FROM users_batch_error_does_not_discard_batch_exec ORDER BY id ASC") + var index: Int = 0 + binds.remove(at: 2) // malformed data + for try await (id, name, age) in stream.decode((Int, String, Int).self) { + guard index < binds.count else { + Issue.record("Too many rows") + return + } + let expected = binds[index] + #expect(id == expected.id) + #expect(name == expected.name) + #expect(age == expected.age) + index += 1 + } + #expect(index == binds.count) + } + } + } +#endif diff --git a/Tests/IntegrationTests/BugReportTests.swift b/Tests/IntegrationTests/BugReportTests.swift index 53e7fc0..a4f9ed8 100644 --- a/Tests/IntegrationTests/BugReportTests.swift +++ b/Tests/IntegrationTests/BugReportTests.swift @@ -266,5 +266,5 @@ private struct Timestamp: Sendable, OracleCodable { } } - var oracleType: OracleDataType { .timestamp } + static var defaultOracleType: OracleDataType { .timestamp } } diff --git a/Tests/OracleNIOTests/ConnectionStateMachine/StatementStateMachineTests.swift b/Tests/OracleNIOTests/ConnectionStateMachine/StatementStateMachineTests.swift index ab13c11..e2d9453 100644 --- a/Tests/OracleNIOTests/ConnectionStateMachine/StatementStateMachineTests.swift +++ b/Tests/OracleNIOTests/ConnectionStateMachine/StatementStateMachineTests.swift @@ -25,7 +25,7 @@ final class StatementStateMachineTests: XCTestCase { let query: OracleStatement = "DELETE FROM table" let queryContext = StatementContext(statement: query, promise: promise) - let result = StatementResult(value: .noRows) + let result = StatementResult(value: .noRows(affectedRows: 0)) let backendError = OracleBackendMessage.BackendError( number: 0, cursorID: 6, position: 0, rowCount: 0, isWarning: false, message: nil, rowID: nil, batchErrors: []) @@ -74,7 +74,7 @@ final class StatementStateMachineTests: XCTestCase { XCTAssertEqual(state.rowDataReceived(.init(1), capabilities: .init()), .wait) XCTAssertEqual(state.queryParameterReceived(.init()), .wait) XCTAssertEqual( - state.backendErrorReceived(.noData), .forwardStreamComplete([row1], cursorID: 1)) + state.backendErrorReceived(.noData), .forwardStreamComplete([row1], cursorID: 1, affectedRows: 1)) } func testCancellationCompletesQueryOnlyOnce() throws { diff --git a/Tests/OracleNIOTests/OracleStatementTests.swift b/Tests/OracleNIOTests/OracleStatementTests.swift index 82f9aa6..c59f35e 100644 --- a/Tests/OracleNIOTests/OracleStatementTests.swift +++ b/Tests/OracleNIOTests/OracleStatementTests.swift @@ -135,7 +135,7 @@ final class OracleStatementTests: XCTestCase { // Testing utility, because we do not have a throwing encodable, luckily :) struct ThrowingByteBuffer: OracleThrowingDynamicTypeEncodable { - let oracleType: OracleNIO.OracleDataType = .raw + static let defaultOracleType: OracleNIO.OracleDataType = .raw var size: UInt32 { UInt32(self.base.readableBytes) } diff --git a/Tests/OracleNIOTests/TestUtils/ConnectionAction+TestUtils.swift b/Tests/OracleNIOTests/TestUtils/ConnectionAction+TestUtils.swift index 66c7073..e7ebbb8 100644 --- a/Tests/OracleNIOTests/TestUtils/ConnectionAction+TestUtils.swift +++ b/Tests/OracleNIOTests/TestUtils/ConnectionAction+TestUtils.swift @@ -126,10 +126,10 @@ extension ConnectionStateMachine.ConnectionAction: Equatable { case (.forwardRows(let lhs), .forwardRows(let rhs)): return lhs == rhs case ( - .forwardStreamComplete(let lhsRows, let lhsCursorID), - .forwardStreamComplete(let rhsRows, let rhsCursorID) + .forwardStreamComplete(let lhsRows, let lhsCursorID, let lhsAffectedRows), + .forwardStreamComplete(let rhsRows, let rhsCursorID, let rhsAffectedRows) ): - return lhsRows == rhsRows && lhsCursorID == rhsCursorID + return lhsRows == rhsRows && lhsCursorID == rhsCursorID && lhsAffectedRows == rhsAffectedRows case ( .forwardStreamError(let lhsError, let lhsRead, let lhsCursorID, let lhsClientCancelled), .forwardStreamError(let rhsError, let rhsRead, let rhsCursorID, let rhsClientCancelled) diff --git a/Tests/OracleNIOTests/TestUtils/OracleRowStream+TestUtils.swift b/Tests/OracleNIOTests/TestUtils/OracleRowStream+TestUtils.swift index 64f129e..9df0c46 100644 --- a/Tests/OracleNIOTests/TestUtils/OracleRowStream+TestUtils.swift +++ b/Tests/OracleNIOTests/TestUtils/OracleRowStream+TestUtils.swift @@ -21,12 +21,18 @@ extension OracleRowStream { convenience init( source: Source, - eventLoop: any EventLoop = EmbeddedEventLoop() + eventLoop: any EventLoop = EmbeddedEventLoop(), + affectedRows: Int? = nil, + rowCounts: [Int]? = nil, + batchErrors: [OracleSQLError.BatchError]? = nil ) { self.init( source: source, eventLoop: eventLoop, - logger: OracleConnection.noopLogger + logger: OracleConnection.noopLogger, + affectedRows: affectedRows, + rowCounts: rowCounts, + batchErrors: batchErrors ) } diff --git a/Tests/OracleNIOTests/TestUtils/QueryResult+TestUtils.swift b/Tests/OracleNIOTests/TestUtils/QueryResult+TestUtils.swift index 36522ec..ab336c5 100644 --- a/Tests/OracleNIOTests/TestUtils/QueryResult+TestUtils.swift +++ b/Tests/OracleNIOTests/TestUtils/QueryResult+TestUtils.swift @@ -19,6 +19,6 @@ import NIOEmbedded extension StatementResult { init(value: Value) { - self.init(value: value, logger: OracleConnection.noopLogger) + self.init(value: value, logger: OracleConnection.noopLogger, batchErrors: nil, rowCounts: nil) } } diff --git a/docker-compose.yaml b/docker-compose.yaml index 20501dc..933c5c6 100644 --- a/docker-compose.yaml +++ b/docker-compose.yaml @@ -10,7 +10,7 @@ services: ports: - 1521:1521 oracle-23: - image: gvenzl/oracle-free:23.4-faststart + image: gvenzl/oracle-free:23.5-faststart environment: ORACLE_PASSWORD: my_very_secure_password APP_USER: my_user