From a3213b5a11987f33e3d1a71d6c148b2c5f57d946 Mon Sep 17 00:00:00 2001 From: Timo <38291523+lovetodream@users.noreply.github.com> Date: Fri, 5 Jul 2024 12:28:59 +0200 Subject: [PATCH] feat: draft decoding --- .../StatementStateMachine.swift | 172 +++++----- .../Coding/OracleBackendMessageDecoder.swift | 25 +- Sources/OracleNIO/Messages/DescribeInfo.swift | 1 - .../OracleBackendMessage+BitVector.swift | 2 +- .../OracleBackendMessage+Parameter.swift | 2 +- .../OracleBackendMessage+RowData.swift | 301 +++++++++++++++++- .../Messages/OracleBackendMessage.swift | 25 +- Sources/OracleNIO/OracleChannelHandler.swift | 14 +- Tests/IntegrationTests/OracleNIOTests.swift | 1 + .../StatementStateMachineTests.swift | 6 +- 10 files changed, 424 insertions(+), 125 deletions(-) diff --git a/Sources/OracleNIO/ConnectionStateMachine/StatementStateMachine.swift b/Sources/OracleNIO/ConnectionStateMachine/StatementStateMachine.swift index 4bd5e96..0531b82 100644 --- a/Sources/OracleNIO/ConnectionStateMachine/StatementStateMachine.swift +++ b/Sources/OracleNIO/ConnectionStateMachine/StatementStateMachine.swift @@ -200,6 +200,7 @@ struct StatementStateMachine { // This state might occur, if the client cancelled the statement, // but the server did not yet receive/process the cancellation // marker. Due to that it might send more data without knowing yet. + // TODO: check if we need to forward row header here too return .wait case .initialized, .streamingAndWaiting, .error, .commandComplete: @@ -216,82 +217,104 @@ struct StatementStateMachine { ) -> Action { switch self.state { case .initialized(let context): - let outBinds = context.statement.binds.metadata.compactMap(\.outContainer) - guard !outBinds.isEmpty else { preconditionFailure() } - var buffer = rowData.slice - if context.isReturning { - for outBind in outBinds { - outBind.storage.withLockedValue { $0 = nil } - let rowCount = buffer.readUB4() ?? 0 - guard rowCount > 0 else { - continue - } - - do { - for _ in 0.. 0 else { +// continue +// } +// +// do { +// for _ in 0.. = [] try OracleBackendMessage.decodeData( from: &slice, @@ -830,7 +853,8 @@ struct StatementStateMachine { case .rowHeader(let rowHeader): action = self.rowHeaderReceived(rowHeader) case .rowData(let rowData): - buffer = rowData.slice + fatalError("TODO") + buffer = ByteBuffer() action = self.rowDataReceived0( buffer: &buffer, capabilities: capabilities ) diff --git a/Sources/OracleNIO/Messages/Coding/OracleBackendMessageDecoder.swift b/Sources/OracleNIO/Messages/Coding/OracleBackendMessageDecoder.swift index acc6c28..acd38dc 100644 --- a/Sources/OracleNIO/Messages/Coding/OracleBackendMessageDecoder.swift +++ b/Sources/OracleNIO/Messages/Coding/OracleBackendMessageDecoder.swift @@ -35,21 +35,22 @@ struct OracleBackendMessageDecoder: ByteToMessageDecoder { final class Context { var capabilities: Capabilities - var performingChunkedRead = false - var statementOptions: StatementOptions? = nil - var columnsCount: Int? = nil + var performingChunkedRead = false // TODO: remove + + var statementContext: StatementContext? + var bitVector: [UInt8]? + var describeInfo: DescribeInfo? + var lobContext: LOBOperationContext? - init( - capabilities: Capabilities, - performingChunkedRead: Bool = false, - statementOptions: StatementOptions? = nil, - columnsCount: Int? = nil - ) { + init(capabilities: Capabilities) { self.capabilities = capabilities - self.performingChunkedRead = performingChunkedRead - self.statementOptions = statementOptions - self.columnsCount = columnsCount + } + + func clearStatementContext() { + self.statementContext = nil + self.bitVector = nil + self.describeInfo = nil } } diff --git a/Sources/OracleNIO/Messages/DescribeInfo.swift b/Sources/OracleNIO/Messages/DescribeInfo.swift index b0ea5db..c8c1303 100644 --- a/Sources/OracleNIO/Messages/DescribeInfo.swift +++ b/Sources/OracleNIO/Messages/DescribeInfo.swift @@ -34,7 +34,6 @@ struct DescribeInfo: OracleBackendMessage.PayloadDecodable, Sendable, Hashable { ) throws -> DescribeInfo { buffer.skipUB4() // max row size let columnCount = try buffer.throwingReadUB4() - context.columnsCount = Int(columnCount) if columnCount > 0 { buffer.moveReaderIndex(forwardBy: 1) diff --git a/Sources/OracleNIO/Messages/OracleBackendMessage+BitVector.swift b/Sources/OracleNIO/Messages/OracleBackendMessage+BitVector.swift index f0fb496..981f2bf 100644 --- a/Sources/OracleNIO/Messages/OracleBackendMessage+BitVector.swift +++ b/Sources/OracleNIO/Messages/OracleBackendMessage+BitVector.swift @@ -24,7 +24,7 @@ extension OracleBackendMessage { context: OracleBackendMessageDecoder.Context ) throws -> OracleBackendMessage.BitVector { let columnsCountSent = try buffer.throwingReadUB2() - guard let columnsCount = context.columnsCount else { + guard let columnsCount = context.describeInfo?.columns.count else { preconditionFailure( "How can we receive a bit vector without an active statement?" ) diff --git a/Sources/OracleNIO/Messages/OracleBackendMessage+Parameter.swift b/Sources/OracleNIO/Messages/OracleBackendMessage+Parameter.swift index 2e222f8..c2b3104 100644 --- a/Sources/OracleNIO/Messages/OracleBackendMessage+Parameter.swift +++ b/Sources/OracleNIO/Messages/OracleBackendMessage+Parameter.swift @@ -113,7 +113,7 @@ extension OracleBackendMessage { { buffer.moveReaderIndex(forwardBy: bytesCount) } - if context.statementOptions!.arrayDMLRowCounts == true { + if context.statementContext!.options.arrayDMLRowCounts == true { let numberOfRows = buffer.readUB4() ?? 0 rowCounts = [] for _ in 0.. RowData { - let data = RowData(slice: buffer.slice()) - buffer.moveReaderIndex(to: buffer.readerIndex + buffer.readableBytes) - return data + guard let statementContext = context.statementContext else { + preconditionFailure( + "RowData cannot be received without having a \(String(reflecting: StatementContext.self))" + ) + } + + let columns: [ColumnStorage] + if let describeInfo = context.describeInfo { + columns = try self.processRowData( + from: &buffer, + describeInfo: describeInfo, + context: context + ) + } else { + columns = try self.processBindRow( + from: &buffer, + statementContext: statementContext, + capabilities: context.capabilities + ) + } + + return .init(columns: columns) } - } -} -extension OracleBackendMessage.RowData: CustomDebugStringConvertible { - var debugDescription: String { - "RowData(slice: \(self.slice.readableBytes) bytes)" + private static func isDuplicateData( + columnNumber: UInt32, bitVector: [UInt8]? + ) -> Bool { + guard let bitVector else { return false } + let byteNumber = columnNumber / 8 + let bitNumber = columnNumber % 8 + return bitVector[Int(byteNumber)] & (1 << bitNumber) == 0 + } + + private static func processRowData( + from buffer: inout ByteBuffer, + describeInfo: DescribeInfo, + context: OracleBackendMessageDecoder.Context + ) throws -> [ColumnStorage] { + var columns = [ColumnStorage]() + columns.reserveCapacity(describeInfo.columns.count) + for (index, column) in describeInfo.columns.enumerated() { + if self.isDuplicateData( + columnNumber: UInt32(index), + bitVector: context.bitVector + ) { + columns.append(.duplicate(index)) + } else if let data = try self.processColumnData( + from: &buffer, + oracleType: column.dataType._oracleType, + csfrm: column.dataType.csfrm, + bufferSize: column.bufferSize, + capabilities: context.capabilities + ) { + if index == 0 { + var data = data.getSlice(at: 1, length: data.readableBytes - 1)! + print(try Int(from: &data, type: column.dataType, context: .default)) + } else { + var data = data.getSlice(at: 1, length: data.readableBytes - 1)! + print(try String(from: &data, type: column.dataType, context: .default)) + } + columns.append(.data(data)) + } else { + throw MissingDataDecodingError.Trigger() + } + } + + // reset bit vector after usage + context.bitVector = nil + return columns + } + + private static func processColumnData( + from buffer: inout ByteBuffer, + oracleType: _TNSDataType?, + csfrm: UInt8, + bufferSize: UInt32, + capabilities: Capabilities + ) throws -> ByteBuffer? { + var columnValue: ByteBuffer + if bufferSize == 0 && ![.long, .longRAW, .uRowID].contains(oracleType) { + columnValue = ByteBuffer(bytes: [0]) // NULL indicator + return columnValue + } + + if [.varchar, .char, .long].contains(oracleType) { + if csfrm == Constants.TNS_CS_NCHAR { + try capabilities.checkNCharsetID() + } + // if we need capabilities during decoding in the future, we should + // move this to decoding too + } + + switch oracleType { + case .varchar, .char, .long, .raw, .longRAW, .number, .date, .timestamp, + .timestampLTZ, .timestampTZ, .binaryDouble, .binaryFloat, + .binaryInteger, .boolean, .intervalDS: + switch buffer.readOracleSlice() { + case .some(let slice): + columnValue = slice + case .none: + return nil // need more data + } + case .rowID: + // length is not the actual length of row ids + let length = try buffer.throwingReadInteger(as: UInt8.self) + if length == 0 || length == Constants.TNS_NULL_LENGTH_INDICATOR { + columnValue = ByteBuffer(bytes: [0]) // NULL indicator + } else { + columnValue = ByteBuffer() + try columnValue.writeLengthPrefixed(as: UInt8.self) { + let start = buffer.readerIndex + _ = try RowID(from: &buffer, type: .rowID, context: .default) + let end = buffer.readerIndex + buffer.moveReaderIndex(to: start) + return $0.writeImmutableBuffer(buffer.readSlice(length: end - start)!) + } + } + case .cursor: + buffer.moveReaderIndex(forwardBy: 1) // length (fixed value) + + let readerIndex = buffer.readerIndex + _ = try DescribeInfo._decode( + from: &buffer, context: .init(capabilities: capabilities) + ) + buffer.skipUB2() // cursor id + let length = buffer.readerIndex - readerIndex + buffer.moveReaderIndex(to: readerIndex) + columnValue = ByteBuffer(integer: Constants.TNS_LONG_LENGTH_INDICATOR) + try columnValue.writeLengthPrefixed(as: UInt32.self) { base in + let start = base.writerIndex + try capabilities.encode(into: &base) + base.writeImmutableBuffer(buffer.readSlice(length: length)!) + return base.writerIndex - start + } + columnValue.writeInteger(0, as: UInt32.self) // chunk length of zero + case .clob, .blob: + + // LOB has a UB4 length indicator instead of the usual UInt8 + let length = try buffer.throwingReadUB4() + if length > 0 { + let size = try buffer.throwingReadUB8() + let chunkSize = try buffer.throwingReadUB4() + var locator: ByteBuffer + switch buffer.readOracleSlice() { + case .some(let slice): + locator = slice + case .none: + return nil // need more data + } + columnValue = ByteBuffer() + try columnValue.writeLengthPrefixed(as: UInt8.self) { + $0.writeInteger(size) + $0.writeInteger(chunkSize) + $0.writeBuffer(&locator) + } + } else { + columnValue = .init(bytes: [0]) // empty buffer + } + case .json: + // TODO: OSON + // OSON has a UB4 length indicator instead of the usual UInt8 + fatalError("OSON is not yet implemented, will be added in the future") + case .vector: + let length = try buffer.throwingReadUB4() + if length > 0 { + buffer.skipUB8() // size (unused) + buffer.skipUB4() // chunk size (unused) + switch buffer.readOracleSlice() { + case .some(let slice): + columnValue = slice + case .none: + return nil // need more data + } + if !buffer.skipRawBytesChunked() { // LOB locator (unused) + return nil // need more data + } + } else { + columnValue = .init(bytes: [0]) // empty buffer + } + case .intNamed: + let startIndex = buffer.readerIndex + if try buffer.throwingReadUB4() > 0 { + if !buffer.skipRawBytesChunked() { // type oid + return nil // need more data + } + } + if try buffer.throwingReadUB4() > 0 { + if !buffer.skipRawBytesChunked() { // oid + return nil // need more data + } + } + if try buffer.throwingReadUB4() > 0 { + if !buffer.skipRawBytesChunked() { // snapshot + return nil // need more data + } + } + buffer.skipUB2() // version + let dataLength = try buffer.throwingReadUB4() + buffer.skipUB2() // flags + if dataLength > 0 { + if !buffer.skipRawBytesChunked() { // data + return nil // need more data + } + } + let endIndex = buffer.readerIndex + buffer.moveReaderIndex(to: startIndex) + columnValue = ByteBuffer(integer: Constants.TNS_LONG_LENGTH_INDICATOR) + let length = (endIndex - startIndex) + (MemoryLayout.size * 2) + columnValue.reserveCapacity(minimumWritableBytes: length) + try columnValue.writeLengthPrefixed(as: UInt32.self) { + $0.writeImmutableBuffer(buffer.readSlice(length: endIndex - startIndex)!) + } + columnValue.writeInteger(0, as: UInt32.self) // chunk length of zero + default: + fatalError( + "\(String(reflecting: oracleType)) is not implemented, please file a bug report") + } + + if [.long, .longRAW].contains(oracleType) { + buffer.skipSB4() // null indicator + buffer.skipUB4() // return code + } + + return columnValue + } + + private static func processBindRow( + from buffer: inout ByteBuffer, + statementContext: StatementContext, + capabilities: Capabilities + ) throws -> [ColumnStorage] { + let outBinds = statementContext.statement.binds.metadata.compactMap(\.outContainer) + guard !outBinds.isEmpty else { preconditionFailure() } + var columns: [ColumnStorage] = [] + if statementContext.isReturning { + for outBind in outBinds { + let rowCount = buffer.readUB4() ?? 0 + guard rowCount > 0 else { + continue + } + + for _ in 0.. ByteBuffer { + guard let columnData = try self.processColumnData( + from: &buffer, + oracleType: metadata.dataType._oracleType, + csfrm: metadata.dataType.csfrm, + bufferSize: metadata.bufferSize, + capabilities: capabilities + ) else { + throw MissingDataDecodingError.Trigger() + } + + let actualBytesCount = buffer.readSB4() ?? 0 + if actualBytesCount < 0 && metadata.dataType._oracleType == .boolean { + return ByteBuffer(bytes: [0]) // empty buffer + } else if actualBytesCount != 0 && !columnData.oracleColumnIsEmpty { + // TODO: throw this as error? + preconditionFailure("column truncated, length: \(actualBytesCount)") + } + + return columnData + } } } diff --git a/Sources/OracleNIO/Messages/OracleBackendMessage.swift b/Sources/OracleNIO/Messages/OracleBackendMessage.swift index c85cc85..466fdd1 100644 --- a/Sources/OracleNIO/Messages/OracleBackendMessage.swift +++ b/Sources/OracleNIO/Messages/OracleBackendMessage.swift @@ -158,7 +158,7 @@ extension OracleBackendMessage { break readLoop } case .parameter: - switch context.statementOptions { + switch context.statementContext { case .some: messages.append( try .queryParameter(.decode(from: &buffer, context: context)) @@ -187,13 +187,15 @@ extension OracleBackendMessage { try .ioVector(.decode(from: &buffer, context: context)) ) case .describeInfo: - messages.append( - try .describeInfo(.decode(from: &buffer, context: context)) - ) + let describeInfo = try DescribeInfo.decode(from: &buffer, context: context) + context.describeInfo = describeInfo + messages.append(.describeInfo(describeInfo)) + case .rowHeader: - messages.append( - try .rowHeader(.decode(from: &buffer, context: context)) - ) + let rowHeader = try RowHeader.decode(from: &buffer, context: context) + context.bitVector = rowHeader.bitVector + messages.append(.rowHeader(rowHeader)) + case .rowData: messages.append( try .rowData(.decode(from: &buffer, context: context)) @@ -202,11 +204,12 @@ extension OracleBackendMessage { // OracleChannelHandler, we are performing a chunked // read on all upcoming data packets, because we are // "blind" and don't know what we might get until then. - context.performingChunkedRead = true + context.performingChunkedRead = true // TODO: remove this case .bitVector: - messages.append( - try .bitVector(.decode(from: &buffer, context: context)) - ) + let bitVector = try BitVector.decode(from: &buffer, context: context) + context.bitVector = bitVector.bitVector + messages.append(.bitVector(bitVector)) + case .warning: messages.append( try .warning(.decodeWarning(from: &buffer, context: context)) diff --git a/Sources/OracleNIO/OracleChannelHandler.swift b/Sources/OracleNIO/OracleChannelHandler.swift index f571263..bdffe91 100644 --- a/Sources/OracleNIO/OracleChannelHandler.swift +++ b/Sources/OracleNIO/OracleChannelHandler.swift @@ -402,8 +402,7 @@ final class OracleChannelHandler: ChannelDuplexHandler { if let cleanupContext { self.closeConnectionAndCleanup(cleanupContext, context: context) } - self.decoderContext.statementOptions = nil - self.decoderContext.columnsCount = nil + self.decoderContext.clearStatementContext() self.run(self.state.readyForStatementReceived(), with: context) case .needMoreData: @@ -423,8 +422,7 @@ final class OracleChannelHandler: ChannelDuplexHandler { } rowStream.receive(completion: .success(())) - self.decoderContext.statementOptions = nil - self.decoderContext.columnsCount = nil + self.decoderContext.clearStatementContext() if cursorID != 0 { self.cleanupContext.cursorsToClose.insert(cursorID) @@ -443,8 +441,7 @@ final class OracleChannelHandler: ChannelDuplexHandler { context.read() } - self.decoderContext.statementOptions = nil - self.decoderContext.columnsCount = nil + self.decoderContext.clearStatementContext() if clientCancelled { self.run(self.state.statementStreamCancelled(), with: context) @@ -670,7 +667,7 @@ final class OracleChannelHandler: ChannelDuplexHandler { describeInfo: describeInfo ) - self.decoderContext.statementOptions = statementContext.options + self.decoderContext.statementContext = statementContext context.writeAndFlush( self.wrapOutboundOut(self.encoder.flush()), promise: nil @@ -728,8 +725,7 @@ final class OracleChannelHandler: ChannelDuplexHandler { logger: result.logger ) promise.succeed(rows) - self.decoderContext.statementOptions = nil - self.decoderContext.columnsCount = nil + self.decoderContext.clearStatementContext() self.run(self.state.readyForStatementReceived(), with: context) } diff --git a/Tests/IntegrationTests/OracleNIOTests.swift b/Tests/IntegrationTests/OracleNIOTests.swift index 16be19e..a97dcc3 100644 --- a/Tests/IntegrationTests/OracleNIOTests.swift +++ b/Tests/IntegrationTests/OracleNIOTests.swift @@ -233,6 +233,7 @@ final class OracleNIOTests: XCTestCase { ) var index = 0 for try await row in rows.decode((Int, String).self) { + print(row) XCTAssertEqual(index + 1, row.0) index = row.0 switch index { diff --git a/Tests/OracleNIOTests/ConnectionStateMachine/StatementStateMachineTests.swift b/Tests/OracleNIOTests/ConnectionStateMachine/StatementStateMachineTests.swift index 535275c..b11ff9f 100644 --- a/Tests/OracleNIOTests/ConnectionStateMachine/StatementStateMachineTests.swift +++ b/Tests/OracleNIOTests/ConnectionStateMachine/StatementStateMachineTests.swift @@ -76,7 +76,7 @@ final class StatementStateMachineTests: XCTestCase { XCTAssertEqual(state.rowHeaderReceived(rowHeader), .succeedStatement(promise, result)) let row1: DataRow = .makeTestDataRow(1) XCTAssertEqual( - state.rowDataReceived(.init(slice: .init(bytes: rowData)), capabilities: .init()), + state.rowDataReceived(.init(columns: []), capabilities: .init()), .forwardStreamComplete([row1], cursorID: 1)) } @@ -121,7 +121,7 @@ final class StatementStateMachineTests: XCTestCase { XCTAssertEqual(state.describeInfoReceived(describeInfo), .wait) XCTAssertEqual(state.rowHeaderReceived(rowHeader), .succeedStatement(promise, result)) XCTAssertEqual( - state.rowDataReceived(.init(slice: .init(bytes: rowData)), capabilities: .init()), + state.rowDataReceived(.init(columns: []), capabilities: .init()), .sendFetch(queryContext)) XCTAssertEqual( state.cancelStatementStream(), @@ -168,7 +168,7 @@ final class StatementStateMachineTests: XCTestCase { XCTAssertEqual(state.describeInfoReceived(describeInfo), .wait) XCTAssertEqual(state.rowHeaderReceived(rowHeader), .succeedStatement(promise, result)) XCTAssertEqual( - state.rowDataReceived(.init(slice: .init(bytes: rowData)), capabilities: .init()), + state.rowDataReceived(.init(columns: []), capabilities: .init()), .sendFetch(queryContext)) XCTAssertEqual( state.cancelStatementStream(),