Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support reference types #153

Merged
merged 9 commits into from
Jan 5, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
350 changes: 145 additions & 205 deletions Cargo.lock

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion rust/candid/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ num-bigint = "0.3.0"
num-traits = "0.2.12"
paste = "1.0.0"
pretty = "0.10.0"
serde = "1.0.115"
serde = { version = "1.0.118", features = ["derive"] }
thiserror = "1.0.20"

[dev-dependencies]
Expand Down
4 changes: 2 additions & 2 deletions rust/candid/src/codegen/rust.rs
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ pub trait RustBindings {
} else {
format!(
"std::pin::Pin<std::boxed::Box<impl std::future::Future<Output = {}>>>",
if return_type == "" {
if return_type.is_empty() {
"()"
} else {
&return_type
Expand All @@ -117,7 +117,7 @@ pub trait RustBindings {
id = id,
arguments = arguments_list,
body = body,
return_type = if return_type == "" {
return_type = if return_type.is_empty() {
format!("")
} else {
format!(" -> {}", return_type)
Expand Down
116 changes: 108 additions & 8 deletions rust/candid/src/de.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ impl<'de> IDLDeserialize<'de> {
}
}

#[derive(Clone, Debug)]
#[derive(Clone, Debug, PartialEq, Eq)]
enum RawValue {
I(i64),
U(u32),
Expand All @@ -80,8 +80,11 @@ impl RawValue {
}
}
}
fn validate_type_range(ty: i64, len: u64) -> Result<()> {
if ty >= 0 && (ty as u64) < len || Opcode::try_from(ty).is_ok() {
fn is_primitive_type(ty: i64) -> bool {
ty < 0 && (ty >= -17 || ty == -24)
}
fn validate_type_range(ty: i64, len: usize) -> Result<()> {
if ty >= 0 && (ty as usize) < len || is_primitive_type(ty) {
Ok(())
} else {
Err(Error::msg(format!("unknown type {}", ty)))
Expand Down Expand Up @@ -169,11 +172,18 @@ impl<'de> Deserializer<'de> {
// Parse magic number, type table, and type seq from input.
fn parse_table(&mut self) -> Result<()> {
self.parse_magic()?;
let len = self.leb128_read()?;
for _i in 0..len {
let len = self.leb128_read()? as usize;
let mut expect_func = std::collections::HashSet::new();
for i in 0..len {
let mut buf = Vec::new();
let ty = self.sleb128_read()?;
buf.push(RawValue::I(ty));
if expect_func.contains(&i) && ty != -22 {
return Err(Error::msg(format!(
"Expect function opcode, but got {}",
ty
)));
}
match Opcode::try_from(ty) {
Ok(Opcode::Opt) | Ok(Opcode::Vec) => {
let ty = self.sleb128_read()?;
Expand All @@ -200,6 +210,58 @@ impl<'de> Deserializer<'de> {
buf.push(RawValue::I(ty));
}
}
Ok(Opcode::Service) => {
let obj_len = u32::try_from(self.leb128_read()?)
.map_err(|_| Error::msg(Error::msg("length out of u32")))?;
// Push one element to the table to ensure it's a non-primitive type
buf.push(RawValue::U(obj_len));
let mut prev_hash = None;
for _ in 0..obj_len {
let mlen = self.leb128_read()? as usize;
let meth = self.parse_string(mlen)?;
let hash = crate::idl_hash(&meth);
if let Some(prev_hash) = prev_hash {
if prev_hash >= hash {
return Err(Error::msg("method name collision or not sorted"));
}
}
prev_hash = Some(hash);
let ty = self.sleb128_read()?;
validate_type_range(ty, len)?;
// Check for method type
if ty >= 0 {
let idx = ty as usize;
if idx < self.table.len() && self.table[idx][0] != RawValue::I(-22) {
return Err(Error::msg("not a function type"));
} else {
expect_func.insert(idx);
}
} else {
return Err(Error::msg("not a function type"));
}
}
}
Ok(Opcode::Func) => {
let arg_len = self.leb128_read()?;
// Push one element to the table to ensure it's a non-primitive type
buf.push(RawValue::U(arg_len as u32));
for _ in 0..arg_len {
let ty = self.sleb128_read()?;
validate_type_range(ty, len)?;
}
let ret_len = self.leb128_read()?;
for _ in 0..ret_len {
let ty = self.sleb128_read()?;
validate_type_range(ty, len)?;
}
let ann_len = self.leb128_read()?;
for _ in 0..ann_len {
let ann = self.parse_byte()?;
if ann > 2u8 {
return Err(Error::msg("Unknown function annotation"));
}
}
}
_ => {
return Err(Error::msg(format!(
"Unsupported op_code {} in type table",
Expand All @@ -212,6 +274,7 @@ impl<'de> Deserializer<'de> {
let len = self.leb128_read()?;
for _i in 0..len {
let ty = self.sleb128_read()?;
validate_type_range(ty, self.table.len())?;
self.types.push_back(RawValue::I(ty));
}
Ok(())
Expand Down Expand Up @@ -300,22 +363,55 @@ impl<'de> Deserializer<'de> {
tagged.extend_from_slice(&bytes);
visitor.visit_byte_buf(tagged)
}
fn decode_principal(&mut self) -> Result<Vec<u8>> {
let bit = self.parse_byte()?;
if bit != 1u8 {
return Err(Error::msg("Opaque reference not supported"));
}
let len = self.leb128_read()? as usize;
self.parse_bytes(len)
}
fn deserialize_principal<'a, V>(&'a mut self, visitor: V) -> Result<V::Value>
where
V: Visitor<'de>,
{
self.check_type(Opcode::Principal)?;
let vec = self.decode_principal()?;
let mut tagged = vec![2u8];
tagged.extend_from_slice(&vec);
visitor.visit_byte_buf(tagged)
}
fn deserialize_service<'a, V>(&'a mut self, visitor: V) -> Result<V::Value>
where
V: Visitor<'de>,
{
self.check_type(Opcode::Service)?;
self.pop_current_type()?;
let vec = self.decode_principal()?;
let mut tagged = vec![4u8];
tagged.extend_from_slice(&vec);
visitor.visit_byte_buf(tagged)
}
fn deserialize_function<'a, V>(&'a mut self, visitor: V) -> Result<V::Value>
where
V: Visitor<'de>,
{
self.check_type(Opcode::Func)?;
self.pop_current_type()?;
let bit = self.parse_byte()?;
if bit != 1u8 {
return Err(Error::msg("Opaque reference not supported"));
}
let vec = self.decode_principal()?;
let len = self.leb128_read()? as usize;
let vec = self.parse_bytes(len)?;
let mut tagged = vec![2u8];
let meth = self.parse_bytes(len)?;
let mut tagged = vec![5u8];
// TODO: find a better way
leb128::write::unsigned(&mut tagged, len as u64)?;
tagged.extend_from_slice(&meth);
tagged.extend_from_slice(&vec);
visitor.visit_byte_buf(tagged)
}

fn deserialize_reserved<'a, V>(&'a mut self, visitor: V) -> Result<V::Value>
where
V: Visitor<'de>,
Expand Down Expand Up @@ -383,6 +479,8 @@ impl<'de, 'a> de::Deserializer<'de> for &'a mut Deserializer<'de> {
Opcode::Record => self.deserialize_struct("_", &[], visitor),
Opcode::Variant => self.deserialize_enum("_", &[], visitor),
Opcode::Principal => self.deserialize_principal(visitor),
Opcode::Service => self.deserialize_service(visitor),
Opcode::Func => self.deserialize_function(visitor),
}
}

Expand Down Expand Up @@ -450,6 +548,8 @@ impl<'de, 'a> de::Deserializer<'de> for &'a mut Deserializer<'de> {
visitor.visit_enum(Compound::new(&mut self, Style::Enum { len, fs }))
}
Opcode::Principal => self.deserialize_principal(visitor),
Opcode::Service => self.deserialize_service(visitor),
Opcode::Func => self.deserialize_function(visitor),
}
}

Expand Down
65 changes: 38 additions & 27 deletions rust/candid/src/parser/grammar.lalrpop
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use super::value::{IDLField, IDLValue, IDLArgs};
use super::typing::{check_unique, TypeEnv};
use super::types::{IDLType, PrimType, TypeField, FuncType, FuncMode, Binding, Dec, IDLProg, IDLTypes};
use super::test::{Assert, Input, Test};
use super::token::{Token, error, error2, LexicalError, Span};
use super::token::{Token, error2, LexicalError, Span};
use crate::{idl_hash, Principal, types::Label};

grammar;
Expand Down Expand Up @@ -40,6 +40,7 @@ extern {
"{" => Token::LBrace,
"}" => Token::RBrace,
"," => Token::Comma,
"." => Token::Dot,
";" => Token::Semi,
":" => Token::Colon,
"->" => Token::Arrow,
Expand All @@ -61,9 +62,10 @@ pub Arg: IDLValue = {
"null" => IDLValue::Null,
"opt" <Arg> => IDLValue::Opt(Box::new(<>)),
"vec" "{" <SepBy<Arg, ";">> "}" => IDLValue::Vec(<>),
"record" "{" <SepBy<RecordField, ";">> "}" =>? {
"record" "{" <Sp<SepBy<RecordField, ";">>> "}" =>? {
let mut id: u32 = 0;
let mut fs: Vec<IDLField> = <>.into_iter().map(|f| {
let span = <>.1.clone();
let mut fs: Vec<IDLField> = <>.0.into_iter().map(|f| {
match f.id {
Label::Unnamed(_) => {
id = id + 1;
Expand All @@ -76,11 +78,16 @@ pub Arg: IDLValue = {
}
}).collect();
fs.sort_unstable_by_key(|IDLField { id, .. }| id.get_id());
check_unique(fs.iter().map(|f| &f.id)).map_err(error)?;
check_unique(fs.iter().map(|f| &f.id)).map_err(|e| error2(e, span))?;
Ok(IDLValue::Record(fs))
},
"variant" "{" <VariantField> "}" => IDLValue::Variant(Box::new(<>), 0),
"principal" <Text> =>? Ok(IDLValue::Principal(Principal::from_text(<>).map_err(error)?)),
"principal" <Sp<Text>> =>? Ok(IDLValue::Principal(Principal::from_text(&<>.0).map_err(|e| error2(e, <>.1))?)),
"service" <Sp<Text>> =>? Ok(IDLValue::Service(Principal::from_text(&<>.0).map_err(|e| error2(e, <>.1))?)),
"func" <id:Sp<Text>> "." <meth:Name> =>? {
let id = Principal::from_text(&id.0).map_err(|e| error2(e, id.1))?;
Ok(IDLValue::Func(id, meth))
},
"(" <AnnVal> ")" => <>,
}

Expand All @@ -105,10 +112,10 @@ Number: String = {

AnnVal: IDLValue = {
<Arg> => <>,
<arg:Arg> ":" <typ:Typ> =>? {
<arg:Sp<Arg>> ":" <typ:Sp<Typ>> =>? {
let env = TypeEnv::new();
let typ = env.ast_to_type(&typ).map_err(error)?;
arg.annotate_type(true, &env, &typ).map_err(error)
let typ = env.ast_to_type(&typ.0).map_err(|e| error2(e, typ.1))?;
arg.0.annotate_type(true, &env, &typ).map_err(|e| error2(e, arg.1))
}
}

Expand All @@ -120,12 +127,13 @@ NumLiteral: IDLValue = {
};
IDLValue::Number(num)
},
<sign:"sign"?> <n:"float"> =>? {
<sign:"sign"?> <n:Sp<"float">> =>? {
let span = n.1.clone();
let num = match sign {
Some('-') => format!("-{}", n),
_ => n,
Some('-') => format!("-{}", n.0),
_ => n.0,
};
let f = num.parse::<f64>().map_err(|_| error("not a float"))?;
let f = num.parse::<f64>().map_err(|_| error2("not a float", span))?;
Ok(IDLValue::Float64(f))
},
}
Expand Down Expand Up @@ -159,23 +167,25 @@ pub Typ: IDLType = {
"opt" <Typ> => IDLType::OptT(Box::new(<>)),
"vec" <Typ> => IDLType::VecT(Box::new(<>)),
"blob" => IDLType::VecT(Box::new(IDLType::PrimT(PrimType::Nat8))),
"record" "{" <SepBy<RecordFieldTyp, ";">> "}" =>? {
"record" "{" <Sp<SepBy<RecordFieldTyp, ";">>> "}" =>? {
let mut id: u32 = 0;
let mut fs: Vec<TypeField> = <>.iter().map(|f| {
let span = <>.1.clone();
let mut fs: Vec<TypeField> = <>.0.iter().map(|f| {
let label = match f.label {
Label::Unnamed(_) => { id = id + 1; Label::Unnamed(id - 1) },
ref l => { id = l.get_id() + 1; l.clone() },
};
TypeField { label, typ: f.typ.clone() }
}).collect();
fs.sort_unstable_by_key(|TypeField { label, .. }| label.get_id());
check_unique(fs.iter().map(|f| &f.label)).map_err(error)?;
check_unique(fs.iter().map(|f| &f.label)).map_err(|e| error2(e, span))?;
Ok(IDLType::RecordT(fs))
},
"variant" "{" <mut fs:SepBy<VariantFieldTyp, ";">> "}" =>? {
fs.sort_unstable_by_key(|TypeField { label, .. }| label.get_id());
check_unique(fs.iter().map(|f| &f.label)).map_err(error)?;
Ok(IDLType::VariantT(fs))
"variant" "{" <mut fs:Sp<SepBy<VariantFieldTyp, ";">>> "}" =>? {
let span = fs.1.clone();
fs.0.sort_unstable_by_key(|TypeField { label, .. }| label.get_id());
check_unique(fs.0.iter().map(|f| &f.label)).map_err(|e| error2(e, span))?;
Ok(IDLType::VariantT(fs.0))
},
"func" <FuncTyp> => IDLType::FuncT(<>),
"service" <ActorTyp> => IDLType::ServT(<>),
Expand Down Expand Up @@ -226,11 +236,12 @@ FuncMode: FuncMode = {
}

ActorTyp: Vec<Binding> = {
"{" <mut fs:SepBy<MethTyp, ";">> "}" =>? {
fs.sort_unstable_by_key(|Binding { id, .. }| idl_hash(id));
let labs: Vec<_> = fs.iter().map(|f| Label::Named(f.id.clone())).collect();
check_unique(labs.iter()).map_err(error)?;
Ok(fs)
"{" <mut fs:Sp<SepBy<MethTyp, ";">>> "}" =>? {
let span = fs.1.clone();
fs.0.sort_unstable_by_key(|Binding { id, .. }| idl_hash(id));
let labs: Vec<_> = fs.0.iter().map(|f| Label::Named(f.id.clone())).collect();
check_unique(labs.iter()).map_err(|e| error2(e, span))?;
Ok(fs.0)
}
}

Expand Down Expand Up @@ -266,9 +277,9 @@ Input: Input = {
Bytes => Input::Blob(<>),
}

Assert: Assert = <id:"id"> <assert:Assertion> =>? {
if id != "assert" {
Err(error("not an assert"))
Assert: Assert = <id:Sp<"id">> <assert:Assertion> =>? {
if id.0 != "assert" {
Err(error2("not an assert", id.1))
} else { Ok(assert) }
};

Expand Down
2 changes: 2 additions & 0 deletions rust/candid/src/parser/token.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ pub enum Token {
Semi,
#[token(",")]
Comma,
#[token(".", priority = 10)]
Dot,
#[token(":")]
Colon,
#[token("->")]
Expand Down
Loading