Skip to content

Commit

Permalink
fix: Convert OOB enums from panics to runtime errors (#76)
Browse files Browse the repository at this point in the history
* Convert OOB enums from panics to runtime errors

* run zig fmt

* PR comments
  • Loading branch information
hans-tvs authored Oct 29, 2024
1 parent 9619cfe commit eed464e
Showing 1 changed file with 40 additions and 11 deletions.
51 changes: 40 additions & 11 deletions src/protobuf.zig
Original file line number Diff line number Diff line change
Expand Up @@ -723,7 +723,7 @@ fn VarintDecoderIterator(comptime T: type, comptime varint_type: VarintType) typ
if (self.current_index < self.input.len) {
const raw_value = try decode_varint(u64, self.input[self.current_index..]);
defer self.current_index += raw_value.size;
return decode_varint_value(T, varint_type, raw_value.value);
return try decode_varint_value(T, varint_type, raw_value.value);
}
return null;
}
Expand Down Expand Up @@ -808,7 +808,7 @@ pub const WireDecoderIterator = struct {
};

/// Get a real varint of type T from a raw u64 data.
fn decode_varint_value(comptime T: type, comptime varint_type: VarintType, raw: u64) T {
fn decode_varint_value(comptime T: type, comptime varint_type: VarintType, raw: u64) DecodingError!T {
return switch (varint_type) {
.ZigZagOptimized => switch (@typeInfo(T)) {
.Int => {
Expand All @@ -825,7 +825,10 @@ fn decode_varint_value(comptime T: type, comptime varint_type: VarintType, raw:
else => @compileError("Invalid type " ++ @typeName(T) ++ " passed"),
},
.Bool => raw != 0,
.Enum => @as(T, @enumFromInt(@as(i32, @intCast(raw)))),
.Enum => if (raw > std.math.maxInt(u32))
error.InvalidInput
else
@as(T, @enumFromInt(@as(i32, @bitCast(@as(u32, @intCast(raw)))))),
else => @compileError("Invalid type " ++ @typeName(T) ++ " passed"),
},
};
Expand Down Expand Up @@ -877,7 +880,7 @@ fn decode_packed_list(slice: []const u8, comptime list_type: ListType, comptime
fn decode_value(comptime decoded_type: type, comptime ftype: FieldType, extracted_data: Extracted, allocator: Allocator) !decoded_type {
return switch (ftype) {
.Varint => |varint_type| switch (extracted_data.data) {
.RawValue => |value| decode_varint_value(decoded_type, varint_type, value),
.RawValue => |value| try decode_varint_value(decoded_type, varint_type, value),
else => error.InvalidInput,
},
.FixedInt => switch (extracted_data.data) {
Expand Down Expand Up @@ -917,7 +920,7 @@ fn decode_data(comptime T: type, comptime field_desc: FieldDescriptor, comptime
switch (list_type) {
.Varint => |varint_type| {
switch (extracted_data.data) {
.RawValue => |value| try @field(result, field.name).append(decode_varint_value(child_type, varint_type, value)),
.RawValue => |value| try @field(result, field.name).append(try decode_varint_value(child_type, varint_type, value)),
.Slice => |slice| try decode_packed_list(slice, list_type, child_type, &@field(result, field.name), allocator),
}
},
Expand Down Expand Up @@ -1724,12 +1727,12 @@ test "zigzag i32 - encode" {
}

test "zigzag i32/i64 - decode" {
try testing.expectEqual(@as(i32, 1), decode_varint_value(i32, .ZigZagOptimized, 2));
try testing.expectEqual(@as(i32, -2), decode_varint_value(i32, .ZigZagOptimized, 3));
try testing.expectEqual(@as(i32, -500), decode_varint_value(i32, .ZigZagOptimized, 999));
try testing.expectEqual(@as(i64, -500), decode_varint_value(i64, .ZigZagOptimized, 999));
try testing.expectEqual(@as(i64, -500), decode_varint_value(i64, .ZigZagOptimized, 999));
try testing.expectEqual(@as(i64, -0x80000000), decode_varint_value(i64, .ZigZagOptimized, 0xffffffff));
try testing.expectEqual(@as(i32, 1), try decode_varint_value(i32, .ZigZagOptimized, 2));
try testing.expectEqual(@as(i32, -2), try decode_varint_value(i32, .ZigZagOptimized, 3));
try testing.expectEqual(@as(i32, -500), try decode_varint_value(i32, .ZigZagOptimized, 999));
try testing.expectEqual(@as(i64, -500), try decode_varint_value(i64, .ZigZagOptimized, 999));
try testing.expectEqual(@as(i64, -500), try decode_varint_value(i64, .ZigZagOptimized, 999));
try testing.expectEqual(@as(i64, -0x80000000), try decode_varint_value(i64, .ZigZagOptimized, 0xffffffff));
}

test "zigzag i64 - encode" {
Expand All @@ -1750,3 +1753,29 @@ test "incorrect data - decode" {

try testing.expectError(error.InvalidInput, value);
}

test "incorrect data - simple varint" {
// Incorrectly serialized protobufs can place a varint with a decoded value
// greater than std.math.maxInt(u32) into the slot an enum is supposed to
// fill. Since this library represents a decoded varint as a u64 -- the max
// possible valid varint width -- that data can make its way deep into the
// decode_varint_value routine. This test checks that we handle such failures
// gracefully rather than panicking.
const max_u64 = decode_varint_value(enum(i32) { a, b, c }, .Simple, (1 << 64) - 1);
const barely_too_big = decode_varint_value(enum(i32) { a, b, c }, .Simple, 1 << 32);

try std.testing.expectError(error.InvalidInput, max_u64);
try std.testing.expectError(error.InvalidInput, barely_too_big);
}

test "correct data - simple varint" {
const enum_a = try decode_varint_value(enum(i32) { a = -1, b = 0, c = 1, d = 2 }, .Simple, (1 << 32) - 1);
const enum_b = try decode_varint_value(enum(i32) { a = -1, b = 0, c = 1, d = 2 }, .Simple, 0);
const enum_c = try decode_varint_value(enum(i32) { a = -1, b = 0, c = 1, d = 2 }, .Simple, 1);
const enum_d = try decode_varint_value(enum(i32) { a = -1, b = 0, c = 1, d = 2 }, .Simple, 2);

try std.testing.expectEqual(.a, enum_a);
try std.testing.expectEqual(.b, enum_b);
try std.testing.expectEqual(.c, enum_c);
try std.testing.expectEqual(.d, enum_d);
}

0 comments on commit eed464e

Please sign in to comment.