Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Convert OOB enums from panics to runtime errors #76

Merged
merged 3 commits into from
Oct 29, 2024
Merged
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 34 additions & 11 deletions src/protobuf.zig
Original file line number Diff line number Diff line change
Expand Up @@ -723,7 +723,7 @@ fn VarintDecoderIterator(comptime T: type, comptime varint_type: VarintType) typ
if (self.current_index < self.input.len) {
const raw_value = try decode_varint(u64, self.input[self.current_index..]);
defer self.current_index += raw_value.size;
return decode_varint_value(T, varint_type, raw_value.value);
return try decode_varint_value(T, varint_type, raw_value.value);
}
return null;
}
Expand Down Expand Up @@ -808,7 +808,7 @@ pub const WireDecoderIterator = struct {
};

/// Get a real varint of type T from a raw u64 data.
fn decode_varint_value(comptime T: type, comptime varint_type: VarintType, raw: u64) T {
fn decode_varint_value(comptime T: type, comptime varint_type: VarintType, raw: u64) !T {
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

polish: Can we make the error type explicit instead of inferred? I'd like to make a pass on all the method to make them explicits in the near future.

return switch (varint_type) {
.ZigZagOptimized => switch (@typeInfo(T)) {
.Int => {
Expand All @@ -825,7 +825,10 @@ fn decode_varint_value(comptime T: type, comptime varint_type: VarintType, raw:
else => @compileError("Invalid type " ++ @typeName(T) ++ " passed"),
},
.Bool => raw != 0,
.Enum => @as(T, @enumFromInt(@as(i32, @intCast(raw)))),
.Enum => if ((raw >> 32) != 0)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion: wouldn't it be clearer to use std.math.maxInt(i32)?

.Enum => if(raw > std.math.maxInt(i32))

error.InvalidInput
else
@as(T, @enumFromInt(@as(i32, @bitCast(@as(u32, @intCast(raw)))))),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is there a new bitCast?

else => @compileError("Invalid type " ++ @typeName(T) ++ " passed"),
},
};
Expand Down Expand Up @@ -877,7 +880,7 @@ fn decode_packed_list(slice: []const u8, comptime list_type: ListType, comptime
fn decode_value(comptime decoded_type: type, comptime ftype: FieldType, extracted_data: Extracted, allocator: Allocator) !decoded_type {
return switch (ftype) {
.Varint => |varint_type| switch (extracted_data.data) {
.RawValue => |value| decode_varint_value(decoded_type, varint_type, value),
.RawValue => |value| try decode_varint_value(decoded_type, varint_type, value),
else => error.InvalidInput,
},
.FixedInt => switch (extracted_data.data) {
Expand Down Expand Up @@ -917,7 +920,7 @@ fn decode_data(comptime T: type, comptime field_desc: FieldDescriptor, comptime
switch (list_type) {
.Varint => |varint_type| {
switch (extracted_data.data) {
.RawValue => |value| try @field(result, field.name).append(decode_varint_value(child_type, varint_type, value)),
.RawValue => |value| try @field(result, field.name).append(try decode_varint_value(child_type, varint_type, value)),
.Slice => |slice| try decode_packed_list(slice, list_type, child_type, &@field(result, field.name), allocator),
}
},
Expand Down Expand Up @@ -1724,12 +1727,12 @@ test "zigzag i32 - encode" {
}

test "zigzag i32/i64 - decode" {
try testing.expectEqual(@as(i32, 1), decode_varint_value(i32, .ZigZagOptimized, 2));
try testing.expectEqual(@as(i32, -2), decode_varint_value(i32, .ZigZagOptimized, 3));
try testing.expectEqual(@as(i32, -500), decode_varint_value(i32, .ZigZagOptimized, 999));
try testing.expectEqual(@as(i64, -500), decode_varint_value(i64, .ZigZagOptimized, 999));
try testing.expectEqual(@as(i64, -500), decode_varint_value(i64, .ZigZagOptimized, 999));
try testing.expectEqual(@as(i64, -0x80000000), decode_varint_value(i64, .ZigZagOptimized, 0xffffffff));
try testing.expectEqual(@as(i32, 1), try decode_varint_value(i32, .ZigZagOptimized, 2));
try testing.expectEqual(@as(i32, -2), try decode_varint_value(i32, .ZigZagOptimized, 3));
try testing.expectEqual(@as(i32, -500), try decode_varint_value(i32, .ZigZagOptimized, 999));
try testing.expectEqual(@as(i64, -500), try decode_varint_value(i64, .ZigZagOptimized, 999));
try testing.expectEqual(@as(i64, -500), try decode_varint_value(i64, .ZigZagOptimized, 999));
try testing.expectEqual(@as(i64, -0x80000000), try decode_varint_value(i64, .ZigZagOptimized, 0xffffffff));
}

test "zigzag i64 - encode" {
Expand All @@ -1750,3 +1753,23 @@ test "incorrect data - decode" {

try testing.expectError(error.InvalidInput, value);
}

test "incorrect data - simple varint" {
const max_u64 = decode_varint_value(enum(i32) { a, b, c }, .Simple, (1 << 64) - 1);
const barely_too_big = decode_varint_value(enum(i32) { a, b, c }, .Simple, 1 << 32);

try std.testing.expectError(error.InvalidInput, max_u64);
try std.testing.expectError(error.InvalidInput, barely_too_big);
}

test "correct data - simple varint" {
const enum_a = try decode_varint_value(enum(i32) { a = -1, b = 0, c = 1, d = 2 }, .Simple, (1 << 32) - 1);
const enum_b = try decode_varint_value(enum(i32) { a = -1, b = 0, c = 1, d = 2 }, .Simple, 0);
const enum_c = try decode_varint_value(enum(i32) { a = -1, b = 0, c = 1, d = 2 }, .Simple, 1);
const enum_d = try decode_varint_value(enum(i32) { a = -1, b = 0, c = 1, d = 2 }, .Simple, 2);

try std.testing.expectEqual(.a, enum_a);
try std.testing.expectEqual(.b, enum_b);
try std.testing.expectEqual(.c, enum_c);
try std.testing.expectEqual(.d, enum_d);
}
Loading