-
Notifications
You must be signed in to change notification settings - Fork 3.1k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merged PR 1589: Add IOBinding interface.
Binding object to bind inputs/outputs. The overall objective is to ensure that the inputs are in the right place before Run() is called. Adding @<Ke Deng> since he is working on async copies to CUDA. Related work items: #465
- Loading branch information
Pranav Sharma
authored and
Pranav Sharma
committed
May 26, 2018
1 parent
768d026
commit 37d10d4
Showing
10 changed files
with
501 additions
and
9 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
#include "IOBinding.h" | ||
#include "core/graph/graph.h" // for LotusIR::ProviderType | ||
#include "core/common/logging/logging.h" | ||
|
||
namespace Lotus { | ||
IOBinding::IOBinding(IExecutionProvider* p_exec_provider, | ||
const Logging::Logger* p_logger) | ||
: p_exec_provider_(p_exec_provider), | ||
p_logger_(p_logger) { | ||
} | ||
|
||
Common::Status IOBinding::BindInput(const std::string& name, const MLValue& ml_value) { | ||
if (!ml_value.IsTensor()) { | ||
feeds_.insert({name, ml_value}); | ||
return Status::OK(); | ||
} | ||
|
||
const Tensor& src_tensor = ml_value.Get<Tensor>(); | ||
const AllocatorInfo& src_location = src_tensor.Location(); | ||
AllocatorPtr alloc = p_exec_provider_->GetAllocator(); | ||
const AllocatorInfo& dst_location = alloc->Info(); | ||
if (src_location.name == dst_location.name) { | ||
// no need to trigger a copy since the tensor is already at the desired location. | ||
feeds_.insert({name, ml_value}); | ||
return Status::OK(); | ||
} | ||
|
||
// create tensor at the desired location. | ||
auto element_type = src_tensor.DataType(); | ||
auto& shape = src_tensor.Shape(); | ||
void* buffer = alloc->Alloc(element_type->Size() * shape.Size()); | ||
LOTUS_ENFORCE(buffer); | ||
std::unique_ptr<Tensor> dst_tensor = std::make_unique<Tensor>(element_type, | ||
shape, | ||
buffer, | ||
dst_location, | ||
alloc); | ||
Status st = p_exec_provider_->CopyTensor(src_tensor, *dst_tensor.get()); | ||
if (!st.IsOK()) { | ||
return st; | ||
} | ||
MLValue dst_mlvalue; | ||
dst_mlvalue.Init(dst_tensor.release(), | ||
DataTypeImpl::GetType<Tensor>(), | ||
DataTypeImpl::GetType<Tensor>()->GetDeleteFunc()); | ||
feeds_.insert({name, dst_mlvalue}); | ||
return Status::OK(); | ||
} | ||
|
||
Common::Status IOBinding::SynchronizeInputs() { | ||
return p_exec_provider_->Sync(); | ||
} | ||
|
||
Common::Status IOBinding::BindOutput(const std::string& name, const MLValue& ml_value) { | ||
output_names_.push_back(name); | ||
outputs_.push_back(ml_value); | ||
return Status::OK(); | ||
} | ||
|
||
const std::vector<std::string>& IOBinding::GetOutputNames() const { | ||
return output_names_; | ||
} | ||
|
||
std::vector<MLValue>& IOBinding::GetOutputs() { | ||
return outputs_; | ||
} | ||
} // namespace Lotus |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
#pragma once | ||
|
||
#include "core/framework/execution_provider.h" | ||
#include "core/common/status.h" | ||
#include "core/framework/ml_value.h" | ||
#include "core/framework/inference_session.h" | ||
#include "core/common/logging/logging.h" | ||
|
||
namespace Lotus { | ||
/** | ||
* Input/Output binding. | ||
* Usage is as follows: | ||
* | ||
* InferenceSession session; | ||
* session.Load(); | ||
* session.Initialize(); | ||
* ... | ||
* shared_ptr<IOBinding> io_binding; | ||
* session.NewIOBinding("DML", &io_binding); | ||
* io_binding->BindInput(...); | ||
* io_binding->BindInput(...); | ||
* io_binding->SynchronizeInputs(); | ||
* | ||
* io_binding->BindOutput(...); | ||
* io_binding->BindOutput(...); | ||
* | ||
* session.Run(io_binding); | ||
* | ||
* vector<MLValue>& outputs = io_binding->GetOutputs(); | ||
*/ | ||
class IOBinding { | ||
public: | ||
/** | ||
* Call repeatedly to bind as many inputs as required. | ||
* If the input mlvalue is not at the desired location (specified by the execution provider), this will | ||
* copy it to the desired location. This copy may or may not be async. It depends on the exec provider. | ||
* For copying it leverages IExecutionProvider::CopyTensor(). | ||
*/ | ||
Common::Status BindInput(const std::string& name, const MLValue& ml_value); | ||
|
||
/** | ||
* If the BindInput calls are async this function acts as a barrier to ensure all inputs are fully copied | ||
* before you call the Run() method. There is no point calling Run() if you're inputs are not ready at the | ||
* desired location. | ||
* This is a blocking call and is a wrapper over IExecutionProvider::Sync(). | ||
* Call InferenceSession::Run() only after calling this method or else you'll end up wasting cycles inside Run(). | ||
*/ | ||
Common::Status SynchronizeInputs(); | ||
|
||
/** | ||
* This simply provides the names and optionally allocated output containers. | ||
*/ | ||
Common::Status BindOutput(const std::string& name, const MLValue& ml_value); | ||
|
||
/** | ||
* This simply collects the outputs obtained after calling Run() inside the @param outputs. | ||
*/ | ||
const std::vector<std::string>& GetOutputNames() const; | ||
std::vector<MLValue>& GetOutputs(); | ||
|
||
private: | ||
friend InferenceSession; | ||
|
||
IOBinding(IExecutionProvider* p_exec_provider, const Logging::Logger* p_logger); | ||
IExecutionProvider* p_exec_provider_ = nullptr; // owned by session | ||
std::unordered_map<std::string, MLValue> feeds_; | ||
std::vector<std::string> output_names_; | ||
std::vector<MLValue> outputs_; | ||
const Logging::Logger* p_logger_ = nullptr; // owned by session | ||
|
||
LOTUS_DISALLOW_COPY_ASSIGN_AND_MOVE(IOBinding); | ||
}; | ||
} // namespace Lotus |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.