diff --git a/csrc/device_lower/pass/insert_syncs.cpp b/csrc/device_lower/pass/insert_syncs.cpp index cadbf98e896..0d84f3e250c 100644 --- a/csrc/device_lower/pass/insert_syncs.cpp +++ b/csrc/device_lower/pass/insert_syncs.cpp @@ -411,6 +411,11 @@ class ReadAfterWriteSyncs : public kir::ExprMutator { auto fence_async = IrBuilder::create(); registerInsertBefore(expr, fence_async, scope); } + + // An mma operation is added to async mma pipeline. + fill_async_mma_pipeline_ = true; + // async mma pipeline has not been flushed yet. + flush_async_mma_pipeline_ = false; } } else if (ir_utils::isCpAsyncBulkStore(expr)) { // Add a fence before TMA store so that writes in the generic proxy is @@ -420,7 +425,7 @@ class ReadAfterWriteSyncs : public kir::ExprMutator { registerInsertBefore(expr, fence_async, scope); } - // Insert sync exprs before async ops. For example, insert + // Insert sync exprs after async ops. For example, insert // wgmma.commit_group.sync.aligned // wgmma.wait_group.sync.aligned 0 // before the use of mma results. Note that cp.async is not handled @@ -430,13 +435,34 @@ class ReadAfterWriteSyncs : public kir::ExprMutator { for (auto inp : expr->inputs()) { auto def = inp->definition(); auto async_type = ir_utils::getAsyncOpType(def); + + NVF_ERROR( + !flush_async_mma_pipeline_ || !fill_async_mma_pipeline_, + "The async mma pipeline cannot be filled without encountering ", + "another mma op after it is flushed with a RAW sync."); + + // Detect a expression that consumes async mma operation. + // The async mma pipeline is already flushed and is empty. + // Adding a RAW wgmma.wait_group(0) is not necessary so skip it. + if (async_type == AsyncOpType::WgMma && !fill_async_mma_pipeline_ && + flush_async_mma_pipeline_) { + continue; + } + if (async_type != AsyncOpType::NotAsync && async_type != AsyncOpType::CpAsync) { input_async_ops[async_type].insert(def); + // async mma pipeline is flushed. + flush_async_mma_pipeline_ = true; + // No mma operations are active in the async mma pipeline. + fill_async_mma_pipeline_ = false; } } for (const auto& [async_type, ops] : input_async_ops) { - auto sync_exprs = lower_utils::getSyncExprs(async_type, 0); + auto sync_exprs = lower_utils::getSyncExprs( + async_type, + /*keep_stages=*/0, + /*requires_commit=*/async_type != AsyncOpType::WgMma); for (auto sync_expr : sync_exprs) { insertSyncExpr(ops, expr, sync_expr, nullptr); } @@ -758,7 +784,31 @@ class ReadAfterWriteSyncs : public kir::ExprMutator { NVF_ERROR(sync_before_.empty(), "Didn't place all required syncs."); } + bool checkAsyncMmaPipeline() { + return fill_async_mma_pipeline_ == false; + } + private: + //! fill_async_mma_pipeline_ is true when any mma expression is issued. A RAW + //! sync is required before any consumer operations use the results of mma + //! expression. + //! + //! flush_async_mma_pipeline_ is true when a RAW sync is issued for async mma + //! pipeline. The RAW sync for async wgmma is `wgmma.wait_group(0)`. All prior + //! mma operations are completed after this operation. No additional RAW sync + //! are required for other consumer expressions unless another mma expression + //! occurs in the fusion. + //! + //! fill_async_mma_pipeline_ and flush_async_mma_pipeline_ cannot be true + //! simultaneously. + //! + //! fill_async_mma_pipeline_ is always false at end of `ReadAfterWriteSyncs`. + //! Detect mma op in async mma pipeline that require RAW sync. + bool fill_async_mma_pipeline_ = false; + + //! Only a single wgmma wait_group to flush async mma pipeline. + bool flush_async_mma_pipeline_ = false; + //! Keep track of expressions that must be followed by syncthreads std::deque> sync_before_; @@ -788,6 +838,9 @@ class ReadAfterWriteSyncs : public kir::ExprMutator { public: static std::vector insert(const std::vector& loop_nests) { ReadAfterWriteSyncs inserter(loop_nests); + NVF_ERROR( + inserter.checkAsyncMmaPipeline(), + "Async mma pipeline should be empty at end of cuda kernel."); return inserter.exprs_; } }; diff --git a/csrc/device_lower/utils.cpp b/csrc/device_lower/utils.cpp index c130dec1c1c..20453025b59 100644 --- a/csrc/device_lower/utils.cpp +++ b/csrc/device_lower/utils.cpp @@ -2022,11 +2022,16 @@ bool allMmaInputsGuardedByMBarrier(const MmaOp* mma) { ir_utils::isCpAsyncBulkLoad(ir_utils::getTv(mma->inB())->definition()); } -std::vector getSyncExprs(AsyncOpType async_type, int64_t keep_stages) { +std::vector getSyncExprs( + AsyncOpType async_type, + int64_t keep_stages, + bool requires_commit) { std::vector sync_exprs; sync_exprs.reserve(2); - auto commit = IrBuilder::create(async_type); - sync_exprs.push_back(commit); + if (requires_commit) { + auto commit = IrBuilder::create(async_type); + sync_exprs.push_back(commit); + } auto wait = IrBuilder::create(async_type, keep_stages); sync_exprs.push_back(wait); return sync_exprs; diff --git a/csrc/device_lower/utils.h b/csrc/device_lower/utils.h index 3781cc40359..cefe23de542 100644 --- a/csrc/device_lower/utils.h +++ b/csrc/device_lower/utils.h @@ -375,7 +375,8 @@ bool allMmaInputsGuardedByMBarrier(const MmaOp* mma); // wgmma.wait_group.sync.aligned std::vector getSyncExprs( AsyncOpType async_type, - int64_t keep_stages = 0); + int64_t keep_stages = 0, + bool requires_commit = true); } // namespace lower_utils diff --git a/csrc/host_ir/executor.cpp b/csrc/host_ir/executor.cpp index 40448bce008..eba71fd6ee9 100644 --- a/csrc/host_ir/executor.cpp +++ b/csrc/host_ir/executor.cpp @@ -68,10 +68,8 @@ void HostIrExecutor::compile(Fusion* fusion) { } } else { std::vector exprs = fusion->exprs(); - DeviceIdxType my_device_idx = communicator_ ? communicator_->deviceId() : 0; for (Expr* e : exprs) { - std::vector communications = - HostIrLower::lower(cloner.clone(e), my_device_idx); + std::vector communications = HostIrLower::lower(cloner.clone(e)); for (auto* communication : communications) { host_ir_container_->pushBackTopLevelExprs(communication); } diff --git a/csrc/host_ir/lower.cpp b/csrc/host_ir/lower.cpp index ae9d40ef111..9d23e3b0f3b 100644 --- a/csrc/host_ir/lower.cpp +++ b/csrc/host_ir/lower.cpp @@ -57,10 +57,6 @@ void lowerToScatter( std::vector& comms) { // we arbitrarily choose the first device of the sender mesh to be the root const DeviceMesh& receiver_mesh = output_tv->getDeviceMesh(); - NVF_ERROR( - receiver_mesh.rank() == 1, - "Gather only supported on a 1D mesh. Given ", - receiver_mesh); auto root = input_tv->getDeviceMesh().at(0); Team team = receiver_mesh.vector(); if (!receiver_mesh.has(root)) { @@ -74,7 +70,7 @@ void lowerToScatter( Adds zero or multiple Gather communications to the vector 'comms' Note that since the root of a Gather collective is a destination, we possibly -need multiple Gathers if the tensor is replicated in the receiver mesh. +need multiple Gather if the tensor is replicated in the receiver mesh. */ void lowerToGather( TensorView* input_tv, @@ -82,10 +78,6 @@ void lowerToGather( std::vector& comms) { // we create as many 'Gathers' as there are devices in the receiver mesh const DeviceMesh& sender_mesh = input_tv->getDeviceMesh(); - NVF_ERROR( - sender_mesh.rank() == 1, - "Currently only lower Gather on a 1D mesh. Given ", - sender_mesh); for (auto root : output_tv->getDeviceMesh().vector()) { Team team = sender_mesh.vector(); if (!sender_mesh.has(root)) { @@ -100,12 +92,10 @@ void lowerToGather( void lowerToAllgather( TensorView* input_tv, TensorView* output_tv, - std::vector& comms, - DeviceIdxType my_device_idx) { - Team team = - input_tv->getDeviceMesh().getSlice(my_device_idx, ParallelType::DIDx); + std::vector& comms) { + const DeviceMesh& mesh = input_tv->getDeviceMesh(); comms.push_back(IrBuilder::create( - CommunicationType::Allgather, output_tv, input_tv, team)); + CommunicationType::Allgather, output_tv, input_tv, mesh.vector())); } // Adds one or zero Broadcast communication to the vector 'comms' @@ -115,8 +105,6 @@ void lowerToBroadcast( DeviceIdxType root, std::vector& comms) { const DeviceMesh& mesh = output_tv->getDeviceMesh(); - NVF_ERROR( - mesh.rank() == 1, "Broadcast only supported a 1D mesh. Given ", mesh); Team team = mesh.vector(); if (!mesh.has(root)) { team.push_back(root); @@ -135,14 +123,6 @@ void lowerToBroadcastOrSendRecv( std::vector& comms) { const DeviceMesh& sender_mesh = input_tv->getDeviceMesh(); const DeviceMesh& receiver_mesh = output_tv->getDeviceMesh(); - NVF_ERROR( - sender_mesh.rank() == 1, - "Broadcast only supported a 1D mesh. Given ", - sender_mesh); - NVF_ERROR( - receiver_mesh.rank() == 1, - "Broadcast only supported a 1D mesh. Given ", - receiver_mesh); if (isSharded(input_tv) && sender_mesh.size() > 1) { // if the inputs and ouputs are parallelized, // we create as many Broadcast as that will be handled in parallel @@ -184,14 +164,6 @@ void lowerToReduce( std::vector& comms) { const DeviceMesh& receiver_mesh = output_tv->getDeviceMesh(); const DeviceMesh& sender_mesh = input_tv->getDeviceMesh(); - NVF_ERROR( - sender_mesh.rank() == 1, - "Reduce only supported a 1D mesh. Given ", - sender_mesh); - NVF_ERROR( - receiver_mesh.rank() == 1, - "Reduce only supported a 1D mesh. Given ", - receiver_mesh); const auto reduce_op_type = getC10dReduceOpType(op_type); // we create as many Reduces as there are devices in the receiver mesh for (auto root : receiver_mesh.vector()) { @@ -213,15 +185,13 @@ void lowerToAllreduce( TensorView* input_tv, TensorView* output_tv, BinaryOpType op_type, - std::vector& comms, - DeviceIdxType my_device_idx) { - Team team = - input_tv->getDeviceMesh().getSlice(my_device_idx, ParallelType::DIDx); + std::vector& comms) { + const DeviceMesh& mesh = input_tv->getDeviceMesh(); comms.push_back(IrBuilder::create( CommunicationType::Allreduce, output_tv, input_tv, - team, + mesh.vector(), /*root=*/-1, getC10dReduceOpType(op_type))); } @@ -230,10 +200,8 @@ void lowerToReduceScatter( TensorView* input_tv, TensorView* output_tv, BinaryOpType op_type, - std::vector& comms, - DeviceIdxType my_device_idx) { - Team team = - input_tv->getDeviceMesh().getSlice(my_device_idx, ParallelType::DIDx); + std::vector& comms) { + const DeviceMesh& mesh = input_tv->getDeviceMesh(); auto reduction_axis = output_tv->getReductionAxis().value(); auto scattered_axis = getShardedLogicalAxis(output_tv, ParallelType::DIDx); // The output tensor is sharded on scattered_axis and needs to be mapped @@ -248,7 +216,7 @@ void lowerToReduceScatter( CommunicationType::ReduceScatter, output_tv, input_tv, - /*team=*/team, + /*team=*/mesh.vector(), /*root=*/-1, getC10dReduceOpType(op_type), scattered_axis)); @@ -265,7 +233,7 @@ void lowerToReduceScatter( sources *) Leverage the topology to ensure that the senders and recerivers are close */ -std::vector HostIrLower::lower(Expr* c, DeviceIdxType my_device_idx) { +std::vector HostIrLower::lower(Expr* c) { FusionGuard fg(c->fusion()); if (c->isOneOf()) { @@ -288,7 +256,7 @@ std::vector HostIrLower::lower(Expr* c, DeviceIdxType my_device_idx) { const DeviceMesh& receiver_mesh = output_tv->getDeviceMesh(); const bool same_mesh = sender_mesh == receiver_mesh; - // Stores whether the I/O has its first axis parallelized on DIDx + // Stores whether the I/O has its first axis parallelized on Didx const bool is_input_sharded = isSharded(input_tv) && sender_mesh.size() > 1; const bool is_output_sharded = isSharded(output_tv) && receiver_mesh.size() > 1; @@ -314,11 +282,11 @@ std::vector HostIrLower::lower(Expr* c, DeviceIdxType my_device_idx) { NVF_ERROR( same_mesh, "ReduceScatter operation must have the same sender and receiver device mesh. " - "Insert a Set operation before or after the reduction to reshard to another device mesh"); - lowerToReduceScatter(input_tv, output_tv, op_type, comms, my_device_idx); + "Insert a Set operation before or after the reduction to reshard ot another device mesh"); + lowerToReduceScatter(input_tv, output_tv, op_type, comms); } else { if (same_mesh) { - lowerToAllreduce(input_tv, output_tv, op_type, comms, my_device_idx); + lowerToAllreduce(input_tv, output_tv, op_type, comms); } else { lowerToReduce(input_tv, output_tv, op_type, comms); } @@ -328,7 +296,7 @@ std::vector HostIrLower::lower(Expr* c, DeviceIdxType my_device_idx) { lowerToScatter(input_tv, output_tv, comms); } else if (is_input_sharded && !is_output_sharded) { if (same_mesh) { - lowerToAllgather(input_tv, output_tv, comms, my_device_idx); + lowerToAllgather(input_tv, output_tv, comms); } else { lowerToGather(input_tv, output_tv, comms); } @@ -538,7 +506,7 @@ std::vector HostIrLower::lowerToCollectiveBasedPipelinedGemmComm( std::unique_ptr HostIrLower::lower( std::unique_ptr fusion, - DeviceIdxType my_device_index) { + int64_t my_device_index) { // Sharding PreSegmenter passes. // Note: passes run before PreSegmenter optimization passes. preseg_passes::OptimizationPass< @@ -596,8 +564,8 @@ std::unique_ptr HostIrLower::lower( NVF_ERROR( group->exprs().size() == 1, "Communication segments must contain only one Expr"); - for (auto* expr : HostIrLower::lower( - ir_cloner.clone(group->exprs().at(0)), my_device_index)) { + for (auto* expr : + HostIrLower::lower(ir_cloner.clone(group->exprs().at(0)))) { // Allocate the recv buffers of communications if (expr->isA()) { auto* communication = expr->as(); diff --git a/csrc/host_ir/lower.h b/csrc/host_ir/lower.h index dc185c9a769..02d120cb734 100644 --- a/csrc/host_ir/lower.h +++ b/csrc/host_ir/lower.h @@ -22,11 +22,11 @@ class HostIrLower { static bool canLower(Expr* expr, bool ignore_inner_resharding = false); // Lower a sharded Expr into a series of Communication. - static std::vector lower(Expr* c, DeviceIdxType my_device_index); + static std::vector lower(Expr* c); static std::unique_ptr lower( std::unique_ptr fusion, - DeviceIdxType my_device_index); + int64_t my_device_index); private: static std::vector lowerToCollectiveBasedPipelinedGemmComm(Expr* expr); diff --git a/csrc/multidevice/device_mesh.cpp b/csrc/multidevice/device_mesh.cpp index 2f23bc83b23..8d8f37f60be 100644 --- a/csrc/multidevice/device_mesh.cpp +++ b/csrc/multidevice/device_mesh.cpp @@ -18,29 +18,12 @@ namespace nvfuser { -DeviceMesh::DeviceMesh( - std::vector devices, - std::vector shape) { - if (shape.empty()) { - shape_ = {(int64_t)devices.size()}; - } else { - int64_t num_devices = - std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<>()); - NVF_ERROR( - (int64_t)devices.size() == num_devices, - "Specified a list of device with ", - devices.size(), - " elements ", - " but shape contains ", - num_devices); - shape_ = std::move(shape); - } +DeviceMesh::DeviceMesh(std::vector devices) { setDevices(std::move(devices)); } DeviceMesh::DeviceMesh(std::initializer_list devices) { setDevices(std::vector(devices)); - shape_ = {(int64_t)vector_.size()}; } void DeviceMesh::setDevices(std::vector devices) { @@ -61,42 +44,8 @@ void DeviceMesh::setDevices(std::vector devices) { return DeviceMesh(devices); } -/*static*/ DeviceMesh DeviceMesh::createForShape( - const std::vector& shape) { - int64_t num_devices = - std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<>()); - std::vector devices(num_devices); - std::iota(devices.begin(), devices.end(), 0); - return DeviceMesh(devices, shape); -} - std::ostream& operator<<(std::ostream& out, const DeviceMesh& mesh) { - out << "DeviceMesh"; - int64_t ndevices = std::accumulate( - mesh.shape().begin(), mesh.shape().end(), 1, std::multiplies<>()); - int64_t ndims = mesh.rank(); - std::vector strides = mesh.shape(); - for (auto i = ndims - 2; i >= 0; --i) { - strides[i] *= strides[i + 1]; - } - - for (auto i = 0; i < ndevices; i++) { - for (auto axis = 0; axis < ndims; axis++) { - if (i % strides[axis] == 0) { - out << "{"; - } - } - out << mesh.vector().at(i); - if ((i + 1) % strides[ndims - 1] != 0) { - out << " "; - } - for (auto axis = 0; axis < ndims; axis++) { - if ((i + 1) % strides[axis] == 0) { - out << "}"; - } - } - } - + out << "DeviceMesh{" << mesh.vector() << "}"; return out; } @@ -107,69 +56,4 @@ int64_t DeviceMesh::size(const ParallelType parallel_type) const { return size(); } -std::vector DeviceMesh::getIndices(const DeviceIdxType device) const { - auto global_idx = idxOf(device); - if (global_idx == -1) { - return {}; - } - std::vector indices(shape_.size()); - int64_t accumulated_size = 1; - for (int64_t i = (int64_t)shape_.size() - 1; i >= 0; i--) { - indices[i] = (global_idx / accumulated_size) % shape_[i]; - accumulated_size *= shape_[i]; - } - return indices; -} - -DeviceIdxType DeviceMesh::maxDeviceId() const { - return *std::max_element(vector_.begin(), vector_.end()); -} - -namespace { -int64_t ptypeToAxis(ParallelType ptype, int64_t ndims) { - NVF_ERROR( - isParallelTypeDeviceDim(ptype), - "Attempting to index into DeviceMesh with a non-device parallel type", - ptype); - int64_t offset = - static_cast(ptype) - static_cast(ParallelType::DIDx); - - NVF_ERROR( - offset < ndims, - "DeviceMesh has ", - ndims, - " dimensions, but requesting slice for ", - ptype); - return ndims - 1 - offset; -} -} // namespace - -std::vector DeviceMesh::getSlice( - DeviceIdxType deviceId, - ParallelType ptype) const { - int64_t axis = ptypeToAxis(ptype, rank()); - auto indices = getIndices(deviceId); - NVF_ERROR( - !indices.empty(), "Device ", deviceId, " is not in DeviceMesh ", vector_); - - int64_t offset = 0; - int64_t stride = 1; - int64_t accumulated_size = 1; - for (auto i = rank() - 1; i >= 0; i--) { - if (i > axis) { - stride *= shape_[i]; - } - if (i != axis) { - offset += indices[i] * accumulated_size; - } - accumulated_size *= shape_[i]; - } - - std::vector devices(shape_[axis]); - for (auto i : c10::irange(devices.size())) { - devices.at(i) = vector_.at(i * stride + offset); - } - return devices; -} - } // namespace nvfuser diff --git a/csrc/multidevice/device_mesh.h b/csrc/multidevice/device_mesh.h index 38c04164337..25f67aa24ad 100644 --- a/csrc/multidevice/device_mesh.h +++ b/csrc/multidevice/device_mesh.h @@ -17,9 +17,9 @@ namespace nvfuser { -// DeviceMesh represents a set of unique devices arranged as a dense -// n-dimensional tensor. DeviceMesh and device parallel types determine -// how a tensorview is sharded among devices. +// The class DeviceMesh represents a set of (unique) devices on which a Pipeline +// Stage will be executed. For now, we only support flat meshes, but later we +// will add support for n-dimensional meshes. class DeviceMesh final { public: // https://google.github.io/styleguide/cppguide.html#Implicit_Conversions @@ -32,39 +32,23 @@ class DeviceMesh final { // There are no such contention for std::initializer_list so I chose to // allow implicit conversion for that. This allows users to write `DeviceMesh // mesh = {1, 2};`, which is more concise. - // When no shape is specified, a 1D DeviceMesh is created by default. - explicit DeviceMesh( - std::vector devices = {}, - std::vector shape = {}); + explicit DeviceMesh(std::vector devices = {}); DeviceMesh(std::initializer_list devices); DeviceMesh(const DeviceMesh&) = default; DeviceMesh(DeviceMesh&&) = default; DeviceMesh& operator=(const DeviceMesh&) = default; DeviceMesh& operator=(DeviceMesh&&) = default; - // Creates a device mesh of [0 ... num_devices-1]. I didn't make it a + // Creates a device mesh of [0 .. num_devices-1]. I didn't make it a // constructor because single-element initializer lists would be directed to // use that instead of the constructor for vectors. static DeviceMesh createForNumDevices(int64_t num_devices); - // Creates a device mesh with the specified shape with devices numbered - // [0 ... num_devices-1]. - static DeviceMesh createForShape(const std::vector& shape); // Returns the number of devices in the mesh int64_t size() const { return static_cast(vector_.size()); } - // Return the size of an axis in the mesh - int64_t size(int64_t axis) const { - return shape_.at(axis); - } - - // Returns the shape of the device mesh - const std::vector& shape() const { - return shape_; - } - int64_t size(ParallelType parallel_type) const; // Returns a vector containing the device indices of the mesh @@ -77,8 +61,7 @@ class DeviceMesh final { return std::find(vector_.begin(), vector_.end(), device) != vector_.end(); } - // Returns the global index of device in the mesh, or -1 if device is not - // present. + // Returns the index of device in the mesh, or -1 if device is not present. int64_t idxOf(const DeviceIdxType device) const { auto it = std::find(vector_.begin(), vector_.end(), device); if (it != vector_.end()) { @@ -87,48 +70,24 @@ class DeviceMesh final { return -1; } - // Returns the indices of a multi-dimensional mesh, or an empty vector - // if device is not present - std::vector getIndices(const DeviceIdxType device) const; - // Returns the device at a particular index in the mesh DeviceIdxType at(int64_t index) const { return vector_.at(index); } - // Returns the rank (number of dimensions) of the mesh. - int64_t rank() const { - return static_cast(shape_.size()); - } - bool operator==(const DeviceMesh& other) const { - return vector_ == other.vector() && shape_ == other.shape(); + return vector_ == other.vector(); } bool operator!=(const DeviceMesh& other) const { - return vector_ != other.vector() || shape_ != other.shape(); + return vector_ != other.vector(); } - // Returns the max device id in the DeviceMesh. - DeviceIdxType maxDeviceId() const; - - // Returns a slice of the DeviceMesh accorinding to the device parallel type - // that contains the device - // Ex: [[0 1 2] - // [3 4 5]] - // getSlice(4, ParallelType::DIDx) = {3, 4, 5} - // getSlice(4, ParallelType::DIDy) = {1, 4} - // TODO: these might be worth caching per TV - std::vector getSlice(DeviceIdxType device, ParallelType ptype) - const; - private: void setDevices(std::vector devices); - // stores the flattened list of device indices + // stores the list of device indices std::vector vector_; - // shape of the device mesh - std::vector shape_; }; std::ostream& operator<<(std::ostream& out, const DeviceMesh& mesh); diff --git a/csrc/multidevice/utils.cpp b/csrc/multidevice/utils.cpp index 07f73f6bc01..96ec87bc37f 100644 --- a/csrc/multidevice/utils.cpp +++ b/csrc/multidevice/utils.cpp @@ -582,7 +582,9 @@ int64_t requestedNumberOfDevices(Fusion* fusion) { DeviceIdxType max_index = 0; for (auto tv : fusion->allTvs()) { if (tv->hasDeviceMesh()) { - max_index = std::max(max_index, tv->getDeviceMesh().maxDeviceId()); + for (auto d_id : tv->getDeviceMesh().vector()) { + max_index = std::max(max_index, d_id); + } } } return static_cast(max_index + 1); @@ -642,4 +644,39 @@ void reorderDIDToFront(TensorView* tv) { tv->reorder(order_map); } +std::unordered_set getTvsWithDifferentSharding( + TensorView* ref, + const std::vector& tvs) { + std::unordered_set ret; + const auto& reference_dom = ref->getLoopDomain(); + FusionGuard fg(ref->fusion()); + auto ca_map = ComputeAtMap(FusionGuard::getCurFusion()); + std::unordered_map concrete_to_reference_map; + for (auto id : reference_dom) { + auto ca_id = + ca_map.getConcreteMappedID(id, IdMappingMode::PERMISSIVE_RESIZE); + concrete_to_reference_map[ca_id] = id; + } + + for (TensorView* tv : tvs) { + if (ref->getDeviceMesh().vector() != tv->getDeviceMesh().vector()) { + ret.insert(tv); + continue; + } + for (auto id : tv->getLoopDomain()) { + auto ca_id = + ca_map.getConcreteMappedID(id, IdMappingMode::PERMISSIVE_RESIZE); + if (concrete_to_reference_map.count(ca_id) > 0) { + auto ref_id = concrete_to_reference_map.at(ca_id); + if ((ref_id->isDeviceDim() || id->isDeviceDim()) && + ref_id->getParallelType() != id->getParallelType()) { + ret.insert(tv); + break; + } + } + } + } + return ret; +} + } // namespace nvfuser diff --git a/csrc/multidevice/utils.h b/csrc/multidevice/utils.h index 45845d83659..050ceb4cc6e 100644 --- a/csrc/multidevice/utils.h +++ b/csrc/multidevice/utils.h @@ -42,41 +42,9 @@ int64_t numDeviceDims(const TensorView*); // Returns the subset of tvs which elements have the different multi-device // sharding as ref -template std::unordered_set getTvsWithDifferentSharding( TensorView* ref, - TvIterator tvs) { - std::unordered_set ret; - const auto& reference_dom = ref->getLoopDomain(); - FusionGuard fg(ref->fusion()); - auto ca_map = ComputeAtMap(FusionGuard::getCurFusion()); - std::unordered_map concrete_to_reference_map; - for (auto id : reference_dom) { - auto ca_id = - ca_map.getConcreteMappedID(id, IdMappingMode::PERMISSIVE_RESIZE); - concrete_to_reference_map[ca_id] = id; - } - - for (TensorView* tv : tvs) { - if (ref->getDeviceMesh().vector() != tv->getDeviceMesh().vector()) { - ret.insert(tv); - continue; - } - for (auto id : tv->getLoopDomain()) { - auto ca_id = - ca_map.getConcreteMappedID(id, IdMappingMode::PERMISSIVE_RESIZE); - if (concrete_to_reference_map.count(ca_id) > 0) { - auto ref_id = concrete_to_reference_map.at(ca_id); - if ((ref_id->isDeviceDim() || id->isDeviceDim()) && - ref_id->getParallelType() != id->getParallelType()) { - ret.insert(tv); - break; - } - } - } - } - return ret; -} + const std::vector& tvs); // Returns whether an Expr embeds multi-device resharding bool isResharding(const Expr* expr); diff --git a/csrc/type.cpp b/csrc/type.cpp index 6edb2ce6e39..b6714f10e15 100644 --- a/csrc/type.cpp +++ b/csrc/type.cpp @@ -710,10 +710,6 @@ static const char* parallel_type2string(ParallelType t) { switch (t) { case ParallelType::DIDx: return "deviceIdx.x"; - case ParallelType::DIDy: - return "deviceIdx.y"; - case ParallelType::DIDz: - return "deviceIdx.z"; case ParallelType::BIDz: return "blockIdx.z"; case ParallelType::BIDy: @@ -1556,8 +1552,7 @@ bool isParallelTypeBlockDim(ParallelType ptype) { } bool isParallelTypeDeviceDim(ParallelType ptype) { - return ptype == ParallelType::DIDx || ptype == ParallelType::DIDy || - ptype == ParallelType::DIDz; + return ptype == ParallelType::DIDx; } bool isParallelTypeThread(ParallelType ptype) { diff --git a/csrc/type.h b/csrc/type.h index 780e0b2a019..008c2fff69e 100644 --- a/csrc/type.h +++ b/csrc/type.h @@ -703,8 +703,6 @@ enum class TernaryOpType { Clamp, Lerp, Threshold, Where, Philox }; enum class ParallelType { DIDx, - DIDy, - DIDz, BIDz, BIDy, BIDx, @@ -743,10 +741,8 @@ static constexpr std::array kParallelTypeTIDs = { ParallelType::TIDy, ParallelType::TIDz}; -static constexpr std::array kParallelTypeDIDs = { - ParallelType::DIDx, - ParallelType::DIDy, - ParallelType::DIDz}; +static constexpr std::array kParallelTypeDIDs = { + ParallelType::DIDx}; enum class MemoryType { Local, Shared, Global, Tensor }; diff --git a/tests/cpp/test_persistent_buffer.cpp b/tests/cpp/test_persistent_buffer.cpp index a08650d712f..0c44b3957a7 100644 --- a/tests/cpp/test_persistent_buffer.cpp +++ b/tests/cpp/test_persistent_buffer.cpp @@ -1472,18 +1472,19 @@ TEST_P(LayerNormSharedMemoryTest, FusionLayerNormSharedMemoryBuffer_CUDA) { constexpr int64_t dim0 = 2048; std::vector input_shape{dim0, hidden_size}; std::vector norm_shape{hidden_size}; - auto input_half = makeContigTensor(2, dtype); - auto weight_half = makeContigTensor(1, dtype); - auto bias_half = makeContigTensor(1, dtype); - fusion.addInput(input_half); - fusion.addInput(weight_half); - fusion.addInput(bias_half); - auto input = castOp(DataType::Float, input_half); - auto weight = castOp(DataType::Float, weight_half); - auto bias = castOp(DataType::Float, bias_half); + + auto input = makeContigTensor(2, dtype); + auto weight = makeContigTensor(1, dtype); + auto bias = makeContigTensor(1, dtype); + fusion.addInput(input); + fusion.addInput(weight); + fusion.addInput(bias); + input = maybeCastOp(DataType::Float, input); + weight = maybeCastOp(DataType::Float, weight); + bias = maybeCastOp(DataType::Float, bias); auto result = layer_norm(input, norm_shape, weight, bias, eps_ptr); - auto result_output = castOp(dtype, result.output); - fusion.addOutput(result_output); + result.output = maybeCastOp(dtype, result.output); + fusion.addOutput(result.output); fusion.addOutput(result.mean); fusion.addOutput(result.invstd); @@ -1534,18 +1535,9 @@ TEST_P(LayerNormSharedMemoryTest, FusionLayerNormSharedMemoryBuffer_CUDA) { auto cg_outputs = executor_cache.runFusionWithInputs(aten_inputs); auto runtime = executor_cache.getMostRecentKernelRuntime(); if (has_enough_regs_smem) { - // For dtype float, no op scheduler is also used. - if (dtype == DataType::Float) { - EXPECT_THAT( - runtime->fusionSegments()->groups(), - UnorderedElementsAre( - HeuristicIs(SchedulerType::NoOp), - HeuristicIs(SchedulerType::InnerPersistent))); - } else { - EXPECT_THAT( - runtime->fusionSegments()->groups(), - UnorderedElementsAre(HeuristicIs(SchedulerType::InnerPersistent))); - } + EXPECT_THAT( + runtime->fusionSegments()->groups(), + UnorderedElementsAre(HeuristicIs(SchedulerType::InnerPersistent))); Fusion* scheduled_fusion = runtime->executors() .back() ->as() diff --git a/tests/cpp/test_sharding.cpp b/tests/cpp/test_sharding.cpp index 211a43e8e59..62d14cf43d5 100644 --- a/tests/cpp/test_sharding.cpp +++ b/tests/cpp/test_sharding.cpp @@ -196,33 +196,6 @@ TEST_P(ShardingTest, ComputeIndex) { testValidate(fusion.get(), outputs, {a_tensor}, __LINE__, __FILE__); } -TEST_F(ShardingTest, MultiDimDeviceMesh) { - DeviceMesh mesh({3, 4, 1, 0, 8, 2}, {2, 3}); - // Shape not consistent with number of devices - EXPECT_ANY_THROW(DeviceMesh({1, 2}, {2, 3})); - // Duplicates in DeviceMesh - EXPECT_ANY_THROW(DeviceMesh({1, 2, 0, 2}, {2, 3})); - - std::vector local_indices_8 = {1, 1}; - std::vector local_indices_1 = {0, 2}; - EXPECT_EQ(mesh.getIndices(8), local_indices_8); - EXPECT_EQ(mesh.getIndices(1), local_indices_1); - - std::vector slice_didx_034 = {3, 4, 1}; - std::vector slice_didy_12 = {1, 2}; - EXPECT_EQ(mesh.getSlice(1, ParallelType::DIDx), slice_didx_034); - EXPECT_EQ(mesh.getSlice(1, ParallelType::DIDy), slice_didy_12); - EXPECT_EQ(mesh.getSlice(2, ParallelType::DIDy), slice_didy_12); - - DeviceMesh mesh3d = DeviceMesh::createForShape({2, 3, 4}); - std::vector slice_didz = {6, 18}; - std::vector slice_didy = {14, 18, 22}; - std::vector slice_didx = {16, 17, 18, 19}; - EXPECT_EQ(mesh3d.getSlice(18, ParallelType::DIDz), slice_didz); - EXPECT_EQ(mesh3d.getSlice(18, ParallelType::DIDy), slice_didy); - EXPECT_EQ(mesh3d.getSlice(18, ParallelType::DIDx), slice_didx); -} - INSTANTIATE_TEST_SUITE_P( , ShardingTest,