diff --git a/Sources/OracleNIO/Data/OracleJSON.swift b/Sources/OracleNIO/Data/OracleJSON.swift index de879de..f59890d 100644 --- a/Sources/OracleNIO/Data/OracleJSON.swift +++ b/Sources/OracleNIO/Data/OracleJSON.swift @@ -175,28 +175,16 @@ struct OracleKeyedDecodingContainer: KeyedDecodingContainerProto return value } - func decode(_ type: Double.Type, forKey key: Key) throws -> Double { - let value = try self.getValue(forKey: key) - guard case .double(let value) = value else { - throw self.createTypeMismatchError(type: type, forKey: key, value: value) - } - return value + func decode(_: Double.Type, forKey key: Key) throws -> Double { + try self.decodeFloatingPointNumber(forKey: key) } - func decode(_ type: Float.Type, forKey key: Key) throws -> Float { - let value = try self.getValue(forKey: key) - guard case .float(let value) = value else { - throw self.createTypeMismatchError(type: type, forKey: key, value: value) - } - return value + func decode(_: Float.Type, forKey key: Key) throws -> Float { + try self.decodeFloatingPointNumber(forKey: key) } - func decode(_ type: Int.Type, forKey key: Key) throws -> Int { - let value = try self.getValue(forKey: key) - guard case .int(let value) = value else { - throw self.createTypeMismatchError(type: type, forKey: key, value: value) - } - return value + func decode(_: Int.Type, forKey key: Key) throws -> Int { + try self.decodeBinaryInteger(forKey: key) } func decode(_: Int8.Type, forKey key: Key) throws -> Int8 { @@ -318,12 +306,56 @@ extension OracleKeyedDecodingContainer { } @inline(__always) - private func decodeBinaryInteger(of type: T.Type = T.self, forKey key: Key) throws -> T { + private func decodeBinaryInteger(forKey key: Key) throws -> T { let value = try self.getValue(forKey: key) - guard case .int(let value) = value else { - throw self.createTypeMismatchError(type: type, forKey: key, value: value) + switch value { + case .int(let value): + return T(value) + case .float(let value): + guard let value = T(exactly: value) else { + throw DecodingError.dataCorruptedError( + forKey: key, + in: self, + debugDescription: "Number \(value) does not fit in \(T.self)." + ) + } + return value + case .double(let value): + guard let value = T(exactly: value) else { + throw DecodingError.dataCorruptedError( + forKey: key, + in: self, + debugDescription: "Number \(value) does not fit in \(T.self)." + ) + } + return value + default: + throw self.createTypeMismatchError(type: T.self, forKey: key, value: value) + } + } + + @inline(__always) + private func decodeFloatingPointNumber(forKey key: Key) throws -> T { + let value = try self.getValue(forKey: key) + let (float, original): (T?, any Numeric) = switch value { + case .int(let value): + (T(exactly: value), value) + case .double(let value): + (T(value), value) + case .float(let value): + (T(value), value) + default: + throw self.createTypeMismatchError(type: T.self, forKey: key, value: value) + } + + guard let float else { + throw DecodingError.dataCorruptedError( + forKey: key, + in: self, + debugDescription: "Number \(original) does not fit in \(T.self)." + ) } - return T(value) + return float } } @@ -360,25 +392,16 @@ struct OracleSingleValueDecodingContainer: SingleValueDecodingContainer { return value } - func decode(_ type: Double.Type) throws -> Double { - guard case .double(let value) = self.value else { - throw self.createTypeMismatchError(type: type, value: value) - } - return value + func decode(_: Double.Type) throws -> Double { + try self.decodeFloatingPointNumber() } - func decode(_ type: Float.Type) throws -> Float { - guard case .float(let value) = self.value else { - throw self.createTypeMismatchError(type: type, value: value) - } - return value + func decode(_: Float.Type) throws -> Float { + try self.decodeFloatingPointNumber() } - func decode(_ type: Int.Type) throws -> Int { - guard case .int(let value) = self.value else { - throw self.createTypeMismatchError(type: type, value: value) - } - return value + func decode(_: Int.Type) throws -> Int { + try self.decodeBinaryInteger() } func decode(_: Int8.Type) throws -> Int8 { @@ -452,11 +475,51 @@ extension OracleSingleValueDecodingContainer { } @inline(__always) - private func decodeBinaryInteger(of type: T.Type = T.self) throws -> T { - guard case .int(let value) = self.value else { - throw self.createTypeMismatchError(type: type, value: value) + private func decodeBinaryInteger() throws -> T { + switch self.value { + case .int(let value): + return T(value) + case .float(let value): + guard let value = T(exactly: value) else { + throw DecodingError.dataCorruptedError( + in: self, + debugDescription: "Number \(value) does not fit in \(T.self)." + ) + } + return value + case .double(let value): + guard let value = T(exactly: value) else { + throw DecodingError.dataCorruptedError( + in: self, + debugDescription: "Number \(value) does not fit in \(T.self)." + ) + } + return value + default: + throw self.createTypeMismatchError(type: T.self, value: self.value) + } + } + + @inline(__always) + private func decodeFloatingPointNumber() throws -> T { + let (float, original): (T?, any Numeric) = switch self.value { + case .int(let value): + (T(exactly: value), value) + case .double(let value): + (T(value), value) + case .float(let value): + (T(value), value) + default: + throw self.createTypeMismatchError(type: T.self, value: value) + } + + guard let float else { + throw DecodingError.dataCorruptedError( + in: self, + debugDescription: "Number \(original) does not fit in \(T.self)." + ) } - return T(value) + return float } } @@ -505,34 +568,16 @@ struct OracleUnkeyedDecodingContainer: UnkeyedDecodingContainer { return value } - mutating func decode(_ type: Double.Type) throws -> Double { - let value = try self.getNextValue(ofType: type) - guard case .double(let value) = value else { - throw self.createTypeMismatchError(type: type, value: value) - } - - self.currentIndex += 1 - return value + mutating func decode(_: Double.Type) throws -> Double { + try self.decodeFloatingPointNumber() } - mutating func decode(_ type: Float.Type) throws -> Float { - let value = try self.getNextValue(ofType: type) - guard case .float(let value) = value else { - throw self.createTypeMismatchError(type: type, value: value) - } - - self.currentIndex += 1 - return value + mutating func decode(_: Float.Type) throws -> Float { + try self.decodeFloatingPointNumber() } - mutating func decode(_ type: Int.Type) throws -> Int { - let value = try self.getNextValue(ofType: type) - guard case .int(let value) = value else { - throw self.createTypeMismatchError(type: type, value: value) - } - - self.currentIndex += 1 - return value + mutating func decode(_: Int.Type) throws -> Int { + try self.decodeBinaryInteger() } mutating func decode(_: Int8.Type) throws -> Int8 { @@ -691,14 +736,63 @@ extension OracleUnkeyedDecodingContainer { } @inline(__always) - private mutating func decodeBinaryInteger(of type: T.Type = T.self) throws -> T { + private mutating func decodeBinaryInteger() throws -> T { let value = try self.getNextValue(ofType: Int.self) - guard case .int(let value) = value else { - throw self.createTypeMismatchError(type: type, value: value) + switch value { + case .int(let value): + self.currentIndex += 1 + return T(value) + + case .float(let value): + guard let value = T(exactly: value) else { + throw DecodingError.dataCorruptedError( + in: self, + debugDescription: "Number \(value) does not fit in \(T.self)." + ) + } + + self.currentIndex += 1 + return value + + case .double(let value): + guard let value = T(exactly: value) else { + throw DecodingError.dataCorruptedError( + in: self, + debugDescription: "Number \(value) does not fit in \(T.self)." + ) + } + + self.currentIndex += 1 + return value + + default: + throw self.createTypeMismatchError(type: T.self, value: value) + } + } + + @inline(__always) + private mutating func decodeFloatingPointNumber() throws -> T { + let value = try self.getNextValue(ofType: T.self) + let (float, original): (T?, any Numeric) = switch value { + case .int(let value): + (T(exactly: value), value) + case .double(let value): + (T(value), value) + case .float(let value): + (T(value), value) + default: + throw self.createTypeMismatchError(type: T.self, value: value) + } + + guard let float else { + throw DecodingError.dataCorruptedError( + in: self, + debugDescription: "Number \(original) does not fit in \(T.self)." + ) } self.currentIndex += 1 - return T(value) + return float } } diff --git a/Sources/OracleNIO/OSON/OracleJSONParser.swift b/Sources/OracleNIO/OSON/OracleJSONParser.swift index b33097b..fa93f96 100644 --- a/Sources/OracleNIO/OSON/OracleJSONParser.swift +++ b/Sources/OracleNIO/OSON/OracleJSONParser.swift @@ -147,7 +147,7 @@ struct OracleJSONParser { // skip the field name offsets array for now let offsetsPosition = buffer.readerIndex buffer.moveReaderIndex(forwardBy: fieldsCount * offsetsSize) - var slice = try buffer.throwingReadSlice(length: segmentSize) + let slice = try buffer.throwingReadSlice(length: segmentSize) let finalPosition = buffer.readerIndex // determine the names of the fields diff --git a/Tests/IntegrationTests/JSONTests.swift b/Tests/IntegrationTests/JSONTests.swift index 40acb29..47562e9 100644 --- a/Tests/IntegrationTests/JSONTests.swift +++ b/Tests/IntegrationTests/JSONTests.swift @@ -48,7 +48,10 @@ final class JSONTests: XCTIntegrationTest { """) try await connection.execute( """ - INSERT INTO TestCompressedJson VALUES (1, '{"key": "value", "int": 8, "array": [1, 2, 3]}') + INSERT INTO TestCompressedJson VALUES ( + 1, + '{"key": "value", "int": 8, "array": [1, 2, 3], "bool1": true, "bool2": false, "nested": {"float": 1.2, "double": 1.23, "null": null}}' + ) """) } } @@ -71,13 +74,31 @@ final class JSONTests: XCTIntegrationTest { for try await (id, json) in stream.decode((Int, OracleJSON).self) { XCTAssertEqual(id, 1) let value = try json.decode(as: MyJSON.self) - XCTAssertEqual(value, MyJSON(key: "value", int: 8, array: [1, 2, 3])) + XCTAssertEqual( + value, MyJSON( + key: "value", + int: 8, + array: [1, 2, 3], + bool1: true, + bool2: false, + nested: .init(float: 1.2, double: 1.23, null: nil) + ) + ) } struct MyJSON: Decodable, Equatable { var key: String - var int: Double - var array: [Double] + var int: Int + var array: [Int] + var bool1: Bool + var bool2: Bool + var nested: Nested + + struct Nested: Decodable, Equatable { + var float: Float + var double: Double + var null: String? + } } } }