diff --git a/src/protobuf.zig b/src/protobuf.zig index 093b449..7c70298 100644 --- a/src/protobuf.zig +++ b/src/protobuf.zig @@ -723,7 +723,7 @@ fn VarintDecoderIterator(comptime T: type, comptime varint_type: VarintType) typ if (self.current_index < self.input.len) { const raw_value = try decode_varint(u64, self.input[self.current_index..]); defer self.current_index += raw_value.size; - return decode_varint_value(T, varint_type, raw_value.value); + return try decode_varint_value(T, varint_type, raw_value.value); } return null; } @@ -808,7 +808,7 @@ pub const WireDecoderIterator = struct { }; /// Get a real varint of type T from a raw u64 data. -fn decode_varint_value(comptime T: type, comptime varint_type: VarintType, raw: u64) T { +fn decode_varint_value(comptime T: type, comptime varint_type: VarintType, raw: u64) DecodingError!T { return switch (varint_type) { .ZigZagOptimized => switch (@typeInfo(T)) { .Int => { @@ -825,7 +825,10 @@ fn decode_varint_value(comptime T: type, comptime varint_type: VarintType, raw: else => @compileError("Invalid type " ++ @typeName(T) ++ " passed"), }, .Bool => raw != 0, - .Enum => @as(T, @enumFromInt(@as(i32, @intCast(raw)))), + .Enum => if (raw > std.math.maxInt(u32)) + error.InvalidInput + else + @as(T, @enumFromInt(@as(i32, @bitCast(@as(u32, @intCast(raw)))))), else => @compileError("Invalid type " ++ @typeName(T) ++ " passed"), }, }; @@ -877,7 +880,7 @@ fn decode_packed_list(slice: []const u8, comptime list_type: ListType, comptime fn decode_value(comptime decoded_type: type, comptime ftype: FieldType, extracted_data: Extracted, allocator: Allocator) !decoded_type { return switch (ftype) { .Varint => |varint_type| switch (extracted_data.data) { - .RawValue => |value| decode_varint_value(decoded_type, varint_type, value), + .RawValue => |value| try decode_varint_value(decoded_type, varint_type, value), else => error.InvalidInput, }, .FixedInt => switch (extracted_data.data) { @@ -917,7 +920,7 @@ fn decode_data(comptime T: type, comptime field_desc: FieldDescriptor, comptime switch (list_type) { .Varint => |varint_type| { switch (extracted_data.data) { - .RawValue => |value| try @field(result, field.name).append(decode_varint_value(child_type, varint_type, value)), + .RawValue => |value| try @field(result, field.name).append(try decode_varint_value(child_type, varint_type, value)), .Slice => |slice| try decode_packed_list(slice, list_type, child_type, &@field(result, field.name), allocator), } }, @@ -1724,12 +1727,12 @@ test "zigzag i32 - encode" { } test "zigzag i32/i64 - decode" { - try testing.expectEqual(@as(i32, 1), decode_varint_value(i32, .ZigZagOptimized, 2)); - try testing.expectEqual(@as(i32, -2), decode_varint_value(i32, .ZigZagOptimized, 3)); - try testing.expectEqual(@as(i32, -500), decode_varint_value(i32, .ZigZagOptimized, 999)); - try testing.expectEqual(@as(i64, -500), decode_varint_value(i64, .ZigZagOptimized, 999)); - try testing.expectEqual(@as(i64, -500), decode_varint_value(i64, .ZigZagOptimized, 999)); - try testing.expectEqual(@as(i64, -0x80000000), decode_varint_value(i64, .ZigZagOptimized, 0xffffffff)); + try testing.expectEqual(@as(i32, 1), try decode_varint_value(i32, .ZigZagOptimized, 2)); + try testing.expectEqual(@as(i32, -2), try decode_varint_value(i32, .ZigZagOptimized, 3)); + try testing.expectEqual(@as(i32, -500), try decode_varint_value(i32, .ZigZagOptimized, 999)); + try testing.expectEqual(@as(i64, -500), try decode_varint_value(i64, .ZigZagOptimized, 999)); + try testing.expectEqual(@as(i64, -500), try decode_varint_value(i64, .ZigZagOptimized, 999)); + try testing.expectEqual(@as(i64, -0x80000000), try decode_varint_value(i64, .ZigZagOptimized, 0xffffffff)); } test "zigzag i64 - encode" { @@ -1750,3 +1753,29 @@ test "incorrect data - decode" { try testing.expectError(error.InvalidInput, value); } + +test "incorrect data - simple varint" { + // Incorrectly serialized protobufs can place a varint with a decoded value + // greater than std.math.maxInt(u32) into the slot an enum is supposed to + // fill. Since this library represents a decoded varint as a u64 -- the max + // possible valid varint width -- that data can make its way deep into the + // decode_varint_value routine. This test checks that we handle such failures + // gracefully rather than panicking. + const max_u64 = decode_varint_value(enum(i32) { a, b, c }, .Simple, (1 << 64) - 1); + const barely_too_big = decode_varint_value(enum(i32) { a, b, c }, .Simple, 1 << 32); + + try std.testing.expectError(error.InvalidInput, max_u64); + try std.testing.expectError(error.InvalidInput, barely_too_big); +} + +test "correct data - simple varint" { + const enum_a = try decode_varint_value(enum(i32) { a = -1, b = 0, c = 1, d = 2 }, .Simple, (1 << 32) - 1); + const enum_b = try decode_varint_value(enum(i32) { a = -1, b = 0, c = 1, d = 2 }, .Simple, 0); + const enum_c = try decode_varint_value(enum(i32) { a = -1, b = 0, c = 1, d = 2 }, .Simple, 1); + const enum_d = try decode_varint_value(enum(i32) { a = -1, b = 0, c = 1, d = 2 }, .Simple, 2); + + try std.testing.expectEqual(.a, enum_a); + try std.testing.expectEqual(.b, enum_b); + try std.testing.expectEqual(.c, enum_c); + try std.testing.expectEqual(.d, enum_d); +}