Skip to content

Commit

Permalink
more rust bindings
Browse files Browse the repository at this point in the history
fix
  • Loading branch information
hypercubestart committed Dec 12, 2020
1 parent f154ffa commit 5a8c94f
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 0 deletions.
13 changes: 13 additions & 0 deletions rust/tvm/src/ir/relay/attrs/nn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -129,3 +129,16 @@ pub struct AvgPool2DAttrsNode {
pub ceil_mode: bool,
pub count_include_pad: bool
}

#[repr(C)]
#[derive(Object, Debug)]
#[ref_name = "UpSamplingAttrs"]
#[type_key = "relay.attrs.UpSamplingAttrs"]
pub struct UpSamplingAttrsNode {
pub base: BaseAttrsNode,
pub scale_h: f64,
pub scale_w: f64,
pub layout: TString,
pub method: TString,
pub align_corners: bool
}
52 changes: 52 additions & 0 deletions rust/tvm/src/ir/relay/attrs/transform.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,13 @@
*/

use crate::ir::attrs::BaseAttrsNode;
use crate::ir::PrimExpr;
use crate::runtime::array::Array;
use crate::runtime::ObjectRef;
use tvm_macros::Object;

type IndexExpr = PrimExpr;

#[repr(C)]
#[derive(Object, Debug)]
#[ref_name = "ExpandDimsAttrs"]
Expand All @@ -29,3 +34,50 @@ pub struct ExpandDimsAttrsNode {
pub axis: i32,
pub num_newaxis: i32,
}

#[repr(C)]
#[derive(Object, Debug)]
#[ref_name = "ConcatenateAttrs"]
#[type_key = "relay.attrs.ConcatenateAttrs"]
pub struct ConcatenateAttrsNode {
pub base: BaseAttrsNode,
pub axis: i32
}

#[repr(C)]
#[derive(Object, Debug)]
#[ref_name = "ReshapeAttrs"]
#[type_key = "relay.attrs.ReshapeAttrs"]
pub struct ReshapeAttrsNode {
pub base: BaseAttrsNode,
pub newshape: Array<IndexExpr>,
pub reverse: bool
}

#[repr(C)]
#[derive(Object, Debug)]
#[ref_name = "SplitAttrs"]
#[type_key = "relay.attrs.SplitAttrs"]
pub struct SplitAttrsNode {
pub base: BaseAttrsNode,
pub indices_or_sections: ObjectRef,
pub axis: i32
}

#[repr(C)]
#[derive(Object, Debug)]
#[ref_name = "TransposeAttrs"]
#[type_key = "relay.attrs.TransposeAttrs"]
pub struct TransposeAttrsNode {
pub base: BaseAttrsNode,
pub axes: Array<IndexExpr>
}

#[repr(C)]
#[derive(Object, Debug)]
#[ref_name = "SqueezeAttrs"]
#[type_key = "relay.attrs.SqueezeAttrs"]
pub struct SqueezeAttrsNode {
pub base: BaseAttrsNode,
pub axis: Array<IndexExpr>
}

0 comments on commit 5a8c94f

Please sign in to comment.