Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
add comments and refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
szha committed Aug 4, 2018
1 parent 63726f0 commit 953e1bb
Showing 1 changed file with 37 additions and 51 deletions.
88 changes: 37 additions & 51 deletions src/operator/nn/concat.cc
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,10 @@ static bool ConcatShape(const nnvm::NodeAttrs& attrs,
return dshape.Size() != 0;
}

// Concat for RNN param deals with the reverse shape inference from output
// for the special case of concatenating RNN parameters.
// The first (and sometimes the second) input may be unknown on the target axis.
// If the two inputs are unknown, they always have the same shape.
static bool RNNParamConcatShape(const nnvm::NodeAttrs& attrs,
std::vector<TShape> *in_shape,
std::vector<TShape> *out_shape) {
Expand Down Expand Up @@ -283,6 +287,34 @@ struct ConcatGrad {

DMLC_REGISTER_PARAMETER(ConcatParam);

#define CONCAT_FORWARD_ATTRS \
.set_num_inputs([](const NodeAttrs& attrs) { \
const ConcatParam& params = nnvm::get<ConcatParam>(attrs.parsed); \
return params.num_args; \
}) \
.set_num_outputs(1) \
.set_attr_parser(ParamParser<ConcatParam>) \
.set_attr<nnvm::FListInputNames>("FListInputNames", \
[](const NodeAttrs& attrs) { \
const ConcatParam& params = nnvm::get<ConcatParam>(attrs.parsed); \
std::vector<std::string> ret; \
for (int i = 0; i < params.num_args; ++i) { \
ret.push_back(std::string("arg") + std::to_string(i)); \
} \
return ret; \
}) \
.set_attr<nnvm::FListOutputNames>("FListOutputNames", \
[](const NodeAttrs& attrs) { \
return std::vector<std::string>{"output"}; \
}) \
.set_attr<nnvm::FInferType>("FInferType", ConcatType) \
.set_attr<FInferStorageType>("FInferStorageType", ConcatForwardInferStorageType) \
.set_attr<FCompute>("FCompute<cpu>", ConcatCompute<cpu>) \
.set_attr<FComputeEx>("FComputeEx<cpu>", ConcatComputeExCPU) \
.set_attr<nnvm::FGradient>("FGradient", ConcatGrad{"_backward_Concat"}) \
.set_attr<std::string>("key_var_num_args", "num_args")


NNVM_REGISTER_OP(Concat)
MXNET_ADD_SPARSE_OP_ALIAS(concat)
.add_alias("concat")
Expand Down Expand Up @@ -323,37 +355,13 @@ Example::
[ 5., 5., 8., 8.]]
)code" ADD_FILELINE)
.set_num_inputs([](const NodeAttrs& attrs) {
const ConcatParam& params = nnvm::get<ConcatParam>(attrs.parsed);
return params.num_args;
})
.set_num_outputs(1)
.set_attr_parser(ParamParser<ConcatParam>)
.set_attr<nnvm::FListInputNames>("FListInputNames",
[](const NodeAttrs& attrs) {
const ConcatParam& params = nnvm::get<ConcatParam>(attrs.parsed);
std::vector<std::string> ret;
for (int i = 0; i < params.num_args; ++i) {
ret.push_back(std::string("arg") + std::to_string(i));
}
return ret;
})
.set_attr<nnvm::FListOutputNames>("FListOutputNames",
[](const NodeAttrs& attrs) {
return std::vector<std::string>{"output"};
})
#if MXNET_USE_MKLDNN == 1
.set_attr<FResourceRequest>("FResourceRequest", [](const NodeAttrs& n) {
return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
})
#endif
CONCAT_FORWARD_ATTRS
.set_attr<nnvm::FInferShape>("FInferShape", ConcatShape)
.set_attr<nnvm::FInferType>("FInferType", ConcatType)
.set_attr<FInferStorageType>("FInferStorageType", ConcatForwardInferStorageType)
.set_attr<FCompute>("FCompute<cpu>", ConcatCompute<cpu>)
.set_attr<FComputeEx>("FComputeEx<cpu>", ConcatComputeExCPU)
.set_attr<nnvm::FGradient>("FGradient", ConcatGrad{"_backward_Concat"})
.set_attr<std::string>("key_var_num_args", "num_args")
.add_argument("data", "NDArray-or-Symbol[]", "List of arrays to concatenate")
.add_arguments(ConcatParam::__FIELDS__());

Expand All @@ -375,39 +383,17 @@ NNVM_REGISTER_OP(_backward_Concat)
#endif
.set_attr<FCompute>("FCompute<cpu>", ConcatGradCompute<cpu>);


// _rnn_param_concat is a custom concat op with specialized infer_shape,
// which handles the case where the first one or two inputs may have
// unknown shape that can be inferred from output shape.
NNVM_REGISTER_OP(_rnn_param_concat)
.set_num_inputs([](const NodeAttrs& attrs) {
const ConcatParam& params = nnvm::get<ConcatParam>(attrs.parsed);
return params.num_args;
})
.set_num_outputs(1)
.set_attr_parser(ParamParser<ConcatParam>)
.set_attr<nnvm::FListInputNames>("FListInputNames",
[](const NodeAttrs& attrs) {
const ConcatParam& params = nnvm::get<ConcatParam>(attrs.parsed);
std::vector<std::string> ret;
for (int i = 0; i < params.num_args; ++i) {
ret.push_back(std::string("arg") + std::to_string(i));
}
return ret;
})
.set_attr<nnvm::FListOutputNames>("FListOutputNames",
[](const NodeAttrs& attrs) {
return std::vector<std::string>{"output"};
})
#if MXNET_USE_MKLDNN == 1
.set_attr<FResourceRequest>("FResourceRequest", [](const NodeAttrs& n) {
return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
})
#endif
CONCAT_FORWARD_ATTRS
.set_attr<nnvm::FInferShape>("FInferShape", RNNParamConcatShape)
.set_attr<nnvm::FInferType>("FInferType", ConcatType)
.set_attr<FInferStorageType>("FInferStorageType", ConcatForwardInferStorageType)
.set_attr<FCompute>("FCompute<cpu>", ConcatCompute<cpu>)
.set_attr<FComputeEx>("FComputeEx<cpu>", ConcatComputeExCPU)
.set_attr<nnvm::FGradient>("FGradient", ConcatGrad{"_backward_Concat"})
.set_attr<std::string>("key_var_num_args", "num_args")
.add_argument("data", "NDArray-or-Symbol[]", "List of arrays to concatenate")
.add_arguments(ConcatParam::__FIELDS__());

Expand Down

0 comments on commit 953e1bb

Please sign in to comment.