diff --git a/Sources/OracleNIO/ConnectionStateMachine/ConnectionStateMachine.swift b/Sources/OracleNIO/ConnectionStateMachine/ConnectionStateMachine.swift index 9638298..1beb0fc 100644 --- a/Sources/OracleNIO/ConnectionStateMachine/ConnectionStateMachine.swift +++ b/Sources/OracleNIO/ConnectionStateMachine/ConnectionStateMachine.swift @@ -828,6 +828,7 @@ struct ConnectionStateMachine { guard case .lobOperation(let context) = self.state else { preconditionFailure("How can we receive LOB data in \(self.state)") } + context.fetchedAmount = parameter.amount context.boolFlag = parameter.boolFlag return .wait // waiting for error } @@ -1039,7 +1040,8 @@ extension ConnectionStateMachine { .unexpectedBackendMessage, .serverVersionNotSupported, .sidNotSupported, - .uncleanShutdown: + .uncleanShutdown, + .unsupportedDataType: return true case .statementCancelled, .nationalCharsetNotSupported: return false diff --git a/Sources/OracleNIO/Constants.swift b/Sources/OracleNIO/Constants.swift index 6e79f6b..e68c89f 100644 --- a/Sources/OracleNIO/Constants.swift +++ b/Sources/OracleNIO/Constants.swift @@ -398,7 +398,7 @@ enum Constants { static let TNS_MAX_ROWID_LENGTH = 18 static let TNS_DURATION_MID: UInt32 = 0x8000_0000 static let TNS_DURATION_OFFSET: UInt8 = 60 - static let TNS_DURATION_SESSION: Int64 = 10 + static let TNS_DURATION_SESSION: UInt64 = 10 @usableFromInline static let TNS_MIN_LONG_LENGTH = 0x8000 static let TNS_MAX_LONG_LENGTH: UInt32 = 0x7fff_ffff diff --git a/Sources/OracleNIO/Data/LOB.swift b/Sources/OracleNIO/Data/LOB.swift index 0b02436..2340d5d 100644 --- a/Sources/OracleNIO/Data/LOB.swift +++ b/Sources/OracleNIO/Data/LOB.swift @@ -69,7 +69,7 @@ import NIOCore /// options: .init(fetchLOBs: true) /// ) /// let lob = try lobRef.decode(of: LOB.self) -/// var offset: UInt64 = 1 +/// var offset = 1 /// let chunkSize = 65536 /// while /// buffer.readableBytes > 0, @@ -81,71 +81,31 @@ import NIOCore /// } /// ``` public final class LOB: Sendable { - /// The total size of the data in the LOB. - /// - /// Bytes for BLOBs and USC-2 code points for CLOBs. - /// USC-2 code points are equivalent to characters for all but supplemental characters. - public let size: UInt64 - /// Reading and writing to the LOB in chunks of multiples of this size will improve performance. - public let chunkSize: UInt32 + private let _size: UInt64 + private let _chunkSize: UInt32 let locator: NIOLockedValueBox<[UInt8]> private let hasMetadata: Bool - public let dbType: OracleDataType + public let oracleType: OracleDataType init( size: UInt64, chunkSize: UInt32, locator: [UInt8], hasMetadata: Bool, - dbType: OracleDataType + oracleType: OracleDataType ) { - self.size = size - self.chunkSize = chunkSize + self._size = size + self._chunkSize = chunkSize self.locator = .init(locator) self.hasMetadata = hasMetadata - self.dbType = dbType - } - - static func create(dbType: OracleDataType, locator: [UInt8]?) -> Self { - if let locator { - return self.init( - size: 0, - chunkSize: 0, - locator: locator, - hasMetadata: false, - dbType: dbType - ) - } else { - let locator = [UInt8](repeating: 0, count: 40) - let lob = self.init( - size: 0, - chunkSize: 0, - locator: locator, - hasMetadata: false, - dbType: dbType - ) - // TODO: create temp lob on db - return lob - } - } - - func encoding() -> String { - let locator = self.locator.withLockedValue { $0 } - if dbType.csfrm == Constants.TNS_CS_NCHAR - || (locator.count >= Constants.TNS_LOB_LOCATOR_OFFSET_FLAG_3 - && (locator[Constants.TNS_LOB_LOCATOR_OFFSET_FLAG_3] - & Constants.TNS_LOB_LOCATOR_VAR_LENGTH_CHARSET) != 0) - { - return Constants.TNS_ENCODING_UTF16 - } - return Constants.TNS_ENCODING_UTF8 + self.oracleType = oracleType } func _read( - offset: UInt64 = 1, - amount: UInt64? = nil, + offset: UInt64, + amount: UInt64, on connection: OracleConnection ) async throws -> ByteBuffer? { let promise = connection.eventLoop.makePromise(of: ByteBuffer?.self) @@ -158,7 +118,7 @@ public final class LOB: Sendable { destinationOffset: 0, operation: .read, sendAmount: true, - amount: amount ?? .init(self.chunkSize), + amount: amount, promise: promise )), promise: nil) return try await promise.futureResult.get() @@ -195,13 +155,13 @@ extension LOB { /// - Returns: An async sequence used to iterate over /// the chunks of data read from the connection. public func readChunks( - ofSize chunkSize: UInt64? = nil, + ofSize chunkSize: Int? = nil, on connection: OracleConnection ) -> ReadSequence { ReadSequence( self, connection: connection, - chunkSize: chunkSize ?? .init(self.chunkSize) + chunkSize: UInt64(chunkSize ?? .init(self._chunkSize)) ) } @@ -236,7 +196,6 @@ extension LOB { var chunkSize: UInt64 public mutating func next() async throws -> ByteBuffer? { - if self.offset >= self.base.size { return nil } guard let chunk = try await self.base._read( offset: self.offset, @@ -325,7 +284,7 @@ extension LOB { /// This has to be the same one the LOB reference was created on. public func write( _ buffer: ByteBuffer, - at offset: UInt64 = 1, + at offset: Int = 1, on connection: OracleConnection ) async throws { let promise = connection.eventLoop.makePromise(of: ByteBuffer?.self) @@ -333,7 +292,7 @@ extension LOB { OracleTask.lobOperation( .init( sourceLOB: self, - sourceOffset: offset, + sourceOffset: UInt64(offset), destinationLOB: nil, destinationOffset: 0, operation: .write, @@ -351,7 +310,7 @@ extension LOB { /// - connection: The connection used to trim the LOB. /// This has to be the same one the LOB reference was created on. public func trim( - to newSize: UInt64, + to newSize: Int, on connection: OracleConnection ) async throws { let promise = connection.eventLoop.makePromise(of: ByteBuffer?.self) @@ -364,31 +323,135 @@ extension LOB { destinationOffset: 0, operation: .trim, sendAmount: true, - amount: newSize, + amount: UInt64(newSize), promise: promise )), promise: nil) _ = try await promise.futureResult.get() } + + /// Create a temporary LOB on the given connection. + /// + /// The temporary LOB lives until the connection is closed or explicitly freed by calling + /// ``free(on:)``. + /// + /// It can be inserted in a table at a later point as long as the connection lives. + public static func create( + _ oracleType: OracleDataType, + on connection: OracleConnection + ) async throws -> LOB { + switch oracleType { + case .blob, .clob, .nCLOB: + let locator = [UInt8](repeating: 0, count: 40) + let lob = self.init( + size: 0, + chunkSize: 0, + locator: locator, + hasMetadata: false, + oracleType: oracleType + ) + let promise = connection.eventLoop.makePromise(of: ByteBuffer?.self) + connection.channel.write( + OracleTask.lobOperation( + .init( + sourceLOB: lob, + sourceOffset: UInt64(oracleType.csfrm), + destinationLOB: nil, + destinationOffset: UInt64(oracleType._oracleType?.rawValue ?? 0), + operation: .createTemp, + sendAmount: true, + amount: Constants.TNS_DURATION_SESSION, + promise: promise + )), promise: nil) + _ = try await promise.futureResult.get() + return lob + default: + throw OracleSQLError.unsupportedDataType + } + } + + /// Frees/removes a temporary LOB from the given connection + /// with the next round trip to the database. + public func free(on connection: OracleConnection) async throws { + let handler = try await connection.channel.pipeline + .handler(type: OracleChannelHandler.self).get() + self.free(from: handler.cleanupContext) + } + + /// Retrieve the total size of the data in the LOB. + /// + /// Bytes for BLOBs and USC-2 code points for CLOBs. + /// USC-2 code points are equivalent to characters for all but supplemental characters. + public func size(on connection: OracleConnection) async throws -> Int { + let promise = connection.eventLoop.makePromise(of: ByteBuffer?.self) + let context = LOBOperationContext( + sourceLOB: self, + sourceOffset: 0, + destinationLOB: nil, + destinationOffset: 0, + operation: .getLength, + sendAmount: true, + amount: 0, + promise: promise + ) + connection.channel.write(OracleTask.lobOperation(context), promise: nil) + _ = try await promise.futureResult.get() + return Int(context.fetchedAmount ?? 0) + } + + /// The total size of the LOB data when it was first received from the database. + /// + /// It might have changed already. To get the up-to-date size use ``size(on:)``. + public var estimatedSize: Int { Int(self._size) } + + + /// Reading and writing to the LOB in chunks of multiples of this size will improve performance. + public func chunkSize(on connection: OracleConnection) async throws -> Int { + let promise = connection.eventLoop.makePromise(of: ByteBuffer?.self) + let context = LOBOperationContext( + sourceLOB: self, + sourceOffset: 0, + destinationLOB: nil, + destinationOffset: 0, + operation: .getChunkSize, + sendAmount: true, + amount: 0, + promise: promise + ) + connection.channel.write(OracleTask.lobOperation(context), promise: nil) + _ = try await promise.futureResult.get() + return Int(context.fetchedAmount ?? Int64(self._chunkSize)) + } + + /// Reading and writing to the LOB in chunks of multiples of this size will improve performance. + /// + /// This is the ideal chunk size at the time of fetching the LOB initially, + /// it falls back to a sensible default if the underlying value is `0`. + /// It might have changed in the meantime, to get the current chunk size use ``chunkSize(on:)``. + public var estimatedChunkSize: Int { + if self._chunkSize == 0 { + 8060 + } else { + Int(self._chunkSize) + } + } } extension LOB: OracleEncodable { - public var oracleType: OracleDataType { .blob } - public func encode( into buffer: inout ByteBuffer, context: OracleEncodingContext ) { - preconditionFailure("This should not be called") + let locator = self.locator.withLockedValue { $0 } + let length = locator.count + buffer.writeUB4(UInt32(length)) + ByteBuffer(bytes: locator)._encodeRaw(into: &buffer, context: context) } public func _encodeRaw( into buffer: inout ByteBuffer, context: OracleEncodingContext ) { - let locator = self.locator.withLockedValue { $0 } - let length = locator.count - buffer.writeUB4(UInt32(length)) - ByteBuffer(bytes: locator)._encodeRaw(into: &buffer, context: context) + self.encode(into: &buffer, context: context) } } @@ -408,7 +471,7 @@ extension LOB: OracleDecodable { chunkSize: chunkSize, locator: Array(buffer: locator), hasMetadata: true, - dbType: type + oracleType: type ) default: throw OracleDecodingError.Code.typeMismatch diff --git a/Sources/OracleNIO/Messages/Coding/OracleFrontendMessageEncoder.swift b/Sources/OracleNIO/Messages/Coding/OracleFrontendMessageEncoder.swift index 8472703..829c908 100644 --- a/Sources/OracleNIO/Messages/Coding/OracleFrontendMessageEncoder.swift +++ b/Sources/OracleNIO/Messages/Coding/OracleFrontendMessageEncoder.swift @@ -718,7 +718,9 @@ struct OracleFrontendMessageEncoder { self.buffer.writeBytes(destinationLOB.locator.withLockedValue({ $0 })) } if context.operation == .createTemp { - if let sourceLOB = context.sourceLOB, sourceLOB.dbType.csfrm == Constants.TNS_CS_NCHAR { + if let sourceLOB = context.sourceLOB, + sourceLOB.oracleType.csfrm == Constants.TNS_CS_NCHAR + { try self.capabilities.checkNCharsetID() self.buffer.writeUB4(UInt32(Constants.TNS_CHARSET_UTF16)) } else { diff --git a/Sources/OracleNIO/Messages/OracleBackendMessage+Parameter.swift b/Sources/OracleNIO/Messages/OracleBackendMessage+Parameter.swift index f8f21a0..2e222f8 100644 --- a/Sources/OracleNIO/Messages/OracleBackendMessage+Parameter.swift +++ b/Sources/OracleNIO/Messages/OracleBackendMessage+Parameter.swift @@ -149,19 +149,19 @@ extension OracleBackendMessage { $0 = newLocator } } + let amount: Int64? if context.lobContext?.operation == .createTemp { buffer.skipUB2() // skip character set - } - let amount: Int64? - if context.lobContext?.sendAmount == true { + // skip trailing flags, amount + buffer.moveReaderIndex(forwardBy: 3) + amount = nil + } else if context.lobContext?.sendAmount == true { amount = try buffer.throwingReadSB8() } else { amount = nil } let boolFlag: Bool? - if context.lobContext?.operation == .createTemp - || context.lobContext?.operation == .isOpen - { + if context.lobContext?.operation == .isOpen { // flag boolFlag = try buffer.throwingReadInteger(as: UInt8.self) > 0 } else { diff --git a/Sources/OracleNIO/OracleChannelHandler.swift b/Sources/OracleNIO/OracleChannelHandler.swift index a819766..f571263 100644 --- a/Sources/OracleNIO/OracleChannelHandler.swift +++ b/Sources/OracleNIO/OracleChannelHandler.swift @@ -514,10 +514,13 @@ final class OracleChannelHandler: ChannelDuplexHandler { switch lobContext.operation { case .read: lobContext.promise.succeed(lobContext.data) - case .open, .isOpen, .close, .write, .trim: + case .open, .isOpen, .close, .write, .trim, .createTemp, .getLength, + .getChunkSize: lobContext.promise.succeed(nil) - case .getLength, .getChunkSize, .createTemp, .freeTemp, .array: - fatalError("not yet supported") + case .freeTemp, .array: + preconditionFailure( + "Invalid lob operation: \(lobContext.operation)" + ) } self.run(self.state.readyForStatementReceived(), with: context) case .failLOBOperation(let promise, let error): diff --git a/Sources/OracleNIO/OracleSQLError.swift b/Sources/OracleNIO/OracleSQLError.swift index 2c9b0dc..5f3d580 100644 --- a/Sources/OracleNIO/OracleSQLError.swift +++ b/Sources/OracleNIO/OracleSQLError.swift @@ -35,6 +35,7 @@ public struct OracleSQLError: Sendable, Error { case serverVersionNotSupported case sidNotSupported case missingParameter + case unsupportedDataType } internal var base: Base @@ -61,6 +62,7 @@ public struct OracleSQLError: Sendable, Error { Self(.serverVersionNotSupported) public static let sidNotSupported = Self(.sidNotSupported) public static let missingParameter = Self(.missingParameter) + public static let unsupportedDataType = Self(.unsupportedDataType) public var description: String { switch self.base { @@ -92,6 +94,8 @@ public struct OracleSQLError: Sendable, Error { return "sidNotSupported" case .missingParameter: return "missingParameter" + case .unsupportedDataType: + return "unsupportedDataType" } } } @@ -304,6 +308,8 @@ public struct OracleSQLError: Sendable, Error { return error } + static let unsupportedDataType = OracleSQLError(code: .unsupportedDataType) + } extension OracleSQLError: CustomStringConvertible { diff --git a/Sources/OracleNIO/OracleTask.swift b/Sources/OracleNIO/OracleTask.swift index 1a558bd..3bf8983 100644 --- a/Sources/OracleNIO/OracleTask.swift +++ b/Sources/OracleNIO/OracleTask.swift @@ -53,6 +53,7 @@ final class LOBOperationContext { let amount: UInt64 let promise: EventLoopPromise + var fetchedAmount: Int64? var boolFlag: Bool? var data: ByteBuffer? diff --git a/Tests/IntegrationTests/LOBTests.swift b/Tests/IntegrationTests/LOBTests.swift index 2e9b177..092b309 100644 --- a/Tests/IntegrationTests/LOBTests.swift +++ b/Tests/IntegrationTests/LOBTests.swift @@ -85,7 +85,7 @@ final class LOBTests: XCTIntegrationTest { "INSERT INTO test_simple_blob (id, content) VALUES (1, \(buffer))", logger: .oracleTest ) - func fetchLOB(chunkSize: UInt64?) async throws { + func fetchLOB(chunkSize: Int?) async throws { var queryOptions = StatementOptions() queryOptions.fetchLOBs = true let rows = try await connection.execute( @@ -95,6 +95,8 @@ final class LOBTests: XCTIntegrationTest { ) var index = 0 for try await (id, lob) in rows.decode((Int, LOB).self) { + XCTAssertEqual(lob.estimatedSize, buffer.readableBytes) + XCTAssertGreaterThan(lob.estimatedChunkSize, 0) index += 1 XCTAssertEqual(index, id) var out = ByteBuffer() @@ -121,7 +123,7 @@ final class LOBTests: XCTIntegrationTest { options: .init(fetchLOBs: true) ) let lob = try lobRef.decode(of: LOB.self) - var offset: UInt64 = 1 + var offset = 1 let chunkSize = 65536 while buffer.readableBytes > 0, let slice = @@ -129,8 +131,12 @@ final class LOBTests: XCTIntegrationTest { .readSlice(length: min(chunkSize, buffer.readableBytes)) { try await lob.write(slice, at: offset, on: connection) - offset += UInt64(slice.readableBytes) + let newSize = try await lob.size(on: connection) + offset += slice.readableBytes + XCTAssertEqual(newSize, offset - 1) } + // fast size does not update + XCTAssertEqual(lob.estimatedSize, 0) buffer.moveReaderIndex(to: 0) try await validateLOB(expected: buffer, on: connection) } @@ -144,7 +150,7 @@ final class LOBTests: XCTIntegrationTest { options: .init(fetchLOBs: true) ) let lob = try lobRef.decode(of: LOB.self) - var offset: UInt64 = 1 + var offset = 1 let chunkSize = 65536 try await lob.open(on: connection) while buffer.readableBytes > 0, @@ -153,7 +159,7 @@ final class LOBTests: XCTIntegrationTest { .readSlice(length: min(chunkSize, buffer.readableBytes)) { try await lob.write(slice, at: offset, on: connection) - offset += UInt64(slice.readableBytes) + offset += slice.readableBytes } let isOpen = try await lob.isOpen(on: connection) XCTAssertTrue(isOpen) @@ -164,6 +170,33 @@ final class LOBTests: XCTIntegrationTest { try await validateLOB(expected: buffer, on: connection) } + func testTemporaryLOB() async throws { + let lob = try await LOB.create(.blob, on: connection) + XCTAssertEqual(lob.estimatedChunkSize, 8060) // the default + let chunkSize = try await lob.chunkSize(on: connection) + let buffer = ByteBuffer(bytes: [0x1, 0x2, 0x3, 0x4]) + try await lob.write(buffer, on: connection) + try await connection.execute( + "INSERT INTO test_simple_blob (id, content) VALUES (1, \(lob))" + ) + XCTAssertGreaterThan(chunkSize, 0) + try await lob.free(on: connection) + let optionalBuffer = try await connection.execute( + "SELECT content FROM test_simple_blob WHERE id = 1" + ).collect().first?.decode(ByteBuffer.self) + XCTAssertEqual(buffer, optionalBuffer) + } + + func testCreateLOBFromUnsupportedDataType() async throws { + var thrown: OracleSQLError? + do { + _ = try await LOB.create(.varchar, on: connection) + } catch let error as OracleSQLError { + thrown = error + } + XCTAssertEqual(thrown?.code, OracleSQLError.Code.unsupportedDataType) + } + func testTrimLOB() async throws { let data = try Data(contentsOf: fileURL) let buffer = ByteBuffer(data: data) @@ -181,7 +214,7 @@ final class LOBTests: XCTIntegrationTest { let lob = try XCTUnwrap(optionalLOB) // shrink to half size - try await lob.trim(to: UInt64(buffer.readableBytes / 2), on: connection) + try await lob.trim(to: buffer.readableBytes / 2, on: connection) let trimmed = buffer.getSlice(at: 0, length: buffer.readableBytes / 2)! try await validateLOB(expected: trimmed, on: connection)