diff --git a/nnvm/src/pass/plan_memory.cc b/nnvm/src/pass/plan_memory.cc index 6c2fc0d087ea..e0788386e6ea 100644 --- a/nnvm/src/pass/plan_memory.cc +++ b/nnvm/src/pass/plan_memory.cc @@ -7,12 +7,37 @@ #include #include #include +#include #include #include "graph_algorithm.h" namespace nnvm { namespace pass { namespace { + using namespace nnvm::top; +// Return bytes of data flag. +static int GetDTypeSize(int type_flag) { + switch (type_flag) { + case kUint8: + case kInt8: + return 1; + case kFloat16: + case kInt16: + case kUint16: + return 2; + case kFloat32: + case kInt32: + case kUint32: + return 4; + case kFloat64: + case kInt64: + case kUint64: + return 8; + default: + LOG(FATAL) << "unknown type_flag=" << type_flag; + return -1; + } +} // simple graph based allocator. class GraphAllocator { @@ -199,7 +224,8 @@ size_t AllocMemory(const Graph& ret, const IndexedGraph& idx, ((storage_ref_count[sid_in] == 1 && !ignore_all_inputs) || identity[ipair]) && entry_ref_count[eid_out] > 0 && shape_vec[eid_out].Size() == shape_vec[eid_in].Size() && - dtype_vec[eid_out] == dtype_vec[eid_in]) { + (dtype_vec[eid_out] == dtype_vec[eid_in] || + GetDTypeSize(dtype_vec[eid_out]) == GetDTypeSize(dtype_vec[eid_in]))) { // inplace optimization taken[kv.first] = true; storage[eid_out] = sid_in;