From 5a8c94f8f26a5fc07e00dd34e44e3a9da7f28186 Mon Sep 17 00:00:00 2001 From: Andrew Liu Date: Sat, 12 Dec 2020 12:44:01 -0800 Subject: [PATCH] more rust bindings fix --- rust/tvm/src/ir/relay/attrs/nn.rs | 13 ++++++ rust/tvm/src/ir/relay/attrs/transform.rs | 52 ++++++++++++++++++++++++ 2 files changed, 65 insertions(+) diff --git a/rust/tvm/src/ir/relay/attrs/nn.rs b/rust/tvm/src/ir/relay/attrs/nn.rs index ff523dcb0302..41e28f2a281f 100644 --- a/rust/tvm/src/ir/relay/attrs/nn.rs +++ b/rust/tvm/src/ir/relay/attrs/nn.rs @@ -129,3 +129,16 @@ pub struct AvgPool2DAttrsNode { pub ceil_mode: bool, pub count_include_pad: bool } + +#[repr(C)] +#[derive(Object, Debug)] +#[ref_name = "UpSamplingAttrs"] +#[type_key = "relay.attrs.UpSamplingAttrs"] +pub struct UpSamplingAttrsNode { + pub base: BaseAttrsNode, + pub scale_h: f64, + pub scale_w: f64, + pub layout: TString, + pub method: TString, + pub align_corners: bool +} diff --git a/rust/tvm/src/ir/relay/attrs/transform.rs b/rust/tvm/src/ir/relay/attrs/transform.rs index c459f96b2d2f..aafd258a4a48 100644 --- a/rust/tvm/src/ir/relay/attrs/transform.rs +++ b/rust/tvm/src/ir/relay/attrs/transform.rs @@ -18,8 +18,13 @@ */ use crate::ir::attrs::BaseAttrsNode; +use crate::ir::PrimExpr; +use crate::runtime::array::Array; +use crate::runtime::ObjectRef; use tvm_macros::Object; +type IndexExpr = PrimExpr; + #[repr(C)] #[derive(Object, Debug)] #[ref_name = "ExpandDimsAttrs"] @@ -29,3 +34,50 @@ pub struct ExpandDimsAttrsNode { pub axis: i32, pub num_newaxis: i32, } + +#[repr(C)] +#[derive(Object, Debug)] +#[ref_name = "ConcatenateAttrs"] +#[type_key = "relay.attrs.ConcatenateAttrs"] +pub struct ConcatenateAttrsNode { + pub base: BaseAttrsNode, + pub axis: i32 +} + +#[repr(C)] +#[derive(Object, Debug)] +#[ref_name = "ReshapeAttrs"] +#[type_key = "relay.attrs.ReshapeAttrs"] +pub struct ReshapeAttrsNode { + pub base: BaseAttrsNode, + pub newshape: Array, + pub reverse: bool +} + +#[repr(C)] +#[derive(Object, Debug)] +#[ref_name = "SplitAttrs"] +#[type_key = "relay.attrs.SplitAttrs"] +pub struct SplitAttrsNode { + pub base: BaseAttrsNode, + pub indices_or_sections: ObjectRef, + pub axis: i32 +} + +#[repr(C)] +#[derive(Object, Debug)] +#[ref_name = "TransposeAttrs"] +#[type_key = "relay.attrs.TransposeAttrs"] +pub struct TransposeAttrsNode { + pub base: BaseAttrsNode, + pub axes: Array +} + +#[repr(C)] +#[derive(Object, Debug)] +#[ref_name = "SqueezeAttrs"] +#[type_key = "relay.attrs.SqueezeAttrs"] +pub struct SqueezeAttrsNode { + pub base: BaseAttrsNode, + pub axis: Array +}