From b68d5266778a08f84611b2ebf11258b7f424b32a Mon Sep 17 00:00:00 2001 From: chenyan-dfinity Date: Tue, 23 Mar 2021 13:59:16 -0700 Subject: [PATCH 01/51] refactor deserialization --- Cargo.lock | 22 ++++++++++ rust/candid/Cargo.toml | 1 + rust/candid/src/binary_parser.rs | 65 +++++++++++++++++++++++++++++ rust/candid/src/error.rs | 70 +++++++++++++++++++++++++++++++- rust/candid/src/lib.rs | 1 + 5 files changed, 157 insertions(+), 2 deletions(-) create mode 100644 rust/candid/src/binary_parser.rs diff --git a/Cargo.lock b/Cargo.lock index 3a2c28ca..d48d30e5 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -130,6 +130,27 @@ version = "0.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6736e2428df2ca2848d846c43e88745121a6654696e349ce0054a420815a7409" +[[package]] +name = "binread" +version = "2.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8f9def047620f016f4d5eb5a4c9b925f03175870772d354c696ec9062e942886" +dependencies = [ + "binread_derive", +] + +[[package]] +name = "binread_derive" +version = "2.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c575d9a28eb4c2d61747b23d50271c6699b941448c35a738af35267f90b3d4d2" +dependencies = [ + "either", + "proc-macro2 1.0.24", + "quote 1.0.9", + "syn 1.0.62", +] + [[package]] name = "bit-set" version = "0.5.2" @@ -233,6 +254,7 @@ name = "candid" version = "0.6.19" dependencies = [ "arbitrary", + "binread", "byteorder", "candid_derive", "codespan-reporting", diff --git a/rust/candid/Cargo.toml b/rust/candid/Cargo.toml index e907ce4e..cdd77474 100644 --- a/rust/candid/Cargo.toml +++ b/rust/candid/Cargo.toml @@ -35,6 +35,7 @@ pretty = "0.10.0" serde = { version = "1.0.118", features = ["derive"] } serde_bytes = "0.11" thiserror = "1.0.20" +binread = "2.0" arbitrary = { version = "0.4.7", optional = true } serde_dhall = { version = "0.10.0", optional = true } diff --git a/rust/candid/src/binary_parser.rs b/rust/candid/src/binary_parser.rs new file mode 100644 index 00000000..1ee75487 --- /dev/null +++ b/rust/candid/src/binary_parser.rs @@ -0,0 +1,65 @@ +use binread::io::{Read, Seek, SeekFrom}; +use binread::{BinRead, BinResult, Error, ReadOptions}; + +fn read_leb(reader: &mut R, _ro: &ReadOptions, _: ()) -> BinResult { + let pos = reader.seek(SeekFrom::Current(0))?; + leb128::read::unsigned(reader).map_err(|_| Error::Custom { + pos, + err: Box::new("Invalid leb128"), + }) +} +fn read_sleb(reader: &mut R, _ro: &ReadOptions, _: ()) -> BinResult { + let pos = reader.seek(SeekFrom::Current(0))?; + leb128::read::signed(reader).map_err(|_| Error::Custom { + pos, + err: Box::new("Invalid sleb128"), + }) +} + +#[derive(BinRead, Debug)] +#[br(magic = b"DIDL")] +struct Table { + #[br(parse_with = read_leb)] + len: u64, + #[br(count = len)] + table: Vec, +} +#[derive(BinRead, Debug)] +enum ConsType { + #[br(magic = 0x6eu8)] + Opt(Box), + #[br(magic = 0x6du8)] + Vec(Box), + #[br(magic = 0x6cu8)] + Record(Fields), +} +#[derive(BinRead, Debug)] +struct IndexType { + #[br(parse_with = read_sleb)] + index: i64, +} +#[derive(BinRead, Debug)] +struct Fields { + #[br(parse_with = read_leb)] + len: u64, + #[br(count = len)] + inner: Vec, +} +#[derive(BinRead, Debug)] +struct FieldType { + #[br(parse_with = read_leb)] + id: u64, + index: IndexType, +} + +#[test] +fn parse() -> crate::Result<()> { + use binread::io::Cursor; + let mut reader = Cursor::new(b"DIDL\x03\x6e\x00\x6d\x7f\x6c\x01\x00\x7e".as_ref()); + //let table = Table::read(&mut reader)?; + let table = crate::error::pretty_read::(&mut reader)?; + println!("{:?}", table); + let rest = reader.position() as usize; + println!("remaining {:02x?}", &reader.into_inner()[rest..]); + Ok(()) +} diff --git a/rust/candid/src/error.rs b/rust/candid/src/error.rs index b1c868a0..63f1e1ad 100644 --- a/rust/candid/src/error.rs +++ b/rust/candid/src/error.rs @@ -11,11 +11,14 @@ use thiserror::Error; pub type Result = std::result::Result; -#[derive(Debug, Error, Eq, PartialEq)] +#[derive(Debug, Error)] pub enum Error { - #[error("Candid parser error: {0:}")] + #[error("Candid parser error: {0}")] Parse(#[from] token::ParserError), + #[error("Binary parser error: {0}")] + Binread(#[from] binread::Error), + #[error("Deserialize error: {0}")] Deserialize(String, String), @@ -56,12 +59,60 @@ impl Error { }; diag.with_labels(vec![label]) } + Error::Binread(e) => { + let diag = Diagnostic::error().with_message("decoding error"); + let labels = get_binread_labels(e); + diag.with_labels(labels) + } Error::Deserialize(e, _) => Diagnostic::error().with_message(e), Error::Custom(e) => Diagnostic::error().with_message(e), } } } +fn get_binread_labels(e: &binread::Error) -> Vec> { + use binread::Error::*; + match e { + BadMagic { pos, .. } => { + let pos = (pos * 2) as usize; + vec![Label::primary((), pos..pos + 2).with_message("Unexpected bytes")] + } + Custom { pos, err } => { + let pos = (pos * 2) as usize; + let err = err.downcast_ref::<&str>().unwrap(); + vec![Label::primary((), pos..pos + 2).with_message(err.to_string())] + } + EnumErrors { + pos, + variant_errors, + } => { + let pos = (pos * 2) as usize; + let variant = variant_errors + .iter() + .find(|(_, e)| !matches!(e, BadMagic { .. })); + // Should be only one non-magic error + match variant { + None => vec![Label::primary((), pos..pos + 2).with_message("Unknown opcode")], + Some((id, e)) => { + let mut labels = get_binread_labels(e); + labels.push(Label::secondary((), pos..pos + 2).with_message(id.to_string())); + labels + } + } + } + NoVariantMatch { pos } => { + let pos = (pos * 2) as usize; + vec![Label::primary((), pos..pos + 2).with_message("No variant match")] + } + AssertFail { pos, message } => { + let pos = (pos * 2) as usize; + vec![Label::primary((), pos..pos + 2).with_message(message)] + } + Io(e) => vec![Label::primary((), 0..0).with_message(e.to_string())], + _ => Vec::new(), + } +} + fn report_expected(expected: &[String]) -> Vec { if expected.is_empty() { return Vec::new(); @@ -124,3 +175,18 @@ where Err(e) }) } + +pub fn pretty_read(reader: &mut std::io::Cursor<&[u8]>) -> Result +where + T: binread::BinRead, +{ + T::read(reader).or_else(|e| { + let e = Error::Binread(e); + let writer = StandardStream::stderr(term::termcolor::ColorChoice::Auto); + let config = term::Config::default(); + let str = hex::encode(&reader.get_ref()); + let file = SimpleFile::new("binary", &str); + term::emit(&mut writer.lock(), &config, &file, &e.report())?; + Err(e) + }) +} diff --git a/rust/candid/src/lib.rs b/rust/candid/src/lib.rs index 54bbcccd..ea449441 100644 --- a/rust/candid/src/lib.rs +++ b/rust/candid/src/lib.rs @@ -270,6 +270,7 @@ pub use parser::types::IDLProg; pub use parser::typing::{check_prog, TypeEnv}; pub use parser::value::IDLArgs; +pub mod binary_parser; pub mod de; pub use de::{decode_args, decode_one}; pub mod ser; From 6f80299704ff41e999e0ee6e2c090c1f6cda1d67 Mon Sep 17 00:00:00 2001 From: chenyan-dfinity Date: Sat, 27 Mar 2021 14:44:07 -0700 Subject: [PATCH 02/51] checkpoint --- Cargo.lock | 1 + rust/candid/Cargo.toml | 2 +- rust/candid/src/binary_parser.rs | 98 +++++++++++++++++++++++++++----- rust/candid/src/parser/typing.rs | 9 ++- 4 files changed, 94 insertions(+), 16 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index d48d30e5..28d34fa9 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -137,6 +137,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8f9def047620f016f4d5eb5a4c9b925f03175870772d354c696ec9062e942886" dependencies = [ "binread_derive", + "lazy_static", ] [[package]] diff --git a/rust/candid/Cargo.toml b/rust/candid/Cargo.toml index cdd77474..c09ad71b 100644 --- a/rust/candid/Cargo.toml +++ b/rust/candid/Cargo.toml @@ -35,7 +35,7 @@ pretty = "0.10.0" serde = { version = "1.0.118", features = ["derive"] } serde_bytes = "0.11" thiserror = "1.0.20" -binread = "2.0" +binread = { version = "2.0", features = ["debug_template"] } arbitrary = { version = "0.4.7", optional = true } serde_dhall = { version = "0.10.0", optional = true } diff --git a/rust/candid/src/binary_parser.rs b/rust/candid/src/binary_parser.rs index 1ee75487..079c0b4d 100644 --- a/rust/candid/src/binary_parser.rs +++ b/rust/candid/src/binary_parser.rs @@ -1,24 +1,36 @@ +use crate::parser::typing::TypeEnv; +use crate::types::internal::{Field, Label, Type}; use binread::io::{Read, Seek, SeekFrom}; use binread::{BinRead, BinResult, Error, ReadOptions}; +use std::convert::TryInto; -fn read_leb(reader: &mut R, _ro: &ReadOptions, _: ()) -> BinResult { +fn read_leb(reader: &mut R, ro: &ReadOptions, _: ()) -> BinResult { let pos = reader.seek(SeekFrom::Current(0))?; leb128::read::unsigned(reader).map_err(|_| Error::Custom { pos, - err: Box::new("Invalid leb128"), + err: Box::new(ro.variable_name.unwrap_or("Invalid leb128")), }) } -fn read_sleb(reader: &mut R, _ro: &ReadOptions, _: ()) -> BinResult { +fn read_sleb(reader: &mut R, ro: &ReadOptions, _: ()) -> BinResult { let pos = reader.seek(SeekFrom::Current(0))?; leb128::read::signed(reader).map_err(|_| Error::Custom { pos, - err: Box::new("Invalid sleb128"), + err: Box::new(ro.variable_name.unwrap_or("Invalid sleb128")), }) } #[derive(BinRead, Debug)] #[br(magic = b"DIDL")] -struct Table { +pub struct Header { + table: Table, + #[br(parse_with = read_leb)] + len: u64, + #[br(count = len)] + args: Vec, +} + +#[derive(BinRead, Debug)] +pub struct Table { #[br(parse_with = read_leb)] len: u64, #[br(count = len)] @@ -32,33 +44,91 @@ enum ConsType { Vec(Box), #[br(magic = 0x6cu8)] Record(Fields), + #[br(magic = 0x6bu8)] + Variant(Fields), } #[derive(BinRead, Debug)] struct IndexType { - #[br(parse_with = read_sleb)] + #[br(parse_with = read_sleb, assert(index >= -17 || index == -24, "unknown opcode {}", index))] index: i64, } #[derive(BinRead, Debug)] struct Fields { - #[br(parse_with = read_leb)] - len: u64, + #[br(parse_with = read_leb, try_map = |x:u64| x.try_into())] + len: u32, #[br(count = len)] inner: Vec, } #[derive(BinRead, Debug)] struct FieldType { - #[br(parse_with = read_leb)] - id: u64, + #[br(parse_with = read_leb, try_map = |x:u64| x.try_into())] + id: u32, index: IndexType, } +impl IndexType { + pub fn to_type(&self) -> Type { + match self.index { + v if v >= 0 => Type::Var(v.to_string()), + -1 => Type::Null, + -2 => Type::Bool, + -3 => Type::Nat, + -4 => Type::Int, + -5 => Type::Nat8, + -6 => Type::Nat16, + -7 => Type::Nat32, + -8 => Type::Nat64, + -9 => Type::Int8, + -10 => Type::Int16, + -11 => Type::Int32, + -12 => Type::Int64, + -13 => Type::Float32, + -14 => Type::Float64, + -15 => Type::Text, + -16 => Type::Reserved, + -17 => Type::Empty, + -24 => Type::Principal, + _ => unreachable!(), + } + } +} +impl FieldType { + pub fn to_field(&self) -> Field { + Field { + id: Label::Id(self.id), + ty: self.index.to_type(), + } + } +} +impl ConsType { + pub fn to_type(&self) -> Type { + match &self { + ConsType::Opt(ref ind) => Type::Opt(Box::new(ind.to_type())), + ConsType::Vec(ref ind) => Type::Vec(Box::new(ind.to_type())), + ConsType::Record(fs) => Type::Record(fs.inner.iter().map(|f| f.to_field()).collect()), + ConsType::Variant(fs) => Type::Variant(fs.inner.iter().map(|f| f.to_field()).collect()), + } + } +} +impl Table { + pub fn to_env(&self) -> TypeEnv { + TypeEnv( + self.table + .iter() + .enumerate() + .map(|(i, t)| (i.to_string(), t.to_type())) + .collect(), + ) + } +} + #[test] fn parse() -> crate::Result<()> { use binread::io::Cursor; - let mut reader = Cursor::new(b"DIDL\x03\x6e\x00\x6d\x7f\x6c\x01\x00\x7e".as_ref()); - //let table = Table::read(&mut reader)?; - let table = crate::error::pretty_read::
(&mut reader)?; - println!("{:?}", table); + let mut reader = + Cursor::new(b"DIDL\x03\x6e\x00\x6d\x6f\x6c\x02\x00\x7e\x01\x7a\x02\x02\x03".as_ref()); + let header = crate::error::pretty_read::
(&mut reader)?; + println!("{}", header.table.to_env()); let rest = reader.position() as usize; println!("remaining {:02x?}", &reader.into_inner()[rest..]); Ok(()) diff --git a/rust/candid/src/parser/typing.rs b/rust/candid/src/parser/typing.rs index 82165ba6..7aa2a537 100644 --- a/rust/candid/src/parser/typing.rs +++ b/rust/candid/src/parser/typing.rs @@ -75,7 +75,14 @@ impl TypeEnv { Err(Error::msg(format!("cannot find method {}", id))) } } - +impl std::fmt::Display for TypeEnv { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + for (k, v) in self.0.iter() { + writeln!(f, "type {} = {}", k, v)?; + } + Ok(()) + } +} fn check_prim(prim: &PrimType) -> Type { match prim { PrimType::Nat => Type::Nat, From 1a133c4f739f8947c9d27c10603d09b749254cd0 Mon Sep 17 00:00:00 2001 From: chenyan-dfinity Date: Sat, 27 Mar 2021 15:14:18 -0700 Subject: [PATCH 03/51] checkpoint --- rust/candid/src/binary_parser.rs | 94 +++++++++++++++++++++----------- 1 file changed, 62 insertions(+), 32 deletions(-) diff --git a/rust/candid/src/binary_parser.rs b/rust/candid/src/binary_parser.rs index 079c0b4d..0ba1cdf6 100644 --- a/rust/candid/src/binary_parser.rs +++ b/rust/candid/src/binary_parser.rs @@ -1,19 +1,20 @@ use crate::parser::typing::TypeEnv; use crate::types::internal::{Field, Label, Type}; +use crate::{Error, Result}; use binread::io::{Read, Seek, SeekFrom}; -use binread::{BinRead, BinResult, Error, ReadOptions}; +use binread::{BinRead, BinResult, Error as BError, ReadOptions}; use std::convert::TryInto; fn read_leb(reader: &mut R, ro: &ReadOptions, _: ()) -> BinResult { let pos = reader.seek(SeekFrom::Current(0))?; - leb128::read::unsigned(reader).map_err(|_| Error::Custom { + leb128::read::unsigned(reader).map_err(|_| BError::Custom { pos, err: Box::new(ro.variable_name.unwrap_or("Invalid leb128")), }) } fn read_sleb(reader: &mut R, ro: &ReadOptions, _: ()) -> BinResult { let pos = reader.seek(SeekFrom::Current(0))?; - leb128::read::signed(reader).map_err(|_| Error::Custom { + leb128::read::signed(reader).map_err(|_| BError::Custom { pos, err: Box::new(ro.variable_name.unwrap_or("Invalid sleb128")), }) @@ -67,9 +68,14 @@ struct FieldType { } impl IndexType { - pub fn to_type(&self) -> Type { - match self.index { - v if v >= 0 => Type::Var(v.to_string()), + fn to_type(&self, len: u64) -> Result { + Ok(match self.index { + v if v >= 0 => { + if v >= len as i64 { + return Err(Error::msg("type index out of range")); + } + Type::Var(v.to_string()) + } -1 => Type::Null, -2 => Type::Bool, -3 => Type::Nat, @@ -89,46 +95,70 @@ impl IndexType { -17 => Type::Empty, -24 => Type::Principal, _ => unreachable!(), - } + }) } } -impl FieldType { - pub fn to_field(&self) -> Field { - Field { - id: Label::Id(self.id), - ty: self.index.to_type(), - } +impl ConsType { + fn to_type(&self, len: u64) -> Result { + Ok(match &self { + ConsType::Opt(ref ind) => Type::Opt(Box::new(ind.to_type(len)?)), + ConsType::Vec(ref ind) => Type::Vec(Box::new(ind.to_type(len)?)), + ConsType::Record(fs) | ConsType::Variant(fs) => { + let mut res = Vec::new(); + let mut prev = None; + for f in fs.inner.iter() { + if let Some(prev) = prev { + if prev >= f.id { + return Err(Error::msg("unsorted or duplicate fields")); + } + } + prev = Some(f.id); + let field = Field { + id: Label::Id(f.id), + ty: f.index.to_type(len)?, + }; + res.push(field); + } + if matches!(&self, ConsType::Record(_)) { + Type::Record(res) + } else { + Type::Variant(res) + } + } + }) } } -impl ConsType { - pub fn to_type(&self) -> Type { - match &self { - ConsType::Opt(ref ind) => Type::Opt(Box::new(ind.to_type())), - ConsType::Vec(ref ind) => Type::Vec(Box::new(ind.to_type())), - ConsType::Record(fs) => Type::Record(fs.inner.iter().map(|f| f.to_field()).collect()), - ConsType::Variant(fs) => Type::Variant(fs.inner.iter().map(|f| f.to_field()).collect()), +impl Table { + fn to_env(&self, len: u64) -> Result { + use std::collections::BTreeMap; + let mut env = BTreeMap::new(); + for (i, t) in self.table.iter().enumerate() { + env.insert(i.to_string(), t.to_type(len)?); } + Ok(TypeEnv(env)) } } -impl Table { - pub fn to_env(&self) -> TypeEnv { - TypeEnv( - self.table - .iter() - .enumerate() - .map(|(i, t)| (i.to_string(), t.to_type())) - .collect(), - ) +impl Header { + pub fn to_types(&self) -> Result<(TypeEnv, Vec)> { + let len = self.table.len; + let env = self.table.to_env(len)?; + let mut args = Vec::new(); + for t in self.args.iter() { + args.push(t.to_type(len)?); + } + Ok((env, args)) } } #[test] -fn parse() -> crate::Result<()> { +fn parse() -> Result<()> { use binread::io::Cursor; let mut reader = - Cursor::new(b"DIDL\x03\x6e\x00\x6d\x6f\x6c\x02\x00\x7e\x01\x7a\x02\x02\x03".as_ref()); + Cursor::new(b"DIDL\x03\x6e\x00\x6d\x6f\x6c\x02\x00\x7e\x01\x7a\x02\x02\x7a".as_ref()); let header = crate::error::pretty_read::
(&mut reader)?; - println!("{}", header.table.to_env()); + let (env, types) = header.to_types()?; + println!("env {}", env); + println!("types {:?}", types); let rest = reader.position() as usize; println!("remaining {:02x?}", &reader.into_inner()[rest..]); Ok(()) From 7628accc0fc6614dd39adaed025e3062efd2b7c4 Mon Sep 17 00:00:00 2001 From: chenyan-dfinity Date: Sat, 27 Mar 2021 15:30:50 -0700 Subject: [PATCH 04/51] fix --- rust/candid/src/binary_parser.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rust/candid/src/binary_parser.rs b/rust/candid/src/binary_parser.rs index 0ba1cdf6..585c2169 100644 --- a/rust/candid/src/binary_parser.rs +++ b/rust/candid/src/binary_parser.rs @@ -32,7 +32,7 @@ pub struct Header { #[derive(BinRead, Debug)] pub struct Table { - #[br(parse_with = read_leb)] + #[br(parse_with = read_leb, assert(len <= i64::MAX as u64, "type table size out of range"))] len: u64, #[br(count = len)] table: Vec, From 8e8a81fce05b9b414257ddd59ba91e2cd41053e6 Mon Sep 17 00:00:00 2001 From: chenyan-dfinity Date: Sat, 27 Mar 2021 22:00:19 -0700 Subject: [PATCH 05/51] fix --- rust/candid/src/binary_parser.rs | 14 +++++--------- rust/candid/src/error.rs | 12 ++++++++---- rust/candid/src/lib.rs | 2 +- 3 files changed, 14 insertions(+), 14 deletions(-) diff --git a/rust/candid/src/binary_parser.rs b/rust/candid/src/binary_parser.rs index 585c2169..85054acc 100644 --- a/rust/candid/src/binary_parser.rs +++ b/rust/candid/src/binary_parser.rs @@ -29,9 +29,8 @@ pub struct Header { #[br(count = len)] args: Vec, } - #[derive(BinRead, Debug)] -pub struct Table { +struct Table { #[br(parse_with = read_leb, assert(len <= i64::MAX as u64, "type table size out of range"))] len: u64, #[br(count = len)] @@ -109,7 +108,7 @@ impl ConsType { for f in fs.inner.iter() { if let Some(prev) = prev { if prev >= f.id { - return Err(Error::msg("unsorted or duplicate fields")); + return Err(Error::msg("field id collision or not sorted")); } } prev = Some(f.id); @@ -152,14 +151,11 @@ impl Header { #[test] fn parse() -> Result<()> { - use binread::io::Cursor; - let mut reader = - Cursor::new(b"DIDL\x03\x6e\x00\x6d\x6f\x6c\x02\x00\x7e\x01\x7a\x02\x02\x7a".as_ref()); - let header = crate::error::pretty_read::
(&mut reader)?; + let bytes = b"DIDL\x03\x6e\x00\x6d\x6f\x6c\x02\x00\x7e\x01\x7a\x02\x02\x7a\x01"; + let (header, rest) = crate::pretty_read::
(bytes.as_ref())?; let (env, types) = header.to_types()?; println!("env {}", env); println!("types {:?}", types); - let rest = reader.position() as usize; - println!("remaining {:02x?}", &reader.into_inner()[rest..]); + println!("rest {:02x?}", rest); Ok(()) } diff --git a/rust/candid/src/error.rs b/rust/candid/src/error.rs index 63f1e1ad..c34b391c 100644 --- a/rust/candid/src/error.rs +++ b/rust/candid/src/error.rs @@ -90,7 +90,7 @@ fn get_binread_labels(e: &binread::Error) -> Vec> { let variant = variant_errors .iter() .find(|(_, e)| !matches!(e, BadMagic { .. })); - // Should be only one non-magic error + // Should have at most one non-magic error match variant { None => vec![Label::primary((), pos..pos + 2).with_message("Unknown opcode")], Some((id, e)) => { @@ -176,11 +176,12 @@ where }) } -pub fn pretty_read(reader: &mut std::io::Cursor<&[u8]>) -> Result +pub fn pretty_read(bytes: &[u8]) -> Result<(T, &[u8])> where T: binread::BinRead, { - T::read(reader).or_else(|e| { + let mut reader = std::io::Cursor::new(bytes); + let res = T::read(&mut reader).or_else(|e| { let e = Error::Binread(e); let writer = StandardStream::stderr(term::termcolor::ColorChoice::Auto); let config = term::Config::default(); @@ -188,5 +189,8 @@ where let file = SimpleFile::new("binary", &str); term::emit(&mut writer.lock(), &config, &file, &e.report())?; Err(e) - }) + })?; + let ind = reader.position() as usize; + let rest = &reader.into_inner()[ind..]; + Ok((res, rest)) } diff --git a/rust/candid/src/lib.rs b/rust/candid/src/lib.rs index ea449441..4dc49140 100644 --- a/rust/candid/src/lib.rs +++ b/rust/candid/src/lib.rs @@ -255,7 +255,7 @@ pub use codegen::generate_code; pub mod bindings; pub mod error; -pub use error::{pretty_parse, Error, Result}; +pub use error::{pretty_parse, pretty_read, Error, Result}; pub mod types; pub use types::CandidType; From 1ce2b1cc8e01a159a408c00aa88d8fcd42767033 Mon Sep 17 00:00:00 2001 From: chenyan-dfinity Date: Sun, 28 Mar 2021 11:46:29 -0700 Subject: [PATCH 06/51] fix --- rust/candid/src/binary_parser.rs | 4 ++-- rust/candid/src/error.rs | 6 +++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/rust/candid/src/binary_parser.rs b/rust/candid/src/binary_parser.rs index 85054acc..07ee5b11 100644 --- a/rust/candid/src/binary_parser.rs +++ b/rust/candid/src/binary_parser.rs @@ -152,10 +152,10 @@ impl Header { #[test] fn parse() -> Result<()> { let bytes = b"DIDL\x03\x6e\x00\x6d\x6f\x6c\x02\x00\x7e\x01\x7a\x02\x02\x7a\x01"; - let (header, rest) = crate::pretty_read::
(bytes.as_ref())?; + let (header, rest): (Header, usize) = crate::pretty_read(bytes)?; let (env, types) = header.to_types()?; println!("env {}", env); println!("types {:?}", types); - println!("rest {:02x?}", rest); + println!("rest {:02x?}", &bytes[rest..]); Ok(()) } diff --git a/rust/candid/src/error.rs b/rust/candid/src/error.rs index c34b391c..7d58aeb0 100644 --- a/rust/candid/src/error.rs +++ b/rust/candid/src/error.rs @@ -176,9 +176,10 @@ where }) } -pub fn pretty_read(bytes: &[u8]) -> Result<(T, &[u8])> +pub fn pretty_read(bytes: R) -> Result<(T, usize)> where T: binread::BinRead, + R: AsRef<[u8]>, { let mut reader = std::io::Cursor::new(bytes); let res = T::read(&mut reader).or_else(|e| { @@ -190,7 +191,6 @@ where term::emit(&mut writer.lock(), &config, &file, &e.report())?; Err(e) })?; - let ind = reader.position() as usize; - let rest = &reader.into_inner()[ind..]; + let rest = reader.position() as usize; Ok((res, rest)) } From 74f67291f42e22e089a5b9653f6a9fd7a732311b Mon Sep 17 00:00:00 2001 From: chenyan-dfinity Date: Sun, 28 Mar 2021 22:54:35 -0700 Subject: [PATCH 07/51] first stab at deserialization --- rust/candid/src/de.rs | 1276 ++----------------------------- rust/candid/src/error.rs | 10 +- rust/candid/src/lib.rs | 2 +- rust/candid/src/parser/value.rs | 4 +- 4 files changed, 91 insertions(+), 1201 deletions(-) diff --git a/rust/candid/src/de.rs b/rust/candid/src/de.rs index 8d43283d..550e22d9 100644 --- a/rust/candid/src/de.rs +++ b/rust/candid/src/de.rs @@ -1,16 +1,14 @@ //! Deserialize Candid binary format to Rust data structures -use super::error::{Error, Result}; -use super::types::internal::Opcode; -use super::{idl_hash, Int, Nat}; +use super::error::{pretty_read, Error, Result}; +use super::{idl_hash, parser::typing::TypeEnv, types::Type, CandidType, Int, Nat}; +use crate::binary_parser::Header; +use binread::BinRead; use byteorder::{LittleEndian, ReadBytesExt}; -use leb128::read::{signed as sleb128_decode, unsigned as leb128_decode}; use serde::de::{self, Deserialize, Visitor}; use std::collections::{BTreeMap, VecDeque}; use std::convert::TryFrom; -use std::io::Read; - -const MAGIC_NUMBER: &[u8; 4] = b"DIDL"; +use std::io::Cursor; /// Use this struct to deserialize a sequence of Rust values (heterogeneous) from IDL binary message. pub struct IDLDeserialize<'de> { @@ -25,1245 +23,141 @@ impl<'de> IDLDeserialize<'de> { /// Deserialize one value from deserializer. pub fn get_value(&mut self) -> Result where - T: de::Deserialize<'de>, + T: de::Deserialize<'de> + CandidType, { let ty = self .de - .table .types .pop_front() .ok_or_else(|| Error::msg("No more values to deserialize"))?; - self.de.table.current_type.push_back(ty); + let expected_type = T::ty(); + self.de.expect_type = if matches!(expected_type, Type::Unknown) { + ty.clone() + } else { + expected_type + }; + self.de.wire_type = ty; - let v = T::deserialize(&mut self.de).map_err(|e| self.de.dump_error_state(e))?; - if self.de.table.current_type.is_empty() && self.de.field_name.is_none() { + let v = T::deserialize(&mut self.de)?; //.map_err(|e| self.de.dump_error_state(e))?; + Ok(v) + /*if self.de.table.current_type.is_empty() && self.de.field_name.is_none() { Ok(v) } else { Err(Error::msg("Trailing type after deserializing a value")) .map_err(|e| self.de.dump_error_state(e)) - } + }*/ } /// Check if we finish deserializing all values. pub fn is_done(&self) -> bool { - self.de.table.types.is_empty() + self.de.types.is_empty() } /// Return error if there are unprocessed bytes in the input. pub fn done(mut self) -> Result<()> { while !self.is_done() { self.get_value::()?; } - if !self.de.input.0.is_empty() { - return Err(Error::msg("Trailing value after finishing deserialization")) - .map_err(|e| self.de.dump_error_state(e)); - } - Ok(()) - } -} - -#[derive(Clone, Debug, PartialEq, Eq)] -enum RawValue { - I(i64), - U(u32), -} -impl RawValue { - fn get_i64(&self) -> Result { - match *self { - RawValue::I(i) => Ok(i), - _ => Err(Error::msg(format!("get_i64 fail: {:?}", *self))), - } - } - fn get_u32(&self) -> Result { - match *self { - RawValue::U(u) => Ok(u), - _ => Err(Error::msg(format!("get_u32 fail: {:?}", *self))), - } - } -} - -struct Bytes<'a>(&'a [u8]); -impl<'a> Bytes<'a> { - fn from(input: &'a [u8]) -> Self { - Bytes(input) - } - fn leb128_read(&mut self) -> Result { - leb128_decode(&mut self.0).map_err(Error::msg) - } - fn sleb128_read(&mut self) -> Result { - sleb128_decode(&mut self.0).map_err(Error::msg) - } - fn parse_byte(&mut self) -> Result { - let mut buf = [0u8; 1]; - self.0.read_exact(&mut buf)?; - Ok(buf[0]) - } - fn parse_bytes(&mut self, len: usize) -> Result> { - if self.0.len() < len { - return Err(Error::msg("unexpected end of message")); - } - let mut buf = vec![0; len]; - self.0.read_exact(&mut buf)?; - Ok(buf) - } - fn parse_string(&mut self, len: usize) -> Result { - let buf = self.parse_bytes(len)?; - String::from_utf8(buf).map_err(Error::msg) - } - fn parse_magic(&mut self) -> Result<()> { - let mut buf = [0u8; 4]; - match self.0.read(&mut buf) { - Ok(4) if buf == *MAGIC_NUMBER => Ok(()), - _ => Err(Error::msg(format!("wrong magic number {:?}", buf))), - } - } -} - -struct TypeTable { - // Raw value of the type description table - table: Vec>, - // Value types for deserialization - types: VecDeque, - // The front of current_type queue always points to the type of the value we are deserailizing. - // The type info is cloned from table. Someone more familiar with Rust should see if we can - // rewrite this to avoid copying. - current_type: VecDeque, -} -impl TypeTable { - // Parse the type table and return the remaining bytes - fn from_bytes(input: &[u8]) -> Result<(Self, &[u8])> { - let mut bytes = Bytes::from(input); - let mut table: Vec> = Vec::new(); - let mut types = VecDeque::new(); - - bytes.parse_magic()?; - let len = bytes.leb128_read()? as usize; - let mut expect_func = std::collections::HashSet::new(); - for i in 0..len { - let mut buf = Vec::new(); - let ty = bytes.sleb128_read()?; - buf.push(RawValue::I(ty)); - if expect_func.contains(&i) && ty != -22 { - return Err(Error::msg(format!( - "Expect function opcode, but got {}", - ty - ))); - } - match Opcode::try_from(ty) { - Ok(Opcode::Opt) | Ok(Opcode::Vec) => { - let ty = bytes.sleb128_read()?; - validate_type_range(ty, len)?; - buf.push(RawValue::I(ty)); - } - Ok(Opcode::Record) | Ok(Opcode::Variant) => { - let obj_len = u32::try_from(bytes.leb128_read()?) - .map_err(|_| Error::msg(Error::msg("length out of u32")))?; - buf.push(RawValue::U(obj_len)); - let mut prev_hash = None; - for _ in 0..obj_len { - let hash = u32::try_from(bytes.leb128_read()?) - .map_err(|_| Error::msg(Error::msg("field hash out of u32")))?; - if let Some(prev_hash) = prev_hash { - if prev_hash >= hash { - return Err(Error::msg("field id collision or not sorted")); - } - } - prev_hash = Some(hash); - buf.push(RawValue::U(hash)); - let ty = bytes.sleb128_read()?; - validate_type_range(ty, len)?; - buf.push(RawValue::I(ty)); - } - } - Ok(Opcode::Service) => { - let obj_len = u32::try_from(bytes.leb128_read()?) - .map_err(|_| Error::msg(Error::msg("length out of u32")))?; - // Push one element to the table to ensure it's a non-primitive type - buf.push(RawValue::U(obj_len)); - let mut prev = None; - for _ in 0..obj_len { - let mlen = bytes.leb128_read()? as usize; - let meth = bytes.parse_string(mlen)?; - if let Some(prev) = prev { - if prev >= meth { - return Err(Error::msg("method name collision or not sorted")); - } - } - prev = Some(meth); - let ty = bytes.sleb128_read()?; - validate_type_range(ty, len)?; - // Check for method type - if ty >= 0 { - let idx = ty as usize; - if idx < table.len() && table[idx][0] != RawValue::I(-22) { - return Err(Error::msg("not a function type")); - } else { - expect_func.insert(idx); - } - } else { - return Err(Error::msg("not a function type")); - } - } - } - Ok(Opcode::Func) => { - let arg_len = bytes.leb128_read()?; - // Push one element to the table to ensure it's a non-primitive type - buf.push(RawValue::U(arg_len as u32)); - for _ in 0..arg_len { - let ty = bytes.sleb128_read()?; - validate_type_range(ty, len)?; - } - let ret_len = bytes.leb128_read()?; - for _ in 0..ret_len { - let ty = bytes.sleb128_read()?; - validate_type_range(ty, len)?; - } - let ann_len = bytes.leb128_read()?; - for _ in 0..ann_len { - let ann = bytes.parse_byte()?; - if ann > 2u8 { - return Err(Error::msg("Unknown function annotation")); - } - } - } - _ => { - return Err(Error::msg(format!( - "Unsupported op_code {} in type table", - ty - ))) - } - }; - table.push(buf); - } - let len = bytes.leb128_read()?; - for _i in 0..len { - let ty = bytes.sleb128_read()?; - validate_type_range(ty, table.len())?; - types.push_back(RawValue::I(ty)); - } - Ok(( - TypeTable { - table, - types, - current_type: VecDeque::new(), - }, - bytes.0, - )) - } - fn pop_current_type(&mut self) -> Result { - self.current_type - .pop_front() - .ok_or_else(|| Error::msg("empty current_type")) - } - fn peek_current_type(&self) -> Result<&RawValue> { - self.current_type - .front() - .ok_or_else(|| Error::msg("empty current_type")) - } - fn rawvalue_to_opcode(&self, v: &RawValue) -> Result { - let mut op = v.get_i64()?; - if op >= 0 && op < self.table.len() as i64 { - op = self.table[op as usize][0].get_i64()?; - } - Opcode::try_from(op).map_err(|_| Error::msg(format!("Unknown opcode {}", op))) - } - // Pop type opcode from the front of current_type. - // If the opcode is an index (>= 0), we push the corresponding entry from table, - // to current_type queue, and pop the opcode from the front. - fn parse_type(&mut self) -> Result { - let mut op = self.pop_current_type()?.get_i64()?; - if op >= 0 && op < self.table.len() as i64 { - let ty = &self.table[op as usize]; - for x in ty.iter().rev() { - self.current_type.push_front(x.clone()); - } - op = self.pop_current_type()?.get_i64()?; - } - let r = Opcode::try_from(op).map_err(|_| Error::msg(format!("Unknown opcode {}", op)))?; - Ok(r) - } - // Same logic as parse_type, but not poping the current_type queue. - fn peek_type(&self) -> Result { - let op = self.peek_current_type()?; - self.rawvalue_to_opcode(op) - } - // Check if current_type matches the provided type - fn check_type(&mut self, expected: Opcode) -> Result<()> { - let wire_type = self.parse_type()?; - if wire_type != expected { - return Err(Error::msg(format!( - "Type mismatch. Type on the wire: {:?}; Expected type: {:?}", - wire_type, expected - ))); + let ind = self.de.input.position() as usize; + let rest = &self.de.input.get_ref()[ind..]; + if !rest.is_empty() { + return Err(Error::msg("Trailing value after finishing deserialization")); + //.map_err(|e| self.de.dump_error_state(e)); } Ok(()) } } -fn is_primitive_type(ty: i64) -> bool { - ty < 0 && (ty >= -17 || ty == -24) -} -fn validate_type_range(ty: i64, len: usize) -> Result<()> { - if ty >= 0 && (ty as usize) < len || is_primitive_type(ty) { - Ok(()) - } else { - Err(Error::msg(format!("unknown type {}", ty))) - } -} -#[derive(Debug)] -enum FieldLabel { - Named(&'static str), - Id(u32), - Variant(String), - Skip, -} struct Deserializer<'de> { - input: Bytes<'de>, - table: TypeTable, - // field_name tells deserialize_identifier which field name to process. - // This field should always be set by set_field_name function. - field_name: Option, - // The record nesting depth should be bounded by the length of table to avoid infinite loop. + input: Cursor<&'de [u8]>, + table: TypeEnv, + types: VecDeque, + wire_type: Type, + expect_type: Type, record_nesting_depth: usize, } impl<'de> Deserializer<'de> { - fn from_bytes(input: &'de [u8]) -> Result { - let (table, input) = TypeTable::from_bytes(input)?; + fn from_bytes(bytes: &'de [u8]) -> Result { + let mut reader = Cursor::new(bytes); + let header: Header = pretty_read(&mut reader)?; + let (env, types) = header.to_types()?; Ok(Deserializer { - input: Bytes::from(input), - table, - field_name: None, + input: reader, + table: env, + types: types.into(), + wire_type: Type::Unknown, + expect_type: Type::Unknown, record_nesting_depth: 0, }) } - - fn dump_error_state(&self, e: Error) -> Error { - let mut str = format!("Trailing type: {:?}\n", self.table.current_type); - str.push_str(&format!("Trailing value: {:02x?}\n", self.input.0)); - if self.field_name.is_some() { - str.push_str(&format!("Trailing field_name: {:?}\n", self.field_name)); - } - str.push_str(&format!("Type table: {:?}\n", self.table.table)); - str.push_str(&format!("Remaining value types: {:?}", self.table.types)); - e.with_states(str) - } - - // Should always call set_field_name to set the field_name. After deserialize_identifier - // processed the field_name, field_name will be reset to None. - fn set_field_name(&mut self, field: FieldLabel) { - if self.field_name.is_some() { - panic!(format!("field_name already taken {:?}", self.field_name)); - } - self.field_name = Some(field); - } - // Customize deserailization methods - // Several deserialize functions will call visit_byte_buf. - // We reserve the first byte to be a tag to distinguish between different callers: - // int(0), nat(1), principal(2), reserved(3) - // This is necessary for deserializing IDLValue because - // it has only one visitor and we need a way to know who called the visitor. - fn deserialize_int<'a, V>(&'a mut self, visitor: V) -> Result - where - V: Visitor<'de>, - { - self.record_nesting_depth = 0; - self.table.check_type(Opcode::Int)?; - let v = Int::decode(&mut self.input.0).map_err(Error::msg)?; - let bytes = v.0.to_signed_bytes_le(); - let mut tagged = vec![0u8]; - tagged.extend_from_slice(&bytes); - visitor.visit_byte_buf(tagged) - } - fn deserialize_nat<'a, V>(&'a mut self, visitor: V) -> Result - where - V: Visitor<'de>, - { - self.record_nesting_depth = 0; - self.table.check_type(Opcode::Nat)?; - let v = Nat::decode(&mut self.input.0).map_err(Error::msg)?; - let bytes = v.0.to_bytes_le(); - let mut tagged = vec![1u8]; - tagged.extend_from_slice(&bytes); - visitor.visit_byte_buf(tagged) - } - fn decode_principal(&mut self) -> Result> { - let bit = self.input.parse_byte()?; - if bit != 1u8 { - return Err(Error::msg("Opaque reference not supported")); - } - let len = self.input.leb128_read()? as usize; - self.input.parse_bytes(len) - } - fn deserialize_principal<'a, V>(&'a mut self, visitor: V) -> Result - where - V: Visitor<'de>, - { - self.record_nesting_depth = 0; - self.table.check_type(Opcode::Principal)?; - let vec = self.decode_principal()?; - let mut tagged = vec![2u8]; - tagged.extend_from_slice(&vec); - visitor.visit_byte_buf(tagged) - } - fn deserialize_service<'a, V>(&'a mut self, visitor: V) -> Result - where - V: Visitor<'de>, - { - self.record_nesting_depth = 0; - self.table.check_type(Opcode::Service)?; - self.table.pop_current_type()?; - let vec = self.decode_principal()?; - let mut tagged = vec![4u8]; - tagged.extend_from_slice(&vec); - visitor.visit_byte_buf(tagged) - } - fn deserialize_function<'a, V>(&'a mut self, visitor: V) -> Result - where - V: Visitor<'de>, - { - self.record_nesting_depth = 0; - self.table.check_type(Opcode::Func)?; - self.table.pop_current_type()?; - let bit = self.input.parse_byte()?; - if bit != 1u8 { - return Err(Error::msg("Opaque reference not supported")); + fn check_type(&self, expect: &Type) -> Result<()> { + if self.wire_type == self.expect_type && self.wire_type == *expect { + Ok(()) + } else { + Err(Error::msg(format!( + "Type mismatch. Expect {}, but found {}", + expect, self.wire_type + ))) } - let vec = self.decode_principal()?; - let len = self.input.leb128_read()? as usize; - let meth = self.input.parse_bytes(len)?; - let mut tagged = vec![5u8]; - // TODO: find a better way - leb128::write::unsigned(&mut tagged, len as u64)?; - tagged.extend_from_slice(&meth); - tagged.extend_from_slice(&vec); - visitor.visit_byte_buf(tagged) - } - fn deserialize_reserved<'a, V>(&'a mut self, visitor: V) -> Result - where - V: Visitor<'de>, - { - self.record_nesting_depth = 0; - self.table.check_type(Opcode::Reserved)?; - let tagged = vec![3u8]; - visitor.visit_byte_buf(tagged) - } - fn deserialize_empty<'a, V>(&'a mut self, _visitor: V) -> Result - where - V: Visitor<'de>, - { - Err(Error::msg("Cannot decode empty type")) } } -macro_rules! primitive_impl { - ($ty:ident, $opcode:expr, $($value:tt)*) => { - paste::item! { - fn [](self, visitor: V) -> Result - where V: Visitor<'de> { - self.record_nesting_depth = 0; - self.table.check_type($opcode)?; - let value = (self.input.0).$($value)*().map_err(|_| Error::msg(format!("cannot read {} value", stringify!($opcode))))?; - visitor.[](value) - } - } - }; -} - impl<'de, 'a> de::Deserializer<'de> for &'a mut Deserializer<'de> { type Error = Error; - - // Skipping unused field types - fn deserialize_ignored_any(self, visitor: V) -> Result - where - V: Visitor<'de>, - { - if self.field_name.is_some() { - return self.deserialize_identifier(visitor); - } - let t = self.table.peek_type()?; - if t != Opcode::Record { - self.record_nesting_depth = 0; - } - match t { - Opcode::Int => self.deserialize_int(visitor), - Opcode::Nat => self.deserialize_nat(visitor), - Opcode::Nat8 => self.deserialize_u8(visitor), - Opcode::Nat16 => self.deserialize_u16(visitor), - Opcode::Nat32 => self.deserialize_u32(visitor), - Opcode::Nat64 => self.deserialize_u64(visitor), - Opcode::Int8 => self.deserialize_i8(visitor), - Opcode::Int16 => self.deserialize_i16(visitor), - Opcode::Int32 => self.deserialize_i32(visitor), - Opcode::Int64 => self.deserialize_i64(visitor), - Opcode::Float32 => self.deserialize_f32(visitor), - Opcode::Float64 => self.deserialize_f64(visitor), - Opcode::Bool => self.deserialize_bool(visitor), - Opcode::Text => self.deserialize_string(visitor), - Opcode::Null => self.deserialize_unit(visitor), - Opcode::Reserved => self.deserialize_reserved(visitor), - Opcode::Empty => self.deserialize_empty(visitor), - Opcode::Vec => self.deserialize_seq(visitor), - Opcode::Opt => self.deserialize_option(visitor), - Opcode::Record => self.deserialize_struct("_", &[], visitor), - Opcode::Variant => self.deserialize_enum("_", &[], visitor), - Opcode::Principal => self.deserialize_principal(visitor), - Opcode::Service => self.deserialize_service(visitor), - Opcode::Func => self.deserialize_function(visitor), - } - } - - // Used for deserializing to IDLValue - fn deserialize_any(mut self, visitor: V) -> Result + fn deserialize_any(self, visitor: V) -> Result where V: Visitor<'de>, { - if self.field_name.is_some() { - return self.deserialize_identifier(visitor); - } - let t = self.table.peek_type()?; - if t != Opcode::Record { - self.record_nesting_depth = 0; - } + let t = self.table.trace_type(&self.expect_type)?; match t { - Opcode::Int => self.deserialize_int(visitor), - Opcode::Nat => self.deserialize_nat(visitor), - Opcode::Nat8 => self.deserialize_u8(visitor), - Opcode::Nat16 => self.deserialize_u16(visitor), - Opcode::Nat32 => self.deserialize_u32(visitor), - Opcode::Nat64 => self.deserialize_u64(visitor), - Opcode::Int8 => self.deserialize_i8(visitor), - Opcode::Int16 => self.deserialize_i16(visitor), - Opcode::Int32 => self.deserialize_i32(visitor), - Opcode::Int64 => self.deserialize_i64(visitor), - Opcode::Float32 => self.deserialize_f32(visitor), - Opcode::Float64 => self.deserialize_f64(visitor), - Opcode::Bool => self.deserialize_bool(visitor), - Opcode::Text => self.deserialize_string(visitor), - Opcode::Null => self.deserialize_unit(visitor), - Opcode::Reserved => self.deserialize_reserved(visitor), - Opcode::Empty => self.deserialize_empty(visitor), - Opcode::Vec => self.deserialize_seq(visitor), - Opcode::Opt => self.deserialize_option(visitor), - Opcode::Record => { - let old_nesting = self.record_nesting_depth; - self.record_nesting_depth += 1; - if self.record_nesting_depth > self.table.table.len() { - return Err(Error::msg("There is an infinite loop in the record definition, the type is isomorphic to an empty type")); - } - self.table.check_type(Opcode::Record)?; - let len = self.table.pop_current_type()?.get_u32()?; - let mut fs = BTreeMap::new(); - for i in 0..len { - let hash = self.table.current_type[2 * i as usize].get_u32()?; - if fs.insert(hash, None) != None { - return Err(Error::msg(format!("hash collision {}", hash))); - } - } - let res = visitor.visit_map(Compound::new(&mut self, Style::Struct { len, fs })); - self.record_nesting_depth = old_nesting; - res - } - Opcode::Variant => { - self.record_nesting_depth = 0; - self.table.check_type(Opcode::Variant)?; - let len = self.table.pop_current_type()?.get_u32()?; - let mut fs = BTreeMap::new(); - for i in 0..len { - let hash = self.table.current_type[2 * i as usize].get_u32()?; - if fs.insert(hash, None) != None { - return Err(Error::msg(format!("hash collision {}", hash))); - } - } - visitor.visit_enum(Compound::new(&mut self, Style::Enum { len, fs })) - } - Opcode::Principal => self.deserialize_principal(visitor), - Opcode::Service => self.deserialize_service(visitor), - Opcode::Func => self.deserialize_function(visitor), + Type::Bool => self.deserialize_bool(visitor), + Type::Nat8 => self.deserialize_u8(visitor), + Type::Vec(_) => self.deserialize_seq(visitor), + _ => unreachable!(), } } - primitive_impl!(i8, Opcode::Int8, read_i8); - primitive_impl!(i16, Opcode::Int16, read_i16::); - primitive_impl!(i32, Opcode::Int32, read_i32::); - primitive_impl!(i64, Opcode::Int64, read_i64::); - primitive_impl!(u8, Opcode::Nat8, read_u8); - primitive_impl!(u16, Opcode::Nat16, read_u16::); - primitive_impl!(u32, Opcode::Nat32, read_u32::); - primitive_impl!(u64, Opcode::Nat64, read_u64::); - primitive_impl!(f32, Opcode::Float32, read_f32::); - primitive_impl!(f64, Opcode::Float64, read_f64::); - - fn deserialize_i128(self, visitor: V) -> Result - where - V: Visitor<'de>, - { - use std::convert::TryInto; - self.record_nesting_depth = 0; - let value: i128 = match self.table.parse_type()? { - Opcode::Int => { - let v = Int::decode(&mut self.input.0).map_err(Error::msg)?; - v.0.try_into().map_err(Error::msg)? - } - Opcode::Nat => { - let v = Nat::decode(&mut self.input.0).map_err(Error::msg)?; - v.0.try_into().map_err(Error::msg)? - } - t => { - return Err(Error::msg(format!( - "Type mismatch. Type on the wire: {:?}; Expected type: int", - t - ))) - } - }; - visitor.visit_i128(value) - } - - fn deserialize_u128(self, visitor: V) -> Result - where - V: Visitor<'de>, - { - use std::convert::TryInto; - self.record_nesting_depth = 0; - self.table.check_type(Opcode::Nat)?; - let v = Nat::decode(&mut self.input.0).map_err(Error::msg)?; - let value: u128 = v.0.try_into().map_err(Error::msg)?; - visitor.visit_u128(value) - } - fn deserialize_bool(self, visitor: V) -> Result where V: Visitor<'de>, { + #[derive(BinRead)] + struct BoolValue( + #[br(try_map = |x:u8| match x { 0u8 => Ok(false), | 1u8 => Ok(true), | _ => Err("Not a boolean") } )] + bool, + ); self.record_nesting_depth = 0; - self.table.check_type(Opcode::Bool)?; - let byte = self.input.parse_byte()?; - if byte > 1u8 { - return Err(de::Error::custom("not a boolean value")); - } - let value = byte == 1u8; - visitor.visit_bool(value) - } - - fn deserialize_string(self, visitor: V) -> Result - where - V: Visitor<'de>, - { - self.record_nesting_depth = 0; - self.table.check_type(Opcode::Text)?; - let len = self.input.leb128_read()? as usize; - let value = self.input.parse_string(len)?; - visitor.visit_string(value) - } - - fn deserialize_str(self, visitor: V) -> Result - where - V: Visitor<'de>, - { - self.record_nesting_depth = 0; - self.table.check_type(Opcode::Text)?; - let len = self.input.leb128_read()? as usize; - let value: Result<&str> = - std::str::from_utf8(&self.input.0[0..len]).map_err(de::Error::custom); - self.input.0 = &self.input.0[len..]; - visitor.visit_borrowed_str(value?) - } - - fn deserialize_option(self, visitor: V) -> Result - where - V: Visitor<'de>, - { - self.record_nesting_depth = 0; - match self.table.peek_type()? { - Opcode::Opt => { - self.table.parse_type()?; - match self.input.parse_byte()? { - 0 => { - // Skip the type T of Option - self.table.pop_current_type()?; - visitor.visit_none() - } - // TODO handle subtyping failure - 1 => visitor.visit_some(self), - _ => Err(de::Error::custom("not an option tag")), - } - } - Opcode::Null | Opcode::Reserved => { - self.table.parse_type()?; - visitor.visit_none() - } - _ => visitor.visit_some(self), - } - } - fn deserialize_unit(self, visitor: V) -> Result - where - V: Visitor<'de>, - { - self.record_nesting_depth = 0; - self.table.check_type(Opcode::Null)?; - visitor.visit_unit() - } - fn deserialize_unit_struct(self, _name: &'static str, visitor: V) -> Result - where - V: Visitor<'de>, - { - self.deserialize_unit(visitor) - } - fn deserialize_newtype_struct(self, _name: &'static str, visitor: V) -> Result - where - V: Visitor<'de>, - { - visitor.visit_newtype_struct(self) - } - fn deserialize_byte_buf>(self, visitor: V) -> Result { - self.record_nesting_depth = 0; - self.table.check_type(Opcode::Vec)?; - self.table.check_type(Opcode::Nat8)?; - let len = self.input.leb128_read()?; - let bytes = self.input.parse_bytes(len as usize)?; - visitor.visit_byte_buf(bytes) - } - fn deserialize_bytes>(self, visitor: V) -> Result { - self.record_nesting_depth = 0; - match self.table.peek_type()? { - Opcode::Principal => self.deserialize_principal(visitor), - Opcode::Vec => { - self.table.check_type(Opcode::Vec)?; - self.table.check_type(Opcode::Nat8)?; - let len = self.input.leb128_read()? as usize; - let bytes: &[u8] = &self.input.0[0..len]; - self.input.0 = &self.input.0[len..]; - visitor.visit_borrowed_bytes(bytes) - } - _ => Err(Error::msg("bytes only takes principal or vec nat8")), - } - } - fn deserialize_seq(mut self, visitor: V) -> Result - where - V: Visitor<'de>, - { - self.record_nesting_depth = 0; - match self.table.parse_type()? { - Opcode::Vec => { - let len = self.input.leb128_read()?; - let value = visitor.visit_seq(Compound::new(&mut self, Style::Vector { len })); - // Skip the type T of Vec. - self.table.pop_current_type()?; - value - } - Opcode::Record => { - let len = self.table.pop_current_type()?.get_u32()?; - visitor.visit_seq(Compound::new(&mut self, Style::Tuple { len, index: 0 })) - } - _ => Err(Error::msg("seq only takes vector or tuple")), - } - } - fn deserialize_map(mut self, visitor: V) -> Result - where - V: Visitor<'de>, - { - self.record_nesting_depth = 0; - self.table.check_type(Opcode::Vec)?; - let len = self.input.leb128_read()?; - let ty = self.table.peek_current_type()?.clone(); - let value = visitor.visit_map(Compound::new(&mut self, Style::Map { len, ty })); - self.table.pop_current_type()?; - value - } - fn deserialize_tuple(self, _len: usize, visitor: V) -> Result - where - V: Visitor<'de>, - { - self.deserialize_seq(visitor) - } - fn deserialize_tuple_struct( - self, - _name: &'static str, - _len: usize, - visitor: V, - ) -> Result - where - V: Visitor<'de>, - { - self.deserialize_seq(visitor) - } - fn deserialize_struct( - mut self, - _name: &'static str, - fields: &'static [&'static str], - visitor: V, - ) -> Result - where - V: Visitor<'de>, - { - let old_nesting = self.record_nesting_depth; - self.record_nesting_depth += 1; - if self.record_nesting_depth > self.table.table.len() { - return Err(Error::msg("There is an infinite loop in the record definition, the type is isomorphic to an empty type")); - } - self.table.check_type(Opcode::Record)?; - let len = self.table.pop_current_type()?.get_u32()?; - let mut fs = BTreeMap::new(); - for s in fields.iter() { - if fs.insert(idl_hash(s), Some(*s)) != None { - return Err(Error::msg(format!("hash collision {}", s))); - } - } - let value = visitor.visit_map(Compound::new(&mut self, Style::Struct { len, fs }))?; - self.record_nesting_depth = old_nesting; - Ok(value) - } - - fn deserialize_enum( - mut self, - _name: &'static str, - variants: &'static [&'static str], - visitor: V, - ) -> Result - where - V: Visitor<'de>, - { - self.record_nesting_depth = 0; - self.table.check_type(Opcode::Variant)?; - let len = self.table.pop_current_type()?.get_u32()?; - let mut fs = BTreeMap::new(); - for s in variants.iter() { - if fs.insert(idl_hash(s), Some(*s)) != None { - return Err(Error::msg(format!("hash collision {}", s))); - } - } - let value = visitor.visit_enum(Compound::new(&mut self, Style::Enum { len, fs }))?; - Ok(value) - } - /// Deserialize identifier. - /// # Panics - /// *Will Panic* when identifier name is None. - fn deserialize_identifier(self, visitor: V) -> Result - where - V: Visitor<'de>, - { - // N.B. Here we want to panic as it indicates a logical error. - let label = self.field_name.as_ref().unwrap(); - let v = match label { - FieldLabel::Named(name) => visitor.visit_str(name), - FieldLabel::Id(hash) => visitor.visit_u32(*hash), - FieldLabel::Variant(variant) => visitor.visit_str(variant), - FieldLabel::Skip => visitor.visit_str("_"), - }; - self.field_name = None; - v + self.check_type(&Type::Bool)?; + let res: BoolValue = pretty_read(&mut self.input)?; + visitor.visit_bool(res.0) } serde::forward_to_deserialize_any! { + u8 + u16 + u32 + u64 + i8 + i16 + i32 + i64 + f32 + f64 char + str + string + unit + option + bytes + byte_buf + unit_struct + newtype_struct + tuple_struct + struct + identifier + tuple + enum + seq + map + ignored_any } } - -#[derive(Debug)] -enum Style { - Tuple { - len: u32, - index: u32, - }, - Vector { - len: u64, // non-vector length can only be u32, because field ids is u32. - }, - Struct { - len: u32, - fs: BTreeMap>, - }, - Enum { - len: u32, - fs: BTreeMap>, - }, - Map { - len: u64, - ty: RawValue, - }, -} - -struct Compound<'a, 'de> { - de: &'a mut Deserializer<'de>, - style: Style, -} - -impl<'a, 'de> Compound<'a, 'de> { - fn new(de: &'a mut Deserializer<'de>, style: Style) -> Self { - Compound { de, style } - } -} - -impl<'de, 'a> de::SeqAccess<'de> for Compound<'a, 'de> { - type Error = Error; - - fn next_element_seed(&mut self, seed: T) -> Result> - where - T: de::DeserializeSeed<'de>, - { - match self.style { - Style::Tuple { - ref len, - ref mut index, - } => { - if *index == *len { - return Ok(None); - } - let t_idx = self.de.table.pop_current_type()?.get_u32()?; - if t_idx != *index { - return Err(Error::msg(format!( - "Expect vector index {}, but get {}", - index, t_idx - ))); - } - *index += 1; - seed.deserialize(&mut *self.de).map(Some) - } - Style::Vector { ref mut len } => { - if *len == 0 { - return Ok(None); - } - let ty = self.de.table.peek_current_type()?.clone(); - self.de.table.current_type.push_front(ty); - *len -= 1; - seed.deserialize(&mut *self.de).map(Some) - } - _ => Err(Error::msg("expect vector or tuple")), - } - } -} - -impl<'de, 'a> de::MapAccess<'de> for Compound<'a, 'de> { - type Error = Error; - fn next_key_seed(&mut self, seed: K) -> Result> - where - K: de::DeserializeSeed<'de>, - { - match self.style { - Style::Struct { - ref mut len, - ref fs, - } => { - if *len == 0 { - return Ok(None); - } - *len -= 1; - let hash = self.de.table.pop_current_type()?.get_u32()?; - match fs.get(&hash) { - Some(None) => self.de.set_field_name(FieldLabel::Id(hash)), - Some(Some(field)) => self.de.set_field_name(FieldLabel::Named(field)), - None => { - // This triggers seed.deserialize to call deserialize_ignore_any - // to skip both type and value of this unknown field. - self.de.set_field_name(FieldLabel::Skip); - } - } - seed.deserialize(&mut *self.de).map(Some) - } - Style::Map { ref mut len, .. } => { - // This only comes from deserialize_map - if *len == 0 { - return Ok(None); - } - self.de.table.check_type(Opcode::Record)?; - assert_eq!(2, self.de.table.pop_current_type()?.get_u32()?); - assert_eq!(0, self.de.table.pop_current_type()?.get_u32()?); - *len -= 1; - seed.deserialize(&mut *self.de).map(Some) - } - _ => Err(Error::msg("expect struct or map")), - } - } - fn next_value_seed(&mut self, seed: V) -> Result - where - V: de::DeserializeSeed<'de>, - { - match self.style { - Style::Map { ref ty, .. } => { - assert_eq!(1, self.de.table.pop_current_type()?.get_u32()?); - let res = seed.deserialize(&mut *self.de); - self.de.table.current_type.push_front(ty.clone()); - res - } - _ => seed.deserialize(&mut *self.de), - } - } -} - -impl<'de, 'a> de::EnumAccess<'de> for Compound<'a, 'de> { - type Error = Error; - type Variant = Self; - - fn variant_seed(self, seed: V) -> Result<(V::Value, Self::Variant)> - where - V: de::DeserializeSeed<'de>, - { - match self.style { - Style::Enum { len, ref fs } => { - let index = u32::try_from(self.de.input.leb128_read()?) - .map_err(|_| Error::msg("variant index out of u32"))?; - if index >= len { - return Err(Error::msg(format!( - "variant index {} larger than length {}", - index, len - ))); - } - let mut index_ty = None; - for i in 0..len { - let hash = self.de.table.pop_current_type()?.get_u32()?; - let ty = self.de.table.pop_current_type()?; - if i == index { - match fs.get(&hash) { - Some(None) => { - let opcode = self.de.table.rawvalue_to_opcode(&ty)?; - let accessor = match opcode { - Opcode::Null => "unit", - Opcode::Record => "struct", - _ => "newtype", - }; - self.de.set_field_name(FieldLabel::Variant(format!( - "{},{}", - hash, accessor - ))); - } - Some(Some(field)) => { - self.de.set_field_name(FieldLabel::Named(field)); - } - None => { - if !fs.is_empty() { - return Err(Error::msg(format!( - "Unknown variant hash {}", - hash - ))); - } else { - // Actual enum won't have empty fs. This can only be generated - // from deserialize_ignored_any - self.de.set_field_name(FieldLabel::Skip); - } - } - } - index_ty = Some(ty); - } - } - // Okay to unwrap, as index_ty always has a value here. - self.de.table.current_type.push_front(index_ty.unwrap()); - let val = seed.deserialize(&mut *self.de)?; - Ok((val, self)) - } - _ => Err(Error::msg("expect enum")), - } - } -} - -impl<'de, 'a> de::VariantAccess<'de> for Compound<'a, 'de> { - type Error = Error; - - fn unit_variant(self) -> Result<()> { - self.de.table.check_type(Opcode::Null)?; - Ok(()) - } - - fn newtype_variant_seed(self, seed: T) -> Result - where - T: de::DeserializeSeed<'de>, - { - seed.deserialize(self.de) - } - - fn tuple_variant(self, _len: usize, visitor: V) -> Result - where - V: Visitor<'de>, - { - de::Deserializer::deserialize_seq(self.de, visitor) - } - - fn struct_variant(self, fields: &'static [&'static str], visitor: V) -> Result - where - V: Visitor<'de>, - { - if fields.is_empty() { - de::Deserializer::deserialize_any(self.de, visitor) - } else { - de::Deserializer::deserialize_struct(self.de, "_", fields, visitor) - } - } -} - -/// Allow decoding of any sized argument. -pub trait ArgumentDecoder<'a>: Sized { - /// Decodes a value of type [Self], modifying the deserializer (values are consumed). - fn decode(de: &mut IDLDeserialize<'a>) -> Result; -} - -/// Decode an empty tuple. -impl<'a> ArgumentDecoder<'a> for () { - fn decode(_de: &mut IDLDeserialize<'a>) -> Result<()> { - Ok(()) - } -} - -// Create implementation of [ArgumentDecoder] for up to 16 value tuples. -macro_rules! decode_impl { - ( $($id: ident : $typename: ident),* ) => { - impl<'a, $( $typename ),*> ArgumentDecoder<'a> for ($($typename,)*) - where - $( $typename: Deserialize<'a> ),* - { - fn decode(de: &mut IDLDeserialize<'a>) -> Result { - $( - let $id: $typename = de.get_value()?; - )* - - Ok(($( $id, )*)) - } - } - } -} - -decode_impl!(a: A); -decode_impl!(a: A, b: B); -decode_impl!(a: A, b: B, c: C); -decode_impl!(a: A, b: B, c: C, d: D); -decode_impl!(a: A, b: B, c: C, d: D, e: E); -decode_impl!(a: A, b: B, c: C, d: D, e: E, f: F); -decode_impl!(a: A, b: B, c: C, d: D, e: E, f: F, g: G); -decode_impl!(a: A, b: B, c: C, d: D, e: E, f: F, g: G, h: H); -decode_impl!(a: A, b: B, c: C, d: D, e: E, f: F, g: G, h: H, i: I); -decode_impl!(a: A, b: B, c: C, d: D, e: E, f: F, g: G, h: H, i: I, j: J); -decode_impl!( - a: A, - b: B, - c: C, - d: D, - e: E, - f: F, - g: G, - h: H, - i: I, - j: J, - k: K -); -decode_impl!( - a: A, - b: B, - c: C, - d: D, - e: E, - f: F, - g: G, - h: H, - i: I, - j: J, - k: K, - l: L -); -decode_impl!( - a: A, - b: B, - c: C, - d: D, - e: E, - f: F, - g: G, - h: H, - i: I, - j: J, - k: K, - l: L, - m: M -); -decode_impl!( - a: A, - b: B, - c: C, - d: D, - e: E, - f: F, - g: G, - h: H, - i: I, - j: J, - k: K, - l: L, - m: M, - n: N -); -decode_impl!( - a: A, - b: B, - c: C, - d: D, - e: E, - f: F, - g: G, - h: H, - i: I, - j: J, - k: K, - l: L, - m: M, - n: N, - o: O -); -decode_impl!( - a: A, - b: B, - c: C, - d: D, - e: E, - f: F, - g: G, - h: H, - i: I, - j: J, - k: K, - l: L, - m: M, - n: N, - o: O, - p: P -); - -/// Decode a series of arguments, represented as a tuple. There is a maximum of 16 arguments -/// supported. -/// -/// Example: -/// -/// ``` -/// # use candid::Encode; -/// # use candid::de::decode_args; -/// let golden1 = 123u64; -/// let golden2 = "456"; -/// let bytes = Encode!(&golden1, &golden2).unwrap(); -/// let (value1, value2): (u64, String) = decode_args(&bytes).unwrap(); -/// -/// assert_eq!(golden1, value1); -/// assert_eq!(golden2, value2); -/// ``` -pub fn decode_args<'a, Tuple>(bytes: &'a [u8]) -> Result -where - Tuple: ArgumentDecoder<'a>, -{ - let mut de = IDLDeserialize::new(bytes)?; - let res = ArgumentDecoder::decode(&mut de)?; - de.done()?; - Ok(res) -} - -/// Decode a single argument. -/// -/// Example: -/// -/// ``` -/// # use candid::Encode; -/// # use candid::de::decode_one; -/// let golden1 = 123u64; -/// let bytes = Encode!(&golden1).unwrap(); -/// let value1: u64 = decode_one(&bytes).unwrap(); -/// -/// assert_eq!(golden1, value1); -/// ``` -pub fn decode_one<'a, T>(bytes: &'a [u8]) -> Result -where - T: Deserialize<'a>, -{ - let (res,) = decode_args(bytes)?; - Ok(res) -} diff --git a/rust/candid/src/error.rs b/rust/candid/src/error.rs index 7d58aeb0..acb4b83c 100644 --- a/rust/candid/src/error.rs +++ b/rust/candid/src/error.rs @@ -176,13 +176,11 @@ where }) } -pub fn pretty_read(bytes: R) -> Result<(T, usize)> +pub fn pretty_read(reader: &mut std::io::Cursor<&[u8]>) -> Result where T: binread::BinRead, - R: AsRef<[u8]>, { - let mut reader = std::io::Cursor::new(bytes); - let res = T::read(&mut reader).or_else(|e| { + T::read(reader).or_else(|e| { let e = Error::Binread(e); let writer = StandardStream::stderr(term::termcolor::ColorChoice::Auto); let config = term::Config::default(); @@ -190,7 +188,5 @@ where let file = SimpleFile::new("binary", &str); term::emit(&mut writer.lock(), &config, &file, &e.report())?; Err(e) - })?; - let rest = reader.position() as usize; - Ok((res, rest)) + }) } diff --git a/rust/candid/src/lib.rs b/rust/candid/src/lib.rs index 4dc49140..5c6ecac4 100644 --- a/rust/candid/src/lib.rs +++ b/rust/candid/src/lib.rs @@ -272,7 +272,7 @@ pub use parser::value::IDLArgs; pub mod binary_parser; pub mod de; -pub use de::{decode_args, decode_one}; +//pub use de::{decode_args, decode_one}; pub mod ser; pub use ser::{encode_args, encode_one}; diff --git a/rust/candid/src/parser/value.rs b/rust/candid/src/parser/value.rs index c66b80bf..586ad0f4 100644 --- a/rust/candid/src/parser/value.rs +++ b/rust/candid/src/parser/value.rs @@ -323,13 +323,13 @@ impl IDLValue { impl crate::CandidType for IDLValue { fn ty() -> Type { - unreachable!() + Type::Unknown } fn id() -> crate::types::TypeId { unreachable!(); } fn _ty() -> Type { - unreachable!() + Type::Unknown } fn idl_serialize(&self, serializer: S) -> std::result::Result<(), S::Error> where From b57ea04e76add4c5593a4624cd74020d330f5fe8 Mon Sep 17 00:00:00 2001 From: chenyan-dfinity Date: Mon, 29 Mar 2021 11:22:35 -0700 Subject: [PATCH 08/51] int --- rust/candid/src/de.rs | 43 +++++++++++++++++++++++++++++++++++-------- 1 file changed, 35 insertions(+), 8 deletions(-) diff --git a/rust/candid/src/de.rs b/rust/candid/src/de.rs index 550e22d9..c88148e9 100644 --- a/rust/candid/src/de.rs +++ b/rust/candid/src/de.rs @@ -89,15 +89,42 @@ impl<'de> Deserializer<'de> { record_nesting_depth: 0, }) } - fn check_type(&self, expect: &Type) -> Result<()> { - if self.wire_type == self.expect_type && self.wire_type == *expect { - Ok(()) - } else { + fn expect_type(&self, expect: &Type) -> Result<()> { + if *expect != self.expect_type { Err(Error::msg(format!( - "Type mismatch. Expect {}, but found {}", - expect, self.wire_type + "Internal error. Expect {}, but expect_type is {}", + expect, self.expect_type ))) + } else { + Ok(()) + } + } + fn deserialize_int<'a, V>(&'a mut self, visitor: V) -> Result + where + V: Visitor<'de>, + { + self.record_nesting_depth = 0; + self.expect_type(&Type::Int)?; + let mut bytes = Vec::new(); + match &self.wire_type { + Type::Int => { + bytes.push(0u8); + let int = Int::decode(&mut self.input).map_err(Error::msg)?; + bytes.extend_from_slice(&int.0.to_signed_bytes_le()); + } + Type::Nat => { + bytes.push(1u8); + let nat = Nat::decode(&mut self.input).map_err(Error::msg)?; + bytes.extend_from_slice(&nat.0.to_bytes_le()); + } + t => { + return Err(Error::msg(format!( + "Type mismatch. Expect int, but found {}", + t + ))) + } } + visitor.visit_byte_buf(bytes) } } @@ -110,7 +137,7 @@ impl<'de, 'a> de::Deserializer<'de> for &'a mut Deserializer<'de> { let t = self.table.trace_type(&self.expect_type)?; match t { Type::Bool => self.deserialize_bool(visitor), - Type::Nat8 => self.deserialize_u8(visitor), + Type::Int => self.deserialize_int(visitor), Type::Vec(_) => self.deserialize_seq(visitor), _ => unreachable!(), } @@ -126,8 +153,8 @@ impl<'de, 'a> de::Deserializer<'de> for &'a mut Deserializer<'de> { bool, ); self.record_nesting_depth = 0; - self.check_type(&Type::Bool)?; let res: BoolValue = pretty_read(&mut self.input)?; + self.expect_type(&Type::Bool)?; visitor.visit_bool(res.0) } From 20a1280db334758c8817b8ef7e50f276585e1ebd Mon Sep 17 00:00:00 2001 From: chenyan-dfinity Date: Mon, 29 Mar 2021 18:03:34 -0700 Subject: [PATCH 09/51] anyhow --- Cargo.lock | 1 + rust/candid/Cargo.toml | 1 + rust/candid/src/binary_parser.rs | 20 +++++++++++----- rust/candid/src/bindings/candid.rs | 2 +- rust/candid/src/de.rs | 38 +++++++++++++++++++++++------- rust/candid/src/error.rs | 15 ++++-------- 6 files changed, 50 insertions(+), 27 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 28d34fa9..24af1ccd 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -254,6 +254,7 @@ checksum = "b700ce4376041dcd0a327fd0097c41095743c4c8af8887265942faf1100bd040" name = "candid" version = "0.6.19" dependencies = [ + "anyhow", "arbitrary", "binread", "byteorder", diff --git a/rust/candid/Cargo.toml b/rust/candid/Cargo.toml index c09ad71b..f87e13b7 100644 --- a/rust/candid/Cargo.toml +++ b/rust/candid/Cargo.toml @@ -35,6 +35,7 @@ pretty = "0.10.0" serde = { version = "1.0.118", features = ["derive"] } serde_bytes = "0.11" thiserror = "1.0.20" +anyhow = "1.0" binread = { version = "2.0", features = ["debug_template"] } arbitrary = { version = "0.4.7", optional = true } diff --git a/rust/candid/src/binary_parser.rs b/rust/candid/src/binary_parser.rs index 07ee5b11..e2694cac 100644 --- a/rust/candid/src/binary_parser.rs +++ b/rust/candid/src/binary_parser.rs @@ -1,6 +1,7 @@ use crate::parser::typing::TypeEnv; use crate::types::internal::{Field, Label, Type}; -use crate::{Error, Result}; +//use crate::{Error, Result}; +use anyhow::{anyhow, Context, Result}; use binread::io::{Read, Seek, SeekFrom}; use binread::{BinRead, BinResult, Error as BError, ReadOptions}; use std::convert::TryInto; @@ -71,7 +72,7 @@ impl IndexType { Ok(match self.index { v if v >= 0 => { if v >= len as i64 { - return Err(Error::msg("type index out of range")); + return Err(anyhow!("type index {} out of range", v)); } Type::Var(v.to_string()) } @@ -108,7 +109,7 @@ impl ConsType { for f in fs.inner.iter() { if let Some(prev) = prev { if prev >= f.id { - return Err(Error::msg("field id collision or not sorted")); + return Err(anyhow!("field id {} collision or not sorted", f.id)); } } prev = Some(f.id); @@ -132,7 +133,11 @@ impl Table { use std::collections::BTreeMap; let mut env = BTreeMap::new(); for (i, t) in self.table.iter().enumerate() { - env.insert(i.to_string(), t.to_type(len)?); + env.insert( + i.to_string(), + t.to_type(len) + .with_context(|| format!("Invalid table entry {}: {:?}", i, t))?, + ); } Ok(TypeEnv(env)) } @@ -142,8 +147,11 @@ impl Header { let len = self.table.len; let env = self.table.to_env(len)?; let mut args = Vec::new(); - for t in self.args.iter() { - args.push(t.to_type(len)?); + for (i, t) in self.args.iter().enumerate() { + args.push( + t.to_type(len) + .with_context(|| format!("Invalid argument entry {}: {:?}", i, t))?, + ); } Ok((env, args)) } diff --git a/rust/candid/src/bindings/candid.rs b/rust/candid/src/bindings/candid.rs index 9f360e40..323e5bab 100644 --- a/rust/candid/src/bindings/candid.rs +++ b/rust/candid/src/bindings/candid.rs @@ -115,7 +115,7 @@ pub fn pp_ty(ty: &Type) -> RcDoc { } } Knot(ref id) => RcDoc::text(format!("{}", id)), - Unknown => unreachable!(), + Unknown => str("unknown"), } } diff --git a/rust/candid/src/de.rs b/rust/candid/src/de.rs index c88148e9..025838dd 100644 --- a/rust/candid/src/de.rs +++ b/rust/candid/src/de.rs @@ -3,6 +3,7 @@ use super::error::{pretty_read, Error, Result}; use super::{idl_hash, parser::typing::TypeEnv, types::Type, CandidType, Int, Nat}; use crate::binary_parser::Header; +use anyhow::{anyhow, Context}; use binread::BinRead; use byteorder::{LittleEndian, ReadBytesExt}; use serde::de::{self, Deserialize, Visitor}; @@ -17,7 +18,8 @@ pub struct IDLDeserialize<'de> { impl<'de> IDLDeserialize<'de> { /// Create a new deserializer with IDL binary message. pub fn new(bytes: &'de [u8]) -> Result { - let de = Deserializer::from_bytes(bytes)?; + let de = Deserializer::from_bytes(bytes) + .with_context(|| format!("Cannot parse header {}", &hex::encode(bytes)))?; Ok(IDLDeserialize { de }) } /// Deserialize one value from deserializer. @@ -25,20 +27,24 @@ impl<'de> IDLDeserialize<'de> { where T: de::Deserialize<'de> + CandidType, { - let ty = self + let (ind, ty) = self .de .types .pop_front() - .ok_or_else(|| Error::msg("No more values to deserialize"))?; + .context("No more values to deserialize")?; let expected_type = T::ty(); self.de.expect_type = if matches!(expected_type, Type::Unknown) { ty.clone() } else { expected_type }; - self.de.wire_type = ty; + self.de.wire_type = ty.clone(); - let v = T::deserialize(&mut self.de)?; //.map_err(|e| self.de.dump_error_state(e))?; + let v = T::deserialize(&mut self.de) + .with_context(|| self.de.dump_state()) + .with_context(|| { + format!("Fail to decode argument {} from {} to {}", ind, ty, T::ty()) + })?; Ok(v) /*if self.de.table.current_type.is_empty() && self.de.field_name.is_none() { Ok(v) @@ -59,8 +65,8 @@ impl<'de> IDLDeserialize<'de> { let ind = self.de.input.position() as usize; let rest = &self.de.input.get_ref()[ind..]; if !rest.is_empty() { - return Err(Error::msg("Trailing value after finishing deserialization")); - //.map_err(|e| self.de.dump_error_state(e)); + return Err(anyhow!(self.de.dump_state())) + .context("Trailing value after finishing deserialization")?; } Ok(()) } @@ -69,7 +75,7 @@ impl<'de> IDLDeserialize<'de> { struct Deserializer<'de> { input: Cursor<&'de [u8]>, table: TypeEnv, - types: VecDeque, + types: VecDeque<(usize, Type)>, wire_type: Type, expect_type: Type, record_nesting_depth: usize, @@ -83,12 +89,26 @@ impl<'de> Deserializer<'de> { Ok(Deserializer { input: reader, table: env, - types: types.into(), + types: types.into_iter().enumerate().collect(), wire_type: Type::Unknown, expect_type: Type::Unknown, record_nesting_depth: 0, }) } + fn dump_state(&self) -> String { + let hex = hex::encode(self.input.get_ref()); + let pos = self.input.position() as usize * 2; + let (before, after) = hex.split_at(pos); + let mut res = format!("input: {}_{}\n", before, after); + if !self.table.0.is_empty() { + res += &format!("table: {}", self.table); + } + res += &format!( + "wire_type: {}, expect_type: {}", + self.wire_type, self.expect_type + ); + res + } fn expect_type(&self, expect: &Type) -> Result<()> { if *expect != self.expect_type { Err(Error::msg(format!( diff --git a/rust/candid/src/error.rs b/rust/candid/src/error.rs index acb4b83c..cb2409be 100644 --- a/rust/candid/src/error.rs +++ b/rust/candid/src/error.rs @@ -19,19 +19,13 @@ pub enum Error { #[error("Binary parser error: {0}")] Binread(#[from] binread::Error), - #[error("Deserialize error: {0}")] - Deserialize(String, String), - #[error("{0}")] - Custom(String), + Custom(#[from] anyhow::Error), } impl Error { pub fn msg(msg: T) -> Self { - Error::Custom(msg.to_string()) - } - pub fn with_states(&self, states: String) -> Self { - Error::Deserialize(self.to_string(), states) + Error::Custom(anyhow::anyhow!(msg.to_string())) } pub fn report(&self) -> Diagnostic<()> { match self { @@ -64,8 +58,7 @@ impl Error { let labels = get_binread_labels(e); diag.with_labels(labels) } - Error::Deserialize(e, _) => Diagnostic::error().with_message(e), - Error::Custom(e) => Diagnostic::error().with_message(e), + Error::Custom(e) => Diagnostic::error().with_message(e.to_string()), } } } @@ -109,7 +102,7 @@ fn get_binread_labels(e: &binread::Error) -> Vec> { vec![Label::primary((), pos..pos + 2).with_message(message)] } Io(e) => vec![Label::primary((), 0..0).with_message(e.to_string())], - _ => Vec::new(), + _ => unreachable!(), } } From 6c0d2e3c713e7a143e62007e2f0d6e8f040f3488 Mon Sep 17 00:00:00 2001 From: chenyan-dfinity Date: Mon, 29 Mar 2021 22:21:12 -0700 Subject: [PATCH 10/51] fix anyhow --- rust/candid/src/binary_parser.rs | 1 - rust/candid/src/error.rs | 4 ++-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/rust/candid/src/binary_parser.rs b/rust/candid/src/binary_parser.rs index e2694cac..d8e25de5 100644 --- a/rust/candid/src/binary_parser.rs +++ b/rust/candid/src/binary_parser.rs @@ -1,6 +1,5 @@ use crate::parser::typing::TypeEnv; use crate::types::internal::{Field, Label, Type}; -//use crate::{Error, Result}; use anyhow::{anyhow, Context, Result}; use binread::io::{Read, Seek, SeekFrom}; use binread::{BinRead, BinResult, Error as BError, ReadOptions}; diff --git a/rust/candid/src/error.rs b/rust/candid/src/error.rs index cb2409be..aeb4c87c 100644 --- a/rust/candid/src/error.rs +++ b/rust/candid/src/error.rs @@ -16,10 +16,10 @@ pub enum Error { #[error("Candid parser error: {0}")] Parse(#[from] token::ParserError), - #[error("Binary parser error: {0}")] + #[error(transparent)] Binread(#[from] binread::Error), - #[error("{0}")] + #[error(transparent)] Custom(#[from] anyhow::Error), } From 5efe8f55f72901c7f901ca9d1527ab749b5674d5 Mon Sep 17 00:00:00 2001 From: chenyan-dfinity Date: Wed, 31 Mar 2021 12:52:52 -0700 Subject: [PATCH 11/51] subtype --- rust/candid/src/de.rs | 51 +++++++++++++++++++++++++++---------------- 1 file changed, 32 insertions(+), 19 deletions(-) diff --git a/rust/candid/src/de.rs b/rust/candid/src/de.rs index 025838dd..f1cd4cc3 100644 --- a/rust/candid/src/de.rs +++ b/rust/candid/src/de.rs @@ -3,6 +3,7 @@ use super::error::{pretty_read, Error, Result}; use super::{idl_hash, parser::typing::TypeEnv, types::Type, CandidType, Int, Nat}; use crate::binary_parser::Header; +use crate::types::subtype::{subtype, Gamma}; use anyhow::{anyhow, Context}; use binread::BinRead; use byteorder::{LittleEndian, ReadBytesExt}; @@ -39,6 +40,20 @@ impl<'de> IDLDeserialize<'de> { expected_type }; self.de.wire_type = ty.clone(); + if !subtype( + &mut self.de.gamma, + &self.de.table, + &ty, + &self.de.table, + &self.de.expect_type, + ) { + return Err(Error::msg(format!( + "Fail to decode argument {}, because {} is not subtype of {}", + ind, + ty, + T::ty() + ))); + } let v = T::deserialize(&mut self.de) .with_context(|| self.de.dump_state()) @@ -72,12 +87,24 @@ impl<'de> IDLDeserialize<'de> { } } +macro_rules! assert { + ( $self:ident, false ) => {{ + return Err(anyhow!($self.dump_state())).context("Internal error")?; + }}; + ( $self:ident, $pred:expr ) => {{ + if !$pred { + return Err(anyhow!($self.dump_state())).context("Internal error")?; + } + }}; +} + struct Deserializer<'de> { input: Cursor<&'de [u8]>, table: TypeEnv, types: VecDeque<(usize, Type)>, wire_type: Type, expect_type: Type, + gamma: Gamma, record_nesting_depth: usize, } @@ -92,6 +119,7 @@ impl<'de> Deserializer<'de> { types: types.into_iter().enumerate().collect(), wire_type: Type::Unknown, expect_type: Type::Unknown, + gamma: Gamma::default(), record_nesting_depth: 0, }) } @@ -109,22 +137,12 @@ impl<'de> Deserializer<'de> { ); res } - fn expect_type(&self, expect: &Type) -> Result<()> { - if *expect != self.expect_type { - Err(Error::msg(format!( - "Internal error. Expect {}, but expect_type is {}", - expect, self.expect_type - ))) - } else { - Ok(()) - } - } fn deserialize_int<'a, V>(&'a mut self, visitor: V) -> Result where V: Visitor<'de>, { self.record_nesting_depth = 0; - self.expect_type(&Type::Int)?; + assert!(self, self.expect_type == Type::Int); let mut bytes = Vec::new(); match &self.wire_type { Type::Int => { @@ -137,12 +155,7 @@ impl<'de> Deserializer<'de> { let nat = Nat::decode(&mut self.input).map_err(Error::msg)?; bytes.extend_from_slice(&nat.0.to_bytes_le()); } - t => { - return Err(Error::msg(format!( - "Type mismatch. Expect int, but found {}", - t - ))) - } + _ => assert!(self, false), } visitor.visit_byte_buf(bytes) } @@ -159,7 +172,7 @@ impl<'de, 'a> de::Deserializer<'de> for &'a mut Deserializer<'de> { Type::Bool => self.deserialize_bool(visitor), Type::Int => self.deserialize_int(visitor), Type::Vec(_) => self.deserialize_seq(visitor), - _ => unreachable!(), + _ => assert!(self, false), } } @@ -173,8 +186,8 @@ impl<'de, 'a> de::Deserializer<'de> for &'a mut Deserializer<'de> { bool, ); self.record_nesting_depth = 0; + assert!(self, self.expect_type == Type::Bool); let res: BoolValue = pretty_read(&mut self.input)?; - self.expect_type(&Type::Bool)?; visitor.visit_bool(res.0) } From 4768588ec3980e7b27fdd0629518af3c65f9b151 Mon Sep 17 00:00:00 2001 From: chenyan-dfinity Date: Thu, 1 Apr 2021 15:09:17 -0700 Subject: [PATCH 12/51] fix --- rust/candid/src/de.rs | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/rust/candid/src/de.rs b/rust/candid/src/de.rs index f1cd4cc3..96498074 100644 --- a/rust/candid/src/de.rs +++ b/rust/candid/src/de.rs @@ -88,12 +88,12 @@ impl<'de> IDLDeserialize<'de> { } macro_rules! assert { - ( $self:ident, false ) => {{ - return Err(anyhow!($self.dump_state())).context("Internal error")?; + ( false ) => {{ + return Err(Error::msg("Internal error. Please file a bug.")); }}; - ( $self:ident, $pred:expr ) => {{ + ( $pred:expr ) => {{ if !$pred { - return Err(anyhow!($self.dump_state())).context("Internal error")?; + return Err(Error::msg("Internal error. Please file a bug.")); } }}; } @@ -142,7 +142,7 @@ impl<'de> Deserializer<'de> { V: Visitor<'de>, { self.record_nesting_depth = 0; - assert!(self, self.expect_type == Type::Int); + assert!(self.expect_type == Type::Int); let mut bytes = Vec::new(); match &self.wire_type { Type::Int => { @@ -155,7 +155,8 @@ impl<'de> Deserializer<'de> { let nat = Nat::decode(&mut self.input).map_err(Error::msg)?; bytes.extend_from_slice(&nat.0.to_bytes_le()); } - _ => assert!(self, false), + // We already did subtype checking before deserialize, so this is unreachable code + _ => assert!(false), } visitor.visit_byte_buf(bytes) } @@ -172,7 +173,7 @@ impl<'de, 'a> de::Deserializer<'de> for &'a mut Deserializer<'de> { Type::Bool => self.deserialize_bool(visitor), Type::Int => self.deserialize_int(visitor), Type::Vec(_) => self.deserialize_seq(visitor), - _ => assert!(self, false), + _ => assert!(false), } } @@ -186,7 +187,7 @@ impl<'de, 'a> de::Deserializer<'de> for &'a mut Deserializer<'de> { bool, ); self.record_nesting_depth = 0; - assert!(self, self.expect_type == Type::Bool); + assert!(self.expect_type == Type::Bool && self.wire_type == Type::Bool); let res: BoolValue = pretty_read(&mut self.input)?; visitor.visit_bool(res.0) } From ba0a80f4fb489871a728bc3fc9118e18bac6190d Mon Sep 17 00:00:00 2001 From: chenyan-dfinity Date: Thu, 1 Apr 2021 17:37:45 -0700 Subject: [PATCH 13/51] opt --- rust/candid/src/binary_parser.rs | 8 ++- rust/candid/src/de.rs | 87 +++++++++++++++++++++++++++----- rust/candid/src/error.rs | 2 +- 3 files changed, 82 insertions(+), 15 deletions(-) diff --git a/rust/candid/src/binary_parser.rs b/rust/candid/src/binary_parser.rs index d8e25de5..b056d8e2 100644 --- a/rust/candid/src/binary_parser.rs +++ b/rust/candid/src/binary_parser.rs @@ -66,6 +66,10 @@ struct FieldType { index: IndexType, } +fn index_to_var(ind: i64) -> String { + format!("var{}", ind) +} + impl IndexType { fn to_type(&self, len: u64) -> Result { Ok(match self.index { @@ -73,7 +77,7 @@ impl IndexType { if v >= len as i64 { return Err(anyhow!("type index {} out of range", v)); } - Type::Var(v.to_string()) + Type::Var(index_to_var(v)) } -1 => Type::Null, -2 => Type::Bool, @@ -133,7 +137,7 @@ impl Table { let mut env = BTreeMap::new(); for (i, t) in self.table.iter().enumerate() { env.insert( - i.to_string(), + index_to_var(i as i64), t.to_type(len) .with_context(|| format!("Invalid table entry {}: {:?}", i, t))?, ); diff --git a/rust/candid/src/de.rs b/rust/candid/src/de.rs index 96498074..e62c37a5 100644 --- a/rust/candid/src/de.rs +++ b/rust/candid/src/de.rs @@ -40,7 +40,8 @@ impl<'de> IDLDeserialize<'de> { expected_type }; self.de.wire_type = ty.clone(); - if !subtype( + self.de.check_subtype()?; + /*if !subtype( &mut self.de.gamma, &self.de.table, &ty, @@ -53,7 +54,7 @@ impl<'de> IDLDeserialize<'de> { ty, T::ty() ))); - } + }*/ let v = T::deserialize(&mut self.de) .with_context(|| self.de.dump_state()) @@ -89,11 +90,19 @@ impl<'de> IDLDeserialize<'de> { macro_rules! assert { ( false ) => {{ - return Err(Error::msg("Internal error. Please file a bug.")); + return Err(Error::msg(format!( + "Internal error at {}:{}. Please file a bug.", + file!(), + line!() + ))); }}; ( $pred:expr ) => {{ if !$pred { - return Err(Error::msg("Internal error. Please file a bug.")); + return Err(Error::msg(format!( + "Internal error at {}:{}. Please file a bug.", + file!(), + line!() + ))); } }}; } @@ -137,6 +146,27 @@ impl<'de> Deserializer<'de> { ); res } + fn check_subtype(&mut self) -> Result<()> { + if !subtype( + &mut self.gamma, + &self.table, + &self.wire_type, + &self.table, + &self.expect_type, + ) { + Err(Error::msg(format!( + "{} is not subtype of {}", + self.wire_type, self.expect_type, + ))) + } else { + Ok(()) + } + } + fn unroll_type(&mut self) -> Result<()> { + self.expect_type = self.table.trace_type(&self.expect_type)?; + self.wire_type = self.table.trace_type(&self.wire_type)?; + Ok(()) + } fn deserialize_int<'a, V>(&'a mut self, visitor: V) -> Result where V: Visitor<'de>, @@ -162,6 +192,12 @@ impl<'de> Deserializer<'de> { } } +#[derive(BinRead)] +struct BoolValue( + #[br(try_map = |x:u8| match x { 0u8 => Ok(false), | 1u8 => Ok(true), | _ => Err("Expect 00 or 01") } )] + bool, +); + impl<'de, 'a> de::Deserializer<'de> for &'a mut Deserializer<'de> { type Error = Error; fn deserialize_any(self, visitor: V) -> Result @@ -172,25 +208,53 @@ impl<'de, 'a> de::Deserializer<'de> for &'a mut Deserializer<'de> { match t { Type::Bool => self.deserialize_bool(visitor), Type::Int => self.deserialize_int(visitor), - Type::Vec(_) => self.deserialize_seq(visitor), + Type::Opt(_) => self.deserialize_option(visitor), _ => assert!(false), } } - fn deserialize_bool(self, visitor: V) -> Result where V: Visitor<'de>, { - #[derive(BinRead)] - struct BoolValue( - #[br(try_map = |x:u8| match x { 0u8 => Ok(false), | 1u8 => Ok(true), | _ => Err("Not a boolean") } )] - bool, - ); self.record_nesting_depth = 0; assert!(self.expect_type == Type::Bool && self.wire_type == Type::Bool); let res: BoolValue = pretty_read(&mut self.input)?; visitor.visit_bool(res.0) } + fn deserialize_option(self, visitor: V) -> Result + where + V: Visitor<'de>, + { + self.record_nesting_depth = 0; + self.unroll_type()?; + if let Type::Opt(ref t) = self.expect_type { + self.expect_type = *t.clone(); + } else { + assert!(false); + } + match self.wire_type { + Type::Null | Type::Reserved => visitor.visit_none(), + Type::Opt(ref t) => { + self.wire_type = *t.clone(); + if pretty_read::(&mut self.input)?.0 { + if self.check_subtype().is_ok() { + visitor.visit_some(self) + } else { + visitor.visit_none() + } + } else { + visitor.visit_none() + } + } + _ => { + if self.check_subtype().is_ok() { + visitor.visit_some(self) + } else { + visitor.visit_none() + } + } + } + } serde::forward_to_deserialize_any! { u8 @@ -207,7 +271,6 @@ impl<'de, 'a> de::Deserializer<'de> for &'a mut Deserializer<'de> { str string unit - option bytes byte_buf unit_struct diff --git a/rust/candid/src/error.rs b/rust/candid/src/error.rs index aeb4c87c..af7da326 100644 --- a/rust/candid/src/error.rs +++ b/rust/candid/src/error.rs @@ -101,7 +101,7 @@ fn get_binread_labels(e: &binread::Error) -> Vec> { let pos = (pos * 2) as usize; vec![Label::primary((), pos..pos + 2).with_message(message)] } - Io(e) => vec![Label::primary((), 0..0).with_message(e.to_string())], + Io(_) => vec![], _ => unreachable!(), } } From 5b9f9402bf79cd220e7636c78ca99940a1bb0044 Mon Sep 17 00:00:00 2001 From: chenyan-dfinity Date: Thu, 1 Apr 2021 18:19:39 -0700 Subject: [PATCH 14/51] vec --- rust/candid/src/de.rs | 86 ++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 80 insertions(+), 6 deletions(-) diff --git a/rust/candid/src/de.rs b/rust/candid/src/de.rs index e62c37a5..a09f32b4 100644 --- a/rust/candid/src/de.rs +++ b/rust/candid/src/de.rs @@ -7,6 +7,7 @@ use crate::types::subtype::{subtype, Gamma}; use anyhow::{anyhow, Context}; use binread::BinRead; use byteorder::{LittleEndian, ReadBytesExt}; +use leb128::read::{signed as sleb128_decode, unsigned as leb128_decode}; use serde::de::{self, Deserialize, Visitor}; use std::collections::{BTreeMap, VecDeque}; use std::convert::TryFrom; @@ -206,12 +207,22 @@ impl<'de, 'a> de::Deserializer<'de> for &'a mut Deserializer<'de> { { let t = self.table.trace_type(&self.expect_type)?; match t { + Type::Null => self.deserialize_unit(visitor), Type::Bool => self.deserialize_bool(visitor), Type::Int => self.deserialize_int(visitor), Type::Opt(_) => self.deserialize_option(visitor), + Type::Vec(_) => self.deserialize_seq(visitor), _ => assert!(false), } } + fn deserialize_unit(self, visitor: V) -> Result + where + V: Visitor<'de>, + { + self.record_nesting_depth = 0; + assert!(self.expect_type == Type::Null && self.wire_type == Type::Null); + visitor.visit_unit() + } fn deserialize_bool(self, visitor: V) -> Result where V: Visitor<'de>, @@ -227,10 +238,9 @@ impl<'de, 'a> de::Deserializer<'de> for &'a mut Deserializer<'de> { { self.record_nesting_depth = 0; self.unroll_type()?; - if let Type::Opt(ref t) = self.expect_type { - self.expect_type = *t.clone(); - } else { - assert!(false); + match self.expect_type { + Type::Opt(ref t) => self.expect_type = *t.clone(), + _ => assert!(false), } match self.wire_type { Type::Null | Type::Reserved => visitor.visit_none(), @@ -255,6 +265,22 @@ impl<'de, 'a> de::Deserializer<'de> for &'a mut Deserializer<'de> { } } } + fn deserialize_seq(mut self, visitor: V) -> Result + where + V: Visitor<'de>, + { + self.record_nesting_depth = 0; + self.unroll_type()?; + match (&self.expect_type, &self.wire_type) { + (Type::Vec(ref e), Type::Vec(ref w)) => { + self.expect_type = *e.clone(); + self.wire_type = *w.clone(); + let len = leb128_decode(&mut self.input).map_err(Error::msg)?; + visitor.visit_seq(Compound::new(&mut self, Style::Vector { len })) + } + _ => assert!(false), + } + } serde::forward_to_deserialize_any! { u8 @@ -270,7 +296,6 @@ impl<'de, 'a> de::Deserializer<'de> for &'a mut Deserializer<'de> { char str string - unit bytes byte_buf unit_struct @@ -280,8 +305,57 @@ impl<'de, 'a> de::Deserializer<'de> for &'a mut Deserializer<'de> { identifier tuple enum - seq map ignored_any } } + +#[derive(Debug)] +enum Style { + Tuple { + len: u32, + index: u32, + }, + Vector { + len: u64, // non-vector length can only be u32, because field ids is u32. + }, + Struct { + len: u32, + fs: BTreeMap>, + }, + Enum { + len: u32, + fs: BTreeMap>, + }, +} + +struct Compound<'a, 'de> { + de: &'a mut Deserializer<'de>, + style: Style, +} + +impl<'a, 'de> Compound<'a, 'de> { + fn new(de: &'a mut Deserializer<'de>, style: Style) -> Self { + Compound { de, style } + } +} + +impl<'de, 'a> de::SeqAccess<'de> for Compound<'a, 'de> { + type Error = Error; + + fn next_element_seed(&mut self, seed: T) -> Result> + where + T: de::DeserializeSeed<'de>, + { + match self.style { + Style::Vector { ref mut len } => { + if *len == 0 { + return Ok(None); + } + *len -= 1; + seed.deserialize(&mut *self.de).map(Some) + } + _ => Err(Error::msg("expect vector")), + } + } +} From 3c05675a95153ed0e159545c436a8de5235e525a Mon Sep 17 00:00:00 2001 From: chenyan-dfinity Date: Thu, 1 Apr 2021 20:06:58 -0700 Subject: [PATCH 15/51] fix --- rust/candid/src/binary_parser.rs | 12 +++++++++++- rust/candid/src/de.rs | 25 ++----------------------- 2 files changed, 13 insertions(+), 24 deletions(-) diff --git a/rust/candid/src/binary_parser.rs b/rust/candid/src/binary_parser.rs index b056d8e2..e32ce71c 100644 --- a/rust/candid/src/binary_parser.rs +++ b/rust/candid/src/binary_parser.rs @@ -66,10 +66,20 @@ struct FieldType { index: IndexType, } +#[derive(BinRead)] +pub struct BoolValue( + #[br(try_map = |x:u8| match x { 0u8 => Ok(false), | 1u8 => Ok(true), | _ => Err("Expect 00 or 01") } )] + pub bool, +); +#[derive(BinRead)] +pub struct Len( + #[br(parse_with = read_leb)] + pub u64 +); + fn index_to_var(ind: i64) -> String { format!("var{}", ind) } - impl IndexType { fn to_type(&self, len: u64) -> Result { Ok(match self.index { diff --git a/rust/candid/src/de.rs b/rust/candid/src/de.rs index a09f32b4..af5ee885 100644 --- a/rust/candid/src/de.rs +++ b/rust/candid/src/de.rs @@ -2,10 +2,9 @@ use super::error::{pretty_read, Error, Result}; use super::{idl_hash, parser::typing::TypeEnv, types::Type, CandidType, Int, Nat}; -use crate::binary_parser::Header; +use crate::binary_parser::{Header, BoolValue, Len}; use crate::types::subtype::{subtype, Gamma}; use anyhow::{anyhow, Context}; -use binread::BinRead; use byteorder::{LittleEndian, ReadBytesExt}; use leb128::read::{signed as sleb128_decode, unsigned as leb128_decode}; use serde::de::{self, Deserialize, Visitor}; @@ -42,20 +41,6 @@ impl<'de> IDLDeserialize<'de> { }; self.de.wire_type = ty.clone(); self.de.check_subtype()?; - /*if !subtype( - &mut self.de.gamma, - &self.de.table, - &ty, - &self.de.table, - &self.de.expect_type, - ) { - return Err(Error::msg(format!( - "Fail to decode argument {}, because {} is not subtype of {}", - ind, - ty, - T::ty() - ))); - }*/ let v = T::deserialize(&mut self.de) .with_context(|| self.de.dump_state()) @@ -193,12 +178,6 @@ impl<'de> Deserializer<'de> { } } -#[derive(BinRead)] -struct BoolValue( - #[br(try_map = |x:u8| match x { 0u8 => Ok(false), | 1u8 => Ok(true), | _ => Err("Expect 00 or 01") } )] - bool, -); - impl<'de, 'a> de::Deserializer<'de> for &'a mut Deserializer<'de> { type Error = Error; fn deserialize_any(self, visitor: V) -> Result @@ -275,7 +254,7 @@ impl<'de, 'a> de::Deserializer<'de> for &'a mut Deserializer<'de> { (Type::Vec(ref e), Type::Vec(ref w)) => { self.expect_type = *e.clone(); self.wire_type = *w.clone(); - let len = leb128_decode(&mut self.input).map_err(Error::msg)?; + let len = pretty_read::(&mut self.input)?.0; visitor.visit_seq(Compound::new(&mut self, Style::Vector { len })) } _ => assert!(false), From 46643f88ee3b758b525f28c6b0b78e02ac10ec9b Mon Sep 17 00:00:00 2001 From: chenyan-dfinity Date: Fri, 2 Apr 2021 13:59:38 -0700 Subject: [PATCH 16/51] struct, skipping not tested --- rust/candid/src/binary_parser.rs | 9 +- rust/candid/src/de.rs | 160 +++++++++++++++++++++++++++++-- rust/candid/src/error.rs | 4 +- 3 files changed, 160 insertions(+), 13 deletions(-) diff --git a/rust/candid/src/binary_parser.rs b/rust/candid/src/binary_parser.rs index e32ce71c..549acc11 100644 --- a/rust/candid/src/binary_parser.rs +++ b/rust/candid/src/binary_parser.rs @@ -54,14 +54,14 @@ struct IndexType { } #[derive(BinRead, Debug)] struct Fields { - #[br(parse_with = read_leb, try_map = |x:u64| x.try_into())] + #[br(parse_with = read_leb, try_map = |x:u64| x.try_into().map_err(|_| "field length out of 32-bit range"))] len: u32, #[br(count = len)] inner: Vec, } #[derive(BinRead, Debug)] struct FieldType { - #[br(parse_with = read_leb, try_map = |x:u64| x.try_into())] + #[br(parse_with = read_leb, try_map = |x:u64| x.try_into().map_err(|_| "field id out of 32-bit range"))] id: u32, index: IndexType, } @@ -72,10 +72,7 @@ pub struct BoolValue( pub bool, ); #[derive(BinRead)] -pub struct Len( - #[br(parse_with = read_leb)] - pub u64 -); +pub struct Len(#[br(parse_with = read_leb)] pub u64); fn index_to_var(ind: i64) -> String { format!("var{}", ind) diff --git a/rust/candid/src/de.rs b/rust/candid/src/de.rs index af5ee885..f9c5afee 100644 --- a/rust/candid/src/de.rs +++ b/rust/candid/src/de.rs @@ -1,8 +1,13 @@ //! Deserialize Candid binary format to Rust data structures use super::error::{pretty_read, Error, Result}; -use super::{idl_hash, parser::typing::TypeEnv, types::Type, CandidType, Int, Nat}; -use crate::binary_parser::{Header, BoolValue, Len}; +use super::{ + idl_hash, + parser::typing::TypeEnv, + types::{Field, Label, Type}, + CandidType, Int, Nat, +}; +use crate::binary_parser::{BoolValue, Header, Len}; use crate::types::subtype::{subtype, Gamma}; use anyhow::{anyhow, Context}; use byteorder::{LittleEndian, ReadBytesExt}; @@ -100,6 +105,9 @@ struct Deserializer<'de> { wire_type: Type, expect_type: Type, gamma: Gamma, + // field_name tells deserialize_identifier which field name to process. + // This field should always be set by set_field_name function. + field_name: Option