diff --git a/Cargo.lock b/Cargo.lock index 18853c1e..97abb846 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -130,6 +130,28 @@ version = "0.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6736e2428df2ca2848d846c43e88745121a6654696e349ce0054a420815a7409" +[[package]] +name = "binread" +version = "2.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8f9def047620f016f4d5eb5a4c9b925f03175870772d354c696ec9062e942886" +dependencies = [ + "binread_derive", + "lazy_static", +] + +[[package]] +name = "binread_derive" +version = "2.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c575d9a28eb4c2d61747b23d50271c6699b941448c35a738af35267f90b3d4d2" +dependencies = [ + "either", + "proc-macro2 1.0.26", + "quote 1.0.9", + "syn 1.0.68", +] + [[package]] name = "bit-set" version = "0.5.2" @@ -232,7 +254,9 @@ checksum = "b700ce4376041dcd0a327fd0097c41095743c4c8af8887265942faf1100bd040" name = "candid" version = "0.6.21" dependencies = [ + "anyhow", "arbitrary", + "binread", "byteorder", "candid_derive", "codespan-reporting", diff --git a/rust/candid/Cargo.toml b/rust/candid/Cargo.toml index 56dcc78b..883a2aa1 100644 --- a/rust/candid/Cargo.toml +++ b/rust/candid/Cargo.toml @@ -35,6 +35,8 @@ pretty = "0.10.0" serde = { version = "1.0.118", features = ["derive"] } serde_bytes = "0.11" thiserror = "1.0.20" +anyhow = "1.0" +binread = { version = "2.0", features = ["debug_template"] } arbitrary = { version = "0.4.7", optional = true } serde_dhall = { version = "0.10.0", optional = true } diff --git a/rust/candid/src/binary_parser.rs b/rust/candid/src/binary_parser.rs new file mode 100644 index 00000000..79d876f0 --- /dev/null +++ b/rust/candid/src/binary_parser.rs @@ -0,0 +1,262 @@ +use crate::parser::types::FuncMode; +use crate::parser::typing::TypeEnv; +use crate::types::internal::{Field, Function, Label, Type}; +use anyhow::{anyhow, Context, Result}; +use binread::io::{Read, Seek, SeekFrom}; +use binread::{BinRead, BinResult, Error as BError, ReadOptions}; +use std::convert::TryInto; + +fn read_leb(reader: &mut R, ro: &ReadOptions, _: ()) -> BinResult { + let pos = reader.seek(SeekFrom::Current(0))?; + leb128::read::unsigned(reader).map_err(|_| BError::Custom { + pos, + err: Box::new(ro.variable_name.unwrap_or("Invalid leb128")), + }) +} +fn read_sleb(reader: &mut R, ro: &ReadOptions, _: ()) -> BinResult { + let pos = reader.seek(SeekFrom::Current(0))?; + leb128::read::signed(reader).map_err(|_| BError::Custom { + pos, + err: Box::new(ro.variable_name.unwrap_or("Invalid sleb128")), + }) +} + +#[derive(BinRead, Debug)] +#[br(magic = b"DIDL")] +pub struct Header { + table: Table, + #[br(parse_with = read_leb)] + len: u64, + #[br(count = len)] + args: Vec, +} +#[derive(BinRead, Debug)] +struct Table { + #[br(parse_with = read_leb, assert(len <= i64::MAX as u64, "type table size out of range"))] + len: u64, + #[br(count = len)] + table: Vec, +} +#[derive(BinRead, Debug)] +enum ConsType { + #[br(magic = 0x6eu8)] + Opt(Box), + #[br(magic = 0x6du8)] + Vec(Box), + #[br(magic = 0x6cu8)] + Record(Fields), + #[br(magic = 0x6bu8)] + Variant(Fields), + #[br(magic = 0x6au8)] + Func(FuncType), + #[br(magic = 0x69u8)] + Service(ServType), +} +#[derive(BinRead, Debug)] +struct IndexType { + #[br(parse_with = read_sleb, assert(index >= -17 || index == -24, "unknown opcode {}", index))] + index: i64, +} +#[derive(BinRead, Debug)] +struct Fields { + #[br(parse_with = read_leb, try_map = |x:u64| x.try_into().map_err(|_| "field length out of 32-bit range"))] + len: u32, + #[br(count = len)] + inner: Vec, +} +#[derive(BinRead, Debug)] +struct FieldType { + #[br(parse_with = read_leb, try_map = |x:u64| x.try_into().map_err(|_| "field id out of 32-bit range"))] + id: u32, + index: IndexType, +} +#[derive(BinRead, Debug)] +struct FuncType { + #[br(parse_with = read_leb)] + arg_len: u64, + #[br(count = arg_len)] + args: Vec, + #[br(parse_with = read_leb)] + ret_len: u64, + #[br(count = ret_len)] + rets: Vec, + #[br(assert(ann_len <= 1u8, "function annotation length should be at most 1"))] + ann_len: u8, + #[br(count = ann_len)] + ann: Vec, +} +#[derive(BinRead, Debug)] +struct ServType { + #[br(parse_with = read_leb)] + len: u64, + #[br(count = len)] + meths: Vec, +} +#[derive(BinRead, Debug)] +struct Meths { + #[br(parse_with = read_leb)] + len: u64, + #[br(count = len, try_map = |x:Vec| String::from_utf8(x).map_err(|_| "invalid utf8"))] + name: String, + ty: IndexType, +} +#[derive(BinRead, Debug)] +struct Mode { + #[br(try_map = |x:u8| match x { 1u8 => Ok(FuncMode::Query), | 2u8 => Ok(FuncMode::Oneway), | _ => Err("Unknown annotation") })] + inner: FuncMode, +} + +#[derive(BinRead)] +pub struct BoolValue( + #[br(try_map = |x:u8| match x { 0u8 => Ok(false), | 1u8 => Ok(true), | _ => Err("Expect 00 or 01") } )] + pub bool, +); +#[derive(BinRead)] +pub struct Len( + #[br(parse_with = read_leb, try_map = |x:u64| x.try_into().map_err(|_| "length out of usize range"))] + pub usize, +); +#[derive(BinRead)] +pub struct PrincipalBytes { + #[br(assert(flag == 1u8, "Opaque reference not supported"))] + pub flag: u8, + #[br(parse_with = read_leb)] + pub len: u64, + #[br(count = len, parse_with = binread::helpers::read_bytes)] + pub inner: Vec, +} + +fn index_to_var(ind: i64) -> String { + format!("table{}", ind) +} +impl IndexType { + fn to_type(&self, len: u64) -> Result { + Ok(match self.index { + v if v >= 0 => { + if v >= len as i64 { + return Err(anyhow!("type index {} out of range", v)); + } + Type::Var(index_to_var(v)) + } + -1 => Type::Null, + -2 => Type::Bool, + -3 => Type::Nat, + -4 => Type::Int, + -5 => Type::Nat8, + -6 => Type::Nat16, + -7 => Type::Nat32, + -8 => Type::Nat64, + -9 => Type::Int8, + -10 => Type::Int16, + -11 => Type::Int32, + -12 => Type::Int64, + -13 => Type::Float32, + -14 => Type::Float64, + -15 => Type::Text, + -16 => Type::Reserved, + -17 => Type::Empty, + -24 => Type::Principal, + _ => unreachable!(), + }) + } +} +impl ConsType { + fn to_type(&self, len: u64) -> Result { + Ok(match &self { + ConsType::Opt(ref ind) => Type::Opt(Box::new(ind.to_type(len)?)), + ConsType::Vec(ref ind) => Type::Vec(Box::new(ind.to_type(len)?)), + ConsType::Record(fs) | ConsType::Variant(fs) => { + let mut res = Vec::new(); + let mut prev = None; + for f in fs.inner.iter() { + if let Some(prev) = prev { + if prev >= f.id { + return Err(anyhow!("field id {} collision or not sorted", f.id)); + } + } + prev = Some(f.id); + let field = Field { + id: Label::Id(f.id), + ty: f.index.to_type(len)?, + }; + res.push(field); + } + if matches!(&self, ConsType::Record(_)) { + Type::Record(res) + } else { + Type::Variant(res) + } + } + ConsType::Func(f) => { + let mut args = Vec::new(); + let mut rets = Vec::new(); + for arg in f.args.iter() { + args.push(arg.to_type(len)?); + } + for ret in f.rets.iter() { + rets.push(ret.to_type(len)?); + } + Type::Func(Function { + modes: f.ann.iter().map(|x| x.inner.clone()).collect(), + args, + rets, + }) + } + ConsType::Service(serv) => { + let mut res = Vec::new(); + let mut prev = None; + for m in serv.meths.iter() { + if let Some(prev) = prev { + if prev >= &m.name { + return Err(anyhow!("method name {} duplicate or not sorted", m.name)); + } + } + prev = Some(&m.name); + res.push((m.name.clone(), m.ty.to_type(len)?)); + } + Type::Service(res) + } + }) + } +} +impl Table { + fn to_env(&self, len: u64) -> Result { + use std::collections::BTreeMap; + let mut env = BTreeMap::new(); + for (i, t) in self.table.iter().enumerate() { + let ty = t + .to_type(len) + .with_context(|| format!("Invalid table entry {}: {:?}", i, t))?; + env.insert(index_to_var(i as i64), ty); + } + // validate method has func type + for (_, t) in env.iter() { + if let Type::Service(ms) = t { + for (name, ty) in ms.iter() { + if let Type::Var(id) = ty { + if matches!(env.get(id), Some(Type::Func(_))) { + continue; + } + } + return Err(anyhow!("Method {} has a non-function type {}", name, ty)); + } + } + } + Ok(TypeEnv(env)) + } +} +impl Header { + pub fn to_types(&self) -> Result<(TypeEnv, Vec)> { + let len = self.table.len; + let mut env = self.table.to_env(len)?; + env.replace_empty()?; + let mut args = Vec::new(); + for (i, t) in self.args.iter().enumerate() { + args.push( + t.to_type(len) + .with_context(|| format!("Invalid argument entry {}: {:?}", i, t))?, + ); + } + Ok((env, args)) + } +} diff --git a/rust/candid/src/bindings/candid.rs b/rust/candid/src/bindings/candid.rs index 9f360e40..323e5bab 100644 --- a/rust/candid/src/bindings/candid.rs +++ b/rust/candid/src/bindings/candid.rs @@ -115,7 +115,7 @@ pub fn pp_ty(ty: &Type) -> RcDoc { } } Knot(ref id) => RcDoc::text(format!("{}", id)), - Unknown => unreachable!(), + Unknown => str("unknown"), } } diff --git a/rust/candid/src/de.rs b/rust/candid/src/de.rs index 5276f44d..59b021e9 100644 --- a/rust/candid/src/de.rs +++ b/rust/candid/src/de.rs @@ -1,16 +1,19 @@ //! Deserialize Candid binary format to Rust data structures use super::error::{Error, Result}; -use super::types::internal::Opcode; -use super::{idl_hash, Int, Nat}; +use super::{ + parser::typing::TypeEnv, + types::{Field, Label, Type}, + CandidType, Int, Nat, +}; +use crate::binary_parser::{BoolValue, Header, Len, PrincipalBytes}; +use crate::types::subtype::{subtype, Gamma}; +use anyhow::{anyhow, Context}; +use binread::BinRead; use byteorder::{LittleEndian, ReadBytesExt}; -use leb128::read::{signed as sleb128_decode, unsigned as leb128_decode}; -use serde::de::{self, Deserialize, Visitor}; -use std::collections::{BTreeMap, VecDeque}; -use std::convert::TryFrom; -use std::io::Read; - -const MAGIC_NUMBER: &[u8; 4] = b"DIDL"; +use serde::de::{self, Visitor}; +use std::collections::VecDeque; +use std::io::Cursor; /// Use this struct to deserialize a sequence of Rust values (heterogeneous) from IDL binary message. pub struct IDLDeserialize<'de> { @@ -19,432 +22,282 @@ pub struct IDLDeserialize<'de> { impl<'de> IDLDeserialize<'de> { /// Create a new deserializer with IDL binary message. pub fn new(bytes: &'de [u8]) -> Result { - let de = Deserializer::from_bytes(bytes)?; + let de = Deserializer::from_bytes(bytes) + .with_context(|| format!("Cannot parse header {}", &hex::encode(bytes)))?; Ok(IDLDeserialize { de }) } /// Deserialize one value from deserializer. pub fn get_value(&mut self) -> Result where - T: de::Deserialize<'de>, + T: de::Deserialize<'de> + CandidType, + { + self.de.is_untyped = false; + self.deserialize_with_type(T::ty()) + } + pub fn get_value_with_type( + &mut self, + env: &TypeEnv, + expected_type: &Type, + ) -> Result { + self.de.table.merge(env)?; + self.de.is_untyped = true; + self.deserialize_with_type(expected_type.clone()) + } + fn deserialize_with_type(&mut self, expected_type: Type) -> Result + where + T: de::Deserialize<'de> + CandidType, { - let ty = self - .de - .table - .types - .pop_front() - .ok_or_else(|| Error::msg("No more values to deserialize"))?; - self.de.table.current_type.push_back(ty); + let expected_type = self.de.table.trace_type(&expected_type)?; + if self.de.types.is_empty() { + if matches!(expected_type, Type::Opt(_) | Type::Reserved | Type::Null) { + self.de.expect_type = expected_type; + self.de.wire_type = Type::Null; + return T::deserialize(&mut self.de); + } else { + return Err(Error::msg(format!( + "No more values on the wire, the expected type {} is not opt, reserved or null", + expected_type + ))); + } + } - let v = T::deserialize(&mut self.de).map_err(|e| self.de.dump_error_state(e))?; - if self.de.table.current_type.is_empty() && self.de.field_name.is_none() { - Ok(v) + let (ind, ty) = self.de.types.pop_front().unwrap(); + self.de.expect_type = if matches!(expected_type, Type::Unknown) { + self.de.is_untyped = true; + ty.clone() } else { - Err(Error::msg("Trailing type after deserializing a value")) - .map_err(|e| self.de.dump_error_state(e)) - } + expected_type.clone() + }; + self.de.wire_type = ty.clone(); + self.de + .check_subtype() + .with_context(|| self.de.dump_state()) + .with_context(|| { + format!( + "Fail to decode argument {} from {} to {}", + ind, ty, expected_type + ) + })?; + + let v = T::deserialize(&mut self.de) + .with_context(|| self.de.dump_state()) + .with_context(|| { + format!( + "Fail to decode argument {} from {} to {}", + ind, ty, expected_type + ) + })?; + Ok(v) } /// Check if we finish deserializing all values. pub fn is_done(&self) -> bool { - self.de.table.types.is_empty() + self.de.types.is_empty() } /// Return error if there are unprocessed bytes in the input. pub fn done(mut self) -> Result<()> { while !self.is_done() { self.get_value::()?; } - if !self.de.input.0.is_empty() { - return Err(Error::msg("Trailing value after finishing deserialization")) - .map_err(|e| self.de.dump_error_state(e)); + let ind = self.de.input.position() as usize; + let rest = &self.de.input.get_ref()[ind..]; + if !rest.is_empty() { + return Err(anyhow!(self.de.dump_state())) + .context("Trailing value after finishing deserialization")?; } Ok(()) } } -#[derive(Clone, Debug, PartialEq, Eq)] -enum RawValue { - I(i64), - U(u32), -} -impl RawValue { - fn get_i64(&self) -> Result { - match *self { - RawValue::I(i) => Ok(i), - _ => Err(Error::msg(format!("get_i64 fail: {:?}", *self))), - } - } - fn get_u32(&self) -> Result { - match *self { - RawValue::U(u) => Ok(u), - _ => Err(Error::msg(format!("get_u32 fail: {:?}", *self))), - } - } -} - -struct Bytes<'a>(&'a [u8]); -impl<'a> Bytes<'a> { - fn from(input: &'a [u8]) -> Self { - Bytes(input) - } - fn leb128_read(&mut self) -> Result { - leb128_decode(&mut self.0).map_err(Error::msg) - } - fn sleb128_read(&mut self) -> Result { - sleb128_decode(&mut self.0).map_err(Error::msg) - } - fn parse_byte(&mut self) -> Result { - let mut buf = [0u8; 1]; - self.0.read_exact(&mut buf)?; - Ok(buf[0]) - } - fn parse_bytes(&mut self, len: usize) -> Result> { - if self.0.len() < len { - return Err(Error::msg("unexpected end of message")); - } - let mut buf = vec![0; len]; - self.0.read_exact(&mut buf)?; - Ok(buf) - } - fn parse_string(&mut self, len: usize) -> Result { - let buf = self.parse_bytes(len)?; - String::from_utf8(buf).map_err(Error::msg) - } - fn parse_magic(&mut self) -> Result<()> { - let mut buf = [0u8; 4]; - match self.0.read(&mut buf) { - Ok(4) if buf == *MAGIC_NUMBER => Ok(()), - _ => Err(Error::msg(format!("wrong magic number {:?}", buf))), - } - } -} - -struct TypeTable { - // Raw value of the type description table - table: Vec>, - // Value types for deserialization - types: VecDeque, - // The front of current_type queue always points to the type of the value we are deserailizing. - // The type info is cloned from table. Someone more familiar with Rust should see if we can - // rewrite this to avoid copying. - current_type: VecDeque, -} -impl TypeTable { - // Parse the type table and return the remaining bytes - fn from_bytes(input: &[u8]) -> Result<(Self, &[u8])> { - let mut bytes = Bytes::from(input); - let mut table: Vec> = Vec::new(); - let mut types = VecDeque::new(); - - bytes.parse_magic()?; - let len = bytes.leb128_read()? as usize; - let mut expect_func = std::collections::HashSet::new(); - for i in 0..len { - let mut buf = Vec::new(); - let ty = bytes.sleb128_read()?; - buf.push(RawValue::I(ty)); - if expect_func.contains(&i) && ty != -22 { - return Err(Error::msg(format!( - "Expect function opcode, but got {}", - ty - ))); - } - match Opcode::try_from(ty) { - Ok(Opcode::Opt) | Ok(Opcode::Vec) => { - let ty = bytes.sleb128_read()?; - validate_type_range(ty, len)?; - buf.push(RawValue::I(ty)); - } - Ok(Opcode::Record) | Ok(Opcode::Variant) => { - let obj_len = u32::try_from(bytes.leb128_read()?) - .map_err(|_| Error::msg(Error::msg("length out of u32")))?; - buf.push(RawValue::U(obj_len)); - let mut prev_hash = None; - for _ in 0..obj_len { - let hash = u32::try_from(bytes.leb128_read()?) - .map_err(|_| Error::msg(Error::msg("field hash out of u32")))?; - if let Some(prev_hash) = prev_hash { - if prev_hash >= hash { - return Err(Error::msg("field id collision or not sorted")); - } - } - prev_hash = Some(hash); - buf.push(RawValue::U(hash)); - let ty = bytes.sleb128_read()?; - validate_type_range(ty, len)?; - buf.push(RawValue::I(ty)); - } - } - Ok(Opcode::Service) => { - let obj_len = u32::try_from(bytes.leb128_read()?) - .map_err(|_| Error::msg(Error::msg("length out of u32")))?; - // Push one element to the table to ensure it's a non-primitive type - buf.push(RawValue::U(obj_len)); - let mut prev = None; - for _ in 0..obj_len { - let mlen = bytes.leb128_read()? as usize; - let meth = bytes.parse_string(mlen)?; - if let Some(prev) = prev { - if prev >= meth { - return Err(Error::msg("method name collision or not sorted")); - } - } - prev = Some(meth); - let ty = bytes.sleb128_read()?; - validate_type_range(ty, len)?; - // Check for method type - if ty >= 0 { - let idx = ty as usize; - if idx < table.len() && table[idx][0] != RawValue::I(-22) { - return Err(Error::msg("not a function type")); - } else { - expect_func.insert(idx); - } - } else { - return Err(Error::msg("not a function type")); - } - } - } - Ok(Opcode::Func) => { - let arg_len = bytes.leb128_read()?; - // Push one element to the table to ensure it's a non-primitive type - buf.push(RawValue::U(arg_len as u32)); - for _ in 0..arg_len { - let ty = bytes.sleb128_read()?; - validate_type_range(ty, len)?; - } - let ret_len = bytes.leb128_read()?; - for _ in 0..ret_len { - let ty = bytes.sleb128_read()?; - validate_type_range(ty, len)?; - } - let ann_len = bytes.leb128_read()?; - for _ in 0..ann_len { - let ann = bytes.parse_byte()?; - if ann > 2u8 { - return Err(Error::msg("Unknown function annotation")); - } - } - } - _ => { - return Err(Error::msg(format!( - "Unsupported op_code {} in type table", - ty - ))) - } - }; - table.push(buf); - } - let len = bytes.leb128_read()?; - for _i in 0..len { - let ty = bytes.sleb128_read()?; - validate_type_range(ty, table.len())?; - types.push_back(RawValue::I(ty)); - } - Ok(( - TypeTable { - table, - types, - current_type: VecDeque::new(), - }, - bytes.0, - )) - } - fn pop_current_type(&mut self) -> Result { - self.current_type - .pop_front() - .ok_or_else(|| Error::msg("empty current_type")) - } - fn peek_current_type(&self) -> Result<&RawValue> { - self.current_type - .front() - .ok_or_else(|| Error::msg("empty current_type")) - } - fn rawvalue_to_opcode(&self, v: &RawValue) -> Result { - let mut op = v.get_i64()?; - if op >= 0 && op < self.table.len() as i64 { - op = self.table[op as usize][0].get_i64()?; - } - Opcode::try_from(op).map_err(|_| Error::msg(format!("Unknown opcode {}", op))) - } - // Pop type opcode from the front of current_type. - // If the opcode is an index (>= 0), we push the corresponding entry from table, - // to current_type queue, and pop the opcode from the front. - fn parse_type(&mut self) -> Result { - let mut op = self.pop_current_type()?.get_i64()?; - if op >= 0 && op < self.table.len() as i64 { - let ty = &self.table[op as usize]; - for x in ty.iter().rev() { - self.current_type.push_front(x.clone()); - } - op = self.pop_current_type()?.get_i64()?; - } - let r = Opcode::try_from(op).map_err(|_| Error::msg(format!("Unknown opcode {}", op)))?; - Ok(r) - } - // Same logic as parse_type, but not poping the current_type queue. - fn peek_type(&self) -> Result { - let op = self.peek_current_type()?; - self.rawvalue_to_opcode(op) - } - // Check if current_type matches the provided type - fn check_type(&mut self, expected: Opcode) -> Result<()> { - let wire_type = self.parse_type()?; - if wire_type != expected { +macro_rules! assert { + ( false ) => {{ + return Err(Error::msg(format!( + "Internal error at {}:{}. Please file a bug.", + file!(), + line!() + ))); + }}; + ( $pred:expr ) => {{ + if !$pred { return Err(Error::msg(format!( - "Type mismatch. Type on the wire: {:?}; Expected type: {:?}", - wire_type, expected + "Internal error at {}:{}. Please file a bug.", + file!(), + line!() ))); } - Ok(()) - } -} -fn is_primitive_type(ty: i64) -> bool { - ty < 0 && (ty >= -17 || ty == -24) -} -fn validate_type_range(ty: i64, len: usize) -> Result<()> { - if ty >= 0 && (ty as usize) < len || is_primitive_type(ty) { - Ok(()) - } else { - Err(Error::msg(format!("unknown type {}", ty))) - } -} -#[derive(Debug)] -enum FieldLabel { - Named(&'static str), - Id(u32), - Variant(String), - Skip, + }}; } struct Deserializer<'de> { - input: Bytes<'de>, - table: TypeTable, + input: Cursor<&'de [u8]>, + table: TypeEnv, + types: VecDeque<(usize, Type)>, + wire_type: Type, + expect_type: Type, + // Memo table for subtyping relation + gamma: Gamma, // field_name tells deserialize_identifier which field name to process. // This field should always be set by set_field_name function. - field_name: Option, - // The record nesting depth should be bounded by the length of table to avoid infinite loop. - record_nesting_depth: usize, + field_name: Option