Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

parallelize SPMD inputhandler and GetDataShards #5447

Merged
merged 2 commits into from
Aug 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions torch_xla/csrc/runtime/computation_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,9 @@ class ComputationClient {
// wrapped inside a vector.
virtual std::vector<DataPtr> GetDataShards(DataPtr data) = 0;

// Returns data shard at a given index.
virtual DataPtr GetDataShard(DataPtr data, size_t index) = 0;

// Returns wrapped data shards as PjRtShardedData.
virtual DataPtr WrapDataShards(const std::vector<DataPtr>& shards,
std::string device, xla::Shape shape,
Expand Down
114 changes: 77 additions & 37 deletions torch_xla/csrc/runtime/pjrt_computation_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,8 @@ ComputationClient::DataPtr PjRtComputationClient::CreateDataPlaceholder(

std::vector<ComputationClient::DataPtr> PjRtComputationClient::GetDataShards(
ComputationClient::DataPtr data) {
tsl::profiler::TraceMe activity("PjRtComputationClient::GetDataShards",
tsl::profiler::TraceMeLevel::kInfo);
std::vector<ComputationClient::DataPtr> shards;
if (PjRtShardedData* sharded_data =
dynamic_cast<PjRtShardedData*>(data.get())) {
Expand All @@ -208,6 +210,23 @@ std::vector<ComputationClient::DataPtr> PjRtComputationClient::GetDataShards(
return shards;
}

ComputationClient::DataPtr PjRtComputationClient::GetDataShard(
ComputationClient::DataPtr data, size_t index) {
tsl::profiler::TraceMe activity("PjRtComputationClient::GetDataShard",
tsl::profiler::TraceMeLevel::kInfo);
if (PjRtShardedData* sharded_data =
dynamic_cast<PjRtShardedData*>(data.get())) {
XLA_CHECK_LE(index, sharded_data->shards.size())
<< "GetDataShard out of range with index: " << index
<< " and num of shard: " << sharded_data->shards.size();
std::shared_ptr<PjRtData> shard = sharded_data->shards[index];
return std::make_shared<PjRtData>(shard->device(), shard->shape(),
shard->buffer);
} else {
return data;
}
}

ComputationClient::DataPtr PjRtComputationClient::WrapDataShards(
const std::vector<DataPtr>& shards, std::string device, xla::Shape shape,
xla::OpSharding sharding) {
Expand Down Expand Up @@ -603,22 +622,32 @@ PjRtComputationClient::ExecuteReplicated(
XLA_CHECK(devices.size() == arguments.size())
<< "ExecuteReplicated over " << devices.size() << " devices, but "
<< arguments.size() << " arguments devices.";

std::vector<std::vector<xla::PjRtBuffer*>> argument_handles;
for (int32_t i = 0; i < devices.size(); ++i) {
xla::PjRtDevice* pjrt_device = StringToPjRtDevice(devices[i]);
XLA_CHECK(pjrt_device->IsAddressable()) << pjrt_device->DebugString();

std::vector<xla::PjRtBuffer*> buffers;
for (auto& argument : arguments[i]) {
const PjRtData* pjrt_data = dynamic_cast<PjRtData*>(argument.get());

XLA_CHECK(pjrt_device == pjrt_data->buffer->device())
<< pjrt_device->DebugString() << " vs "
<< pjrt_data->buffer->device()->DebugString();
buffers.push_back(pjrt_data->buffer.get());
auto mwait_argument = std::make_shared<util::MultiWait>(devices.size());
std::vector<std::vector<xla::PjRtBuffer*>> argument_handles(devices.size());
{
tsl::profiler::TraceMe activity(
"PjRtComputationClient::ExecuteReplicated_argument_handle",
tsl::profiler::TraceMeLevel::kInfo);
for (int32_t i = 0; i < devices.size(); ++i) {
auto buffer_converter = [&, i]() {
xla::PjRtDevice* pjrt_device = StringToPjRtDevice(devices[i]);
XLA_CHECK(pjrt_device->IsAddressable()) << pjrt_device->DebugString();

std::vector<xla::PjRtBuffer*> buffers;
for (auto& argument : arguments[i]) {
const PjRtData* pjrt_data = dynamic_cast<PjRtData*>(argument.get());

XLA_CHECK(pjrt_device == pjrt_data->buffer->device())
<< pjrt_device->DebugString() << " vs "
<< pjrt_data->buffer->device()->DebugString();
buffers.push_back(pjrt_data->buffer.get());
}
argument_handles[i] = std::move(buffers);
};
env::ScheduleIoClosure(util::MultiWait::Completer(
mwait_argument, std::move(buffer_converter)));
}
argument_handles.push_back(buffers);
mwait_argument->Wait();
}

xla::ExecuteOptions execute_options;
Expand All @@ -632,34 +661,45 @@ PjRtComputationClient::ExecuteReplicated(

std::optional<std::vector<xla::PjRtFuture<xla::Status>>> returned_futures(
devices.size());
std::vector<std::vector<std::unique_ptr<xla::PjRtBuffer>>> results =
pjrt_computation.executable
->Execute(argument_handles, execute_options, returned_futures)
.value();
std::vector<std::vector<std::unique_ptr<xla::PjRtBuffer>>> results;
{
tsl::profiler::TraceMe activity(
"PjRtComputationClient::ExecuteReplicated_execute",
tsl::profiler::TraceMeLevel::kInfo);
results = pjrt_computation.executable
->Execute(std::move(argument_handles), execute_options,
returned_futures)
.value();
}

std::vector<std::vector<ComputationClient::DataPtr>> data_handles;
data_handles.reserve(results.size());
std::vector<size_t> dims(results.size());

for (int32_t i = 0; i < results.size(); ++i) {
xla::PjRtDevice* pjrt_device = StringToPjRtDevice(devices[i]);
XLA_CHECK(pjrt_device->IsAddressable())
<< pjrt_device->DebugString() << " is not addressable.";

std::vector<ComputationClient::DataPtr> datas;
datas.reserve(results[i].size());
dims[i] = results[i].size();
for (int32_t j = 0; j < results[i].size(); ++j) {
std::unique_ptr<xla::PjRtBuffer> buffer = std::move(results[i][j]);
XLA_CHECK(pjrt_device == buffer->device())
<< "Exepcted device: " << pjrt_device->DebugString()
<< " vs. actual device: " << buffer->device()->DebugString();

std::shared_ptr<PjRtData> data = std::make_shared<PjRtData>(
devices[i], buffer->on_device_shape(), std::move(buffer));
datas.push_back(data);
{
tsl::profiler::TraceMe activity(
"PjRtComputationClient::ExecuteReplicated_result_handle",
tsl::profiler::TraceMeLevel::kInfo);
for (int32_t i = 0; i < results.size(); ++i) {
xla::PjRtDevice* pjrt_device = StringToPjRtDevice(devices[i]);
XLA_CHECK(pjrt_device->IsAddressable())
<< pjrt_device->DebugString() << " is not addressable.";

std::vector<ComputationClient::DataPtr> datas;
datas.reserve(results[i].size());
dims[i] = results[i].size();
for (int32_t j = 0; j < results[i].size(); ++j) {
std::unique_ptr<xla::PjRtBuffer> buffer = std::move(results[i][j]);
XLA_CHECK(pjrt_device == buffer->device())
<< "Exepcted device: " << pjrt_device->DebugString()
<< " vs. actual device: " << buffer->device()->DebugString();

std::shared_ptr<PjRtData> data = std::make_shared<PjRtData>(
devices[i], buffer->on_device_shape(), std::move(buffer));
datas.push_back(data);
}
data_handles.push_back(datas);
}
data_handles.push_back(datas);
}

auto mwait = std::make_shared<util::MultiWait>(1);
Expand Down
2 changes: 2 additions & 0 deletions torch_xla/csrc/runtime/pjrt_computation_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ class PjRtComputationClient : public ComputationClient {

std::vector<DataPtr> GetDataShards(DataPtr data) override;

DataPtr GetDataShard(DataPtr data, size_t index) override;

DataPtr WrapDataShards(const std::vector<DataPtr>& shards, std::string device,
xla::Shape shape, xla::OpSharding sharding) override;

Expand Down
29 changes: 24 additions & 5 deletions torch_xla/csrc/xla_graph_executor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -588,15 +588,22 @@ XLAGraphExecutor::SyncTensorCollection XLAGraphExecutor::CollectSyncTensors(
}

void XLAGraphExecutor::TensorCollectionBarrier(SyncTensorCollection* coll) {
tsl::profiler::TraceMe activity("TensorCollectionBarrier",
tsl::profiler::TraceMeLevel::kInfo);
TF_VLOG(4) << "waiting barrier for device " << coll->device.toString()
<< " start";
torch::lazy::LazyGraphExecutor::TensorCollectionBarrier(coll);
// TODO(yeounoh) lock SPMD device
TF_VLOG(4) << "waiting barrier for device " << coll->device.toString();
TF_VLOG(4) << "waiting barrier for device " << coll->device.toString()
<< " done";
}

std::vector<torch::lazy::BackendDataPtr>
XLAGraphExecutor::ExecuteComputationWithBarrier(
torch::lazy::hash_t hash, const std::vector<at::IValue>& graph_inputs,
const torch::lazy::BackendDevice& device) {
tsl::profiler::TraceMe activity("ExecuteComputationWithBarrier",
tsl::profiler::TraceMeLevel::kInfo);
MaybeDumpGraph("dynamo", hash);
auto cachedComputation =
XLAGraphExecutor::Get()->GetComputationCache()->Get(hash);
Expand Down Expand Up @@ -624,7 +631,13 @@ XLAGraphExecutor::ExecuteComputationWithBarrier(

SyncTensorCollection coll;
coll.device = device;
coll.unlocker = DeviceLockerArena::Get()->LockDevices({device});
{
tsl::profiler::TraceMe activity("DeviceBarrier",
tsl::profiler::TraceMeLevel::kInfo);
TF_VLOG(5) << "Lock device " << device.toString() << "...";
coll.unlocker = DeviceLockerArena::Get()->LockDevices({device});
TF_VLOG(5) << "Locking device " << device.toString() << " Done!";
}
std::vector<torch::lazy::BackendDataPtr> arguments;
{
// GetXlaData must be called within a lock region, otherwise it might
Expand Down Expand Up @@ -659,6 +672,8 @@ XLAGraphExecutor::ExecuteComputationWithBarrier(

auto syncfn = [async, hash, sharding_specs]() {
try {
tsl::profiler::TraceMe activity("ExecuteComputationWithBarrier_syncfn",
tsl::profiler::TraceMeLevel::kInfo);
TF_VLOG(3) << "Executing Dynamo IR graph hash "
<< torch::lazy::HashToString(hash) << " on device "
<< async->device << " ...";
Expand Down Expand Up @@ -695,9 +710,13 @@ XLAGraphExecutor::ExecuteComputationWithBarrier(
}

// Updating placeholder with actual output handle.
for (size_t i = 0; i < results.size(); ++i) {
XLA_CHECK(async->tensors_data[i] != nullptr);
async->tensors_data[i]->Assign(*results[i]);
{
tsl::profiler::TraceMe activity("update_placeholder",
tsl::profiler::TraceMeLevel::kInfo);
for (size_t i = 0; i < results.size(); ++i) {
XLA_CHECK(async->tensors_data[i] != nullptr);
async->tensors_data[i]->Assign(*results[i]);
}
}
} catch (...) {
// There are two paths of discovery of an exception happening on an
Expand Down
35 changes: 24 additions & 11 deletions torch_xla/csrc/xla_sharding_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,12 @@
#include "torch_xla/csrc/device.h"
#include "torch_xla/csrc/helpers.h"
#include "torch_xla/csrc/ops/device_data.h"
#include "torch_xla/csrc/runtime/multi_wait.h"
#include "torch_xla/csrc/runtime/runtime.h"
#include "torch_xla/csrc/runtime/thread_pool.h"
#include "torch_xla/csrc/tensor.h"
#include "torch_xla/csrc/tensor_util.h"
#include "tsl/profiler/lib/traceme.h"
#include "xla/execution_options_util.h"
#include "xla/hlo/ir/hlo_module.h"
#include "xla/protobuf_util.h"
Expand Down Expand Up @@ -306,6 +309,8 @@ std::vector<std::vector<runtime::ComputationClient::DataPtr>>
ShardingUtil::InputHandler(
std::vector<runtime::ComputationClient::DataPtr> arguments,
std::vector<std::string> devices) {
tsl::profiler::TraceMe activity("InputHandler",
tsl::profiler::TraceMeLevel::kInfo);
std::vector<std::vector<runtime::ComputationClient::DataPtr>>
arguments_by_device(
devices.size(),
Expand All @@ -314,18 +319,24 @@ ShardingUtil::InputHandler(
// the first local index with the first global device ordinal.
auto device_index = build_index_map(devices);

for (int64_t argument_i = 0; argument_i < arguments.size(); ++argument_i) {
auto shards =
runtime::GetComputationClient()->GetDataShards(arguments[argument_i]);
// With SPMD execution, all input is distributed across addressable devices,
// either by sharding or replication.
for (auto shard : shards) {
int global_ordinal = ParseDeviceString(shard->device()).ordinal();
int device_i = device_index[global_ordinal];
arguments_by_device[device_i][argument_i] = shard;
}
}
auto mwait = std::make_shared<runtime::util::MultiWait>(devices.size());

for (int i = 0; i < devices.size(); i++) {
auto argument_setter = [&, i]() {
for (int64_t argument_i = 0; argument_i < arguments.size();
++argument_i) {
runtime::ComputationClient::DataPtr shard =
runtime::GetComputationClient()->GetDataShard(arguments[argument_i],
i);
int global_ordinal = ParseDeviceString(shard->device()).ordinal();
int device_i = device_index[global_ordinal];
arguments_by_device[device_i][argument_i] = shard;
}
};
runtime::env::ScheduleIoClosure(
runtime::util::MultiWait::Completer(mwait, std::move(argument_setter)));
}
mwait->Wait();
return arguments_by_device;
}

Expand All @@ -334,6 +345,8 @@ std::vector<runtime::ComputationClient::DataPtr> ShardingUtil::OutputHandler(
sharded_results,
std::vector<XLATensor::ShardingSpecPtr> sharding_specs,
bool replicated_output) {
tsl::profiler::TraceMe activity("OutputHandler",
tsl::profiler::TraceMeLevel::kInfo);
std::vector<runtime::ComputationClient::DataPtr> outputs;
outputs.reserve(sharding_specs.size());
for (int i = 0; i < sharding_specs.size(); ++i) {
Expand Down