Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Convert OOB enums from panics to runtime errors #76

Merged
merged 3 commits into from
Oct 29, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 40 additions & 11 deletions src/protobuf.zig
Original file line number Diff line number Diff line change
Expand Up @@ -723,7 +723,7 @@ fn VarintDecoderIterator(comptime T: type, comptime varint_type: VarintType) typ
if (self.current_index < self.input.len) {
const raw_value = try decode_varint(u64, self.input[self.current_index..]);
defer self.current_index += raw_value.size;
return decode_varint_value(T, varint_type, raw_value.value);
return try decode_varint_value(T, varint_type, raw_value.value);
}
return null;
}
Expand Down Expand Up @@ -808,7 +808,7 @@ pub const WireDecoderIterator = struct {
};

/// Get a real varint of type T from a raw u64 data.
fn decode_varint_value(comptime T: type, comptime varint_type: VarintType, raw: u64) T {
fn decode_varint_value(comptime T: type, comptime varint_type: VarintType, raw: u64) DecodingError!T {
return switch (varint_type) {
.ZigZagOptimized => switch (@typeInfo(T)) {
.Int => {
Expand All @@ -825,7 +825,10 @@ fn decode_varint_value(comptime T: type, comptime varint_type: VarintType, raw:
else => @compileError("Invalid type " ++ @typeName(T) ++ " passed"),
},
.Bool => raw != 0,
.Enum => @as(T, @enumFromInt(@as(i32, @intCast(raw)))),
.Enum => if (raw > std.math.maxInt(u32))
error.InvalidInput
else
@as(T, @enumFromInt(@as(i32, @bitCast(@as(u32, @intCast(raw)))))),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is there a new bitCast?

else => @compileError("Invalid type " ++ @typeName(T) ++ " passed"),
},
};
Expand Down Expand Up @@ -877,7 +880,7 @@ fn decode_packed_list(slice: []const u8, comptime list_type: ListType, comptime
fn decode_value(comptime decoded_type: type, comptime ftype: FieldType, extracted_data: Extracted, allocator: Allocator) !decoded_type {
return switch (ftype) {
.Varint => |varint_type| switch (extracted_data.data) {
.RawValue => |value| decode_varint_value(decoded_type, varint_type, value),
.RawValue => |value| try decode_varint_value(decoded_type, varint_type, value),
else => error.InvalidInput,
},
.FixedInt => switch (extracted_data.data) {
Expand Down Expand Up @@ -917,7 +920,7 @@ fn decode_data(comptime T: type, comptime field_desc: FieldDescriptor, comptime
switch (list_type) {
.Varint => |varint_type| {
switch (extracted_data.data) {
.RawValue => |value| try @field(result, field.name).append(decode_varint_value(child_type, varint_type, value)),
.RawValue => |value| try @field(result, field.name).append(try decode_varint_value(child_type, varint_type, value)),
.Slice => |slice| try decode_packed_list(slice, list_type, child_type, &@field(result, field.name), allocator),
}
},
Expand Down Expand Up @@ -1724,12 +1727,12 @@ test "zigzag i32 - encode" {
}

test "zigzag i32/i64 - decode" {
try testing.expectEqual(@as(i32, 1), decode_varint_value(i32, .ZigZagOptimized, 2));
try testing.expectEqual(@as(i32, -2), decode_varint_value(i32, .ZigZagOptimized, 3));
try testing.expectEqual(@as(i32, -500), decode_varint_value(i32, .ZigZagOptimized, 999));
try testing.expectEqual(@as(i64, -500), decode_varint_value(i64, .ZigZagOptimized, 999));
try testing.expectEqual(@as(i64, -500), decode_varint_value(i64, .ZigZagOptimized, 999));
try testing.expectEqual(@as(i64, -0x80000000), decode_varint_value(i64, .ZigZagOptimized, 0xffffffff));
try testing.expectEqual(@as(i32, 1), try decode_varint_value(i32, .ZigZagOptimized, 2));
try testing.expectEqual(@as(i32, -2), try decode_varint_value(i32, .ZigZagOptimized, 3));
try testing.expectEqual(@as(i32, -500), try decode_varint_value(i32, .ZigZagOptimized, 999));
try testing.expectEqual(@as(i64, -500), try decode_varint_value(i64, .ZigZagOptimized, 999));
try testing.expectEqual(@as(i64, -500), try decode_varint_value(i64, .ZigZagOptimized, 999));
try testing.expectEqual(@as(i64, -0x80000000), try decode_varint_value(i64, .ZigZagOptimized, 0xffffffff));
}

test "zigzag i64 - encode" {
Expand All @@ -1750,3 +1753,29 @@ test "incorrect data - decode" {

try testing.expectError(error.InvalidInput, value);
}

test "incorrect data - simple varint" {
// Incorrectly serialized protobufs can place a varint with a decoded value
// greater than std.math.maxInt(u32) into the slot an enum is supposed to
// fill. Since this library represents a decoded varint as a u64 -- the max
// possible valid varint width -- that data can make its way deep into the
// decode_varint_value routine. This test checks that we handle such failures
// gracefully rather than panicking.
const max_u64 = decode_varint_value(enum(i32) { a, b, c }, .Simple, (1 << 64) - 1);
const barely_too_big = decode_varint_value(enum(i32) { a, b, c }, .Simple, 1 << 32);

try std.testing.expectError(error.InvalidInput, max_u64);
try std.testing.expectError(error.InvalidInput, barely_too_big);
}

test "correct data - simple varint" {
const enum_a = try decode_varint_value(enum(i32) { a = -1, b = 0, c = 1, d = 2 }, .Simple, (1 << 32) - 1);
const enum_b = try decode_varint_value(enum(i32) { a = -1, b = 0, c = 1, d = 2 }, .Simple, 0);
const enum_c = try decode_varint_value(enum(i32) { a = -1, b = 0, c = 1, d = 2 }, .Simple, 1);
const enum_d = try decode_varint_value(enum(i32) { a = -1, b = 0, c = 1, d = 2 }, .Simple, 2);

try std.testing.expectEqual(.a, enum_a);
try std.testing.expectEqual(.b, enum_b);
try std.testing.expectEqual(.c, enum_c);
try std.testing.expectEqual(.d, enum_d);
}
Loading