Skip to content

Commit

Permalink
add tests to ensure oob checks work as expected
Browse files Browse the repository at this point in the history
  • Loading branch information
lovetodream committed Oct 29, 2024
1 parent 4c87cc2 commit 1899abb
Show file tree
Hide file tree
Showing 12 changed files with 209 additions and 20 deletions.
13 changes: 11 additions & 2 deletions Benchmarks/Package.resolved
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,15 @@
"version" : "1.4.0"
}
},
{
"identity" : "swift-asn1",
"kind" : "remoteSourceControl",
"location" : "https://github.com/apple/swift-asn1.git",
"state" : {
"revision" : "7faebca1ea4f9aaf0cda1cef7c43aecd2311ddf6",
"version" : "1.3.0"
}
},
{
"identity" : "swift-async-algorithms",
"kind" : "remoteSourceControl",
Expand Down Expand Up @@ -69,8 +78,8 @@
"kind" : "remoteSourceControl",
"location" : "https://github.com/apple/swift-crypto.git",
"state" : {
"revision" : "46072478ca365fe48370993833cb22de9b41567f",
"version" : "3.5.2"
"revision" : "8fa345c2081cfbd4851dffff5dd5bed48efe6081",
"version" : "3.9.0"
}
},
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ struct ConnectionStateMachine {
enum State {
case initialized
case connectMessageSent
case oobCheckInProgress
case oobCheckInProgress(fastAuth: Bool)
case protocolMessageSent
case dataTypesMessageSent
case waitingToStartAuthentication
Expand Down Expand Up @@ -391,7 +391,7 @@ struct ConnectionStateMachine {
if capabilities.supportsOOB
&& capabilities.protocolVersion >= Constants.TNS_VERSION_MIN_OOB_CHECK
{
self.state = .oobCheckInProgress
self.state = .oobCheckInProgress(fastAuth: capabilities.supportsOOB)
return .sendOOBCheck
}

Expand All @@ -406,13 +406,13 @@ struct ConnectionStateMachine {
return .sendProtocol
}

mutating func oobCheckComplete(capabilities: Capabilities) -> ConnectionAction {
guard case .oobCheckInProgress = self.state else {
mutating func oobCheckComplete() -> ConnectionAction {
guard case .oobCheckInProgress(let fastAuth) = self.state else {
assertionFailure("Why are we completing an OOB check when there isn't one in progress?")
return self.errorHappened(.unexpectedBackendMessage(.resetOOB))

Check warning on line 412 in Sources/OracleNIO/ConnectionStateMachine/ConnectionStateMachine.swift

View check run for this annotation

Codecov / codecov/patch

Sources/OracleNIO/ConnectionStateMachine/ConnectionStateMachine.swift#L411-L412

Added lines #L411 - L412 were not covered by tests
}

if capabilities.supportsFastAuth {
if fastAuth {
self.state = .waitingToStartAuthentication
return .provideAuthenticationContext(.allowed)
}
Expand Down Expand Up @@ -528,13 +528,16 @@ struct ConnectionStateMachine {
mutating func markerReceived() -> ConnectionAction {
switch self.state {
case .initialized,
.oobCheckInProgress,
.waitingToStartAuthentication,
.readyForStatement,
.readyToLogOff,
.closed,
.renegotiatingTLS:
preconditionFailure("Invalid state: \(self.state)")

case .oobCheckInProgress:
return self.oobCheckComplete()

case .connectMessageSent,
.protocolMessageSent,
.dataTypesMessageSent,
Expand Down
2 changes: 1 addition & 1 deletion Sources/OracleNIO/Constants.swift
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ enum Constants {

// MARK: Control packet types
static let TNS_CONTROL_TYPE_INBAND_NOTIFICATION = 8
static let TNS_CONTROL_TYPE_RESET_OOB = 9
static let TNS_CONTROL_TYPE_RESET_OOB: UInt16 = 9

// MARK: Connect flags
static let TNS_GSO_DONT_CARE: UInt16 = 0x0001
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,9 @@ struct OraclePartialDecodingError: Error {
) -> Self {
OraclePartialDecodingError(
description: """
Received a control packet with control type '\(controlType)'.
This is unhandled and should be reported, please file an issue.
""",
Received a control packet with control type '\(controlType)'.
This is unhandled and should be reported, please file an issue.
""",
file: file,
line: line
)
Expand Down
3 changes: 2 additions & 1 deletion Sources/OracleNIO/Messages/OracleBackendMessage.swift
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,8 @@ extension OracleBackendMessage {
if controlType == Constants.TNS_CONTROL_TYPE_RESET_OOB {
return (.init(element: .resetOOB), true)
} else {
throw OraclePartialDecodingError
throw
OraclePartialDecodingError
.unknownControlTypeReceived(controlType: controlType)
}
case .data:
Expand Down
2 changes: 1 addition & 1 deletion Sources/OracleNIO/OracleChannelHandler.swift
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ final class OracleChannelHandler: ChannelDuplexHandler {
action = self.state.lobParameterReceived(parameter: parameter)
case .resetOOB:
self.capabilities.supportsOOB = false
action = self.state.oobCheckComplete(capabilities: self.capabilities)
action = self.state.oobCheckComplete()
}

self.run(action, flags: flags, with: context)
Expand Down
11 changes: 10 additions & 1 deletion Tests/IntegrationTests/OracleNIOTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,16 @@ final class OracleNIOTests: XCTestCase {
// MARK: Tests

func testConnectionAndClose() async throws {
let conn = try await OracleConnection.test(on: self.eventLoop)
var config = OracleConnection.Configuration(
host: "database-1.cbko6yqymxz2.eu-central-1.rds.amazonaws.com",
port: 1521,
service: .sid("ORCL"),
username: "admin",
password: "myRand0mAW3pW"
)
// config.disableOOB = true
let conn = try await OracleConnection.connect(
on: self.eventLoop, configuration: config, id: 1, logger: .oracleTest)

Check warning on line 55 in Tests/IntegrationTests/OracleNIOTests.swift

View check run for this annotation

Codecov / codecov/patch

Tests/IntegrationTests/OracleNIOTests.swift#L46-L55

Added lines #L46 - L55 were not covered by tests
print(conn.serverVersion)
XCTAssertNoThrow(try conn.syncClose())
}
Expand Down
41 changes: 41 additions & 0 deletions Tests/OracleNIOTests/Messages/ControlTests.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
//===----------------------------------------------------------------------===//
//
// This source file is part of the OracleNIO open source project
//
// Copyright (c) 2024 Timo Zacherl and the OracleNIO project authors
// Licensed under Apache License v2.0
//
// See LICENSE for license information
// See CONTRIBUTORS.md for the list of OracleNIO project authors
//
// SPDX-License-Identifier: Apache-2.0
//
//===----------------------------------------------------------------------===//

import NIOCore
import NIOTestUtils
import XCTest

@testable import OracleNIO

final class ControlTests: XCTestCase {
func testResetOOB() throws {
var message = try ByteBuffer(plainHexEncodedBytes: "00 09")
let result = try OracleBackendMessage.decode(
from: &message,
of: .control,
context: .init(capabilities: .desired())
)
XCTAssertEqual(result.0, [.resetOOB])
}

func testUnknown() throws {
var message = try ByteBuffer(plainHexEncodedBytes: "01 09")
try XCTAssertThrowsError(
OracleBackendMessage.decode(
from: &message,
of: .control,
context: .init(capabilities: .desired())
), expected: OraclePartialDecodingError.unknownControlTypeReceived(controlType: 0x0109))
}
}
102 changes: 101 additions & 1 deletion Tests/OracleNIOTests/OracleConnectionTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,104 @@ final class OracleConnectionTests: XCTestCase {
}
}

func testOOBCheckWorks() async throws {
func runTest(supportsOOB: Bool) async throws {
let eventLoop = NIOAsyncTestingEventLoop()
let protocolVersion =
OracleBackendMessageEncoder
.ProtocolVersion(Int(Constants.TNS_VERSION_MINIMUM))
let channel = await NIOAsyncTestingChannel(
handlers: [
ReverseByteToMessageHandler(OracleFrontendMessageDecoder()),
ReverseMessageToByteHandler(
OracleBackendMessageEncoder(protocolVersion: protocolVersion)),
],
loop: eventLoop
)
try await channel.connect(to: .makeAddressResolvingHost("localhost", port: 1521))

let configuration = OracleConnection.Configuration(
establishedChannel: channel,
service: .serviceName("oracle"),
username: "username",
password: "password"
)

async let connectionPromise = OracleConnection.connect(
on: eventLoop,
configuration: configuration,
id: 1,
logger: Logger(label: "OracleConnectionTests")
)

let connect = try await channel.waitForOutboundWrite(as: OracleFrontendMessage.self)
XCTAssertEqual(connect, .connect)
try await channel.writeInbound(
C(messages: [
OracleBackendMessage.accept(.init(newCapabilities: .desired(supportsOOB: true)))
]))
protocolVersion.value = Int(Constants.TNS_VERSION_DESIRED)

let oob = try await channel.waitForOutboundWrite(as: OracleFrontendMessage.self)
XCTAssertEqual(oob, .oob)
let marker = try await channel.waitForOutboundWrite(as: OracleFrontendMessage.self)
XCTAssertEqual(marker, .marker)
if supportsOOB {
try await channel.writeInbound(C(messages: [.marker]))
} else {
try await channel.writeInbound(C(messages: [.resetOOB]))
}

let fastAuth = try await channel.waitForOutboundWrite(as: OracleFrontendMessage.self)
XCTAssertEqual(fastAuth, .fastAuth)
try await channel.writeInbound(
C(messages: [
.parameter([
"AUTH_PBKDF2_CSK_SALT": .init(
value: "CA4861BD9A1BF3CC8DA26D236F7534E3", flags: 0),
"AUTH_SESSKEY": .init(
value: "9F9176A81D9B16F47685024821D6D80064C51B80CD70596C273A99C528599B8E",
flags: 0),
"AUTH_VFR_DATA": .init(
value: "48EE55C6694386C5D6DCCC51343193E0",
flags: Constants.TNS_VERIFIER_TYPE_12C),
"AUTH_PBKDF2_VGEN_COUNT": .init(value: "4096", flags: 0),
"AUTH_PBKDF2_SDER_COUNT": .init(value: "3", flags: 0),
"AUTH_GLOBALLY_UNIQUE_DBID\0": .init(
value: "5D7C6DF1436ADB3A97ED9E44F4C830F7", flags: 0),
])
]))
let authPhase2 = try await channel.waitForOutboundWrite(as: OracleFrontendMessage.self)
XCTAssertEqual(authPhase2, .authPhaseTwo)
try await channel.writeInbound(
C(messages: [
.parameter([
"AUTH_VERSION_NO": .init(value: "386138501", flags: 0),
"AUTH_SESSION_ID": .init(value: "52", flags: 0),
"AUTH_SERIAL_NUM": .init(value: "11865", flags: 0),
])
]))

let connection = try await connectionPromise
XCTAssertEqual("\(connection.serverVersion)", "23.4.0.24.5")

self.addTeardownBlock {
async let closePromise: Void = connection.close()
let logoff = try await channel.waitForOutboundWrite(as: OracleFrontendMessage.self)
XCTAssertEqual(logoff, .logoff)
try await channel.writeInbound(
C(messages: [
.status(.init(callStatus: 0, endToEndSequenceNumber: 0))
]))
let close = try await channel.waitForOutboundWrite(as: OracleFrontendMessage.self)
XCTAssertEqual(close, .close)
try await closePromise
}
}
try await runTest(supportsOOB: true)
try await runTest(supportsOOB: false)
}


// MARK: Utility

Expand Down Expand Up @@ -194,10 +292,12 @@ final class OracleConnectionTests: XCTestCase {
}

extension Capabilities {
static func desired() -> Capabilities {
static func desired(supportsOOB: Bool = false) -> Capabilities {
var caps = Capabilities()
caps.protocolVersion = Constants.TNS_VERSION_DESIRED
caps.protocolOptions = supportsOOB ? Constants.TNS_GSO_CAN_RECV_ATTENTION : caps.protocolOptions
caps.supportsFastAuth = true
caps.supportsOOB = supportsOOB
return caps
}
}
Expand Down
12 changes: 12 additions & 0 deletions Tests/OracleNIOTests/TestUtils/OracleBackendMessageEncoder.swift
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,18 @@ struct OracleBackendMessageEncoder: MessageToByteEncoder {
payload: status,
out: &out
)
case .resetOOB:
struct ResetOOB: OracleMessagePayloadEncodable {
func encode(into buffer: inout ByteBuffer) {
buffer.writeInteger(Constants.TNS_CONTROL_TYPE_RESET_OOB)
}
}
self.encode(id: .control, flags: 0, payload: ResetOOB(), out: &out)
case .marker:
struct Marker: OracleMessagePayloadEncodable {
func encode(into buffer: inout ByteBuffer) {}
}
self.encode(id: .marker, flags: 0, payload: Marker(), out: &out)
default:
fatalError("Not implemented")
}
Expand Down
2 changes: 2 additions & 0 deletions Tests/OracleNIOTests/TestUtils/OracleFrontendMessage.swift
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,6 @@ enum OracleFrontendMessage {
case authPhaseTwo
case logoff
case close
case oob
case marker
}
20 changes: 16 additions & 4 deletions Tests/OracleNIOTests/TestUtils/OracleFrontendMessageDecoder.swift
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,13 @@ struct OracleFrontendMessageDecoder: NIOSingleStepByteToMessageDecoder {
return nil
}

if buffer.readableBytes == 1
&& buffer.getInteger(at: buffer.readerIndex, as: UInt8.self) == Character("!").asciiValue!
{
buffer.moveReaderIndex(forwardBy: 1)
return .oob
}

let startReaderIndex = buffer.readerIndex
let length =
buffer
Expand All @@ -42,11 +49,13 @@ struct OracleFrontendMessageDecoder: NIOSingleStepByteToMessageDecoder {
as: UInt8.self
) ?? 0 // packet flags

let typeByte = buffer.getInteger(
at: startReaderIndex + MemoryLayout<UInt32>.size,
as: UInt8.self
)

guard
let typeByte = buffer.getInteger(
at: startReaderIndex + MemoryLayout<UInt32>.size,
as: UInt8.self
),
let typeByte,
let type = PacketType(rawValue: typeByte)
else {
preconditionFailure("invalid packet")
Expand Down Expand Up @@ -86,6 +95,9 @@ struct OracleFrontendMessageDecoder: NIOSingleStepByteToMessageDecoder {
default:
preconditionFailure("TODO: Unimplemented")
}
case .marker:
buffer.moveReaderIndex(forwardBy: buffer.readableBytes)
return .marker
default:
preconditionFailure("TODO: Unimplemented")
}
Expand Down

0 comments on commit 1899abb

Please sign in to comment.