Skip to content

Commit

Permalink
Merged PR 1589: Add IOBinding interface.
Browse files Browse the repository at this point in the history
Binding object to bind inputs/outputs. The overall objective is to ensure that the inputs are in the right place before Run() is called.

Adding @<Ke Deng> since he is working on async copies to CUDA.

Related work items: #465
  • Loading branch information
Pranav Sharma authored and Pranav Sharma committed May 26, 2018
1 parent 768d026 commit 37d10d4
Show file tree
Hide file tree
Showing 10 changed files with 501 additions and 9 deletions.
2 changes: 1 addition & 1 deletion cmake/lotus_unittests.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,7 @@ set(onnx_test_lib
)

if(lotus_USE_CUDA)
set_source_files_properties("${LOTUS_ROOT}/test/onnx/runner.cc"
set_source_files_properties("${LOTUS_ROOT}/test/onnx/runner.cc" "${LOTUS_ROOT}/test/framework/inference_session_test.cc"
PROPERTIES
COMPILE_FLAGS "-DUSE_CUDA"
)
Expand Down
67 changes: 67 additions & 0 deletions lotus/core/framework/IOBinding.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
#include "IOBinding.h"
#include "core/graph/graph.h" // for LotusIR::ProviderType
#include "core/common/logging/logging.h"

namespace Lotus {
IOBinding::IOBinding(IExecutionProvider* p_exec_provider,
const Logging::Logger* p_logger)
: p_exec_provider_(p_exec_provider),
p_logger_(p_logger) {
}

Common::Status IOBinding::BindInput(const std::string& name, const MLValue& ml_value) {
if (!ml_value.IsTensor()) {
feeds_.insert({name, ml_value});
return Status::OK();
}

const Tensor& src_tensor = ml_value.Get<Tensor>();
const AllocatorInfo& src_location = src_tensor.Location();
AllocatorPtr alloc = p_exec_provider_->GetAllocator();
const AllocatorInfo& dst_location = alloc->Info();
if (src_location.name == dst_location.name) {
// no need to trigger a copy since the tensor is already at the desired location.
feeds_.insert({name, ml_value});
return Status::OK();
}

// create tensor at the desired location.
auto element_type = src_tensor.DataType();
auto& shape = src_tensor.Shape();
void* buffer = alloc->Alloc(element_type->Size() * shape.Size());
LOTUS_ENFORCE(buffer);
std::unique_ptr<Tensor> dst_tensor = std::make_unique<Tensor>(element_type,
shape,
buffer,
dst_location,
alloc);
Status st = p_exec_provider_->CopyTensor(src_tensor, *dst_tensor.get());
if (!st.IsOK()) {
return st;
}
MLValue dst_mlvalue;
dst_mlvalue.Init(dst_tensor.release(),
DataTypeImpl::GetType<Tensor>(),
DataTypeImpl::GetType<Tensor>()->GetDeleteFunc());
feeds_.insert({name, dst_mlvalue});
return Status::OK();
}

Common::Status IOBinding::SynchronizeInputs() {
return p_exec_provider_->Sync();
}

Common::Status IOBinding::BindOutput(const std::string& name, const MLValue& ml_value) {
output_names_.push_back(name);
outputs_.push_back(ml_value);
return Status::OK();
}

const std::vector<std::string>& IOBinding::GetOutputNames() const {
return output_names_;
}

std::vector<MLValue>& IOBinding::GetOutputs() {
return outputs_;
}
} // namespace Lotus
73 changes: 73 additions & 0 deletions lotus/core/framework/IOBinding.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
#pragma once

#include "core/framework/execution_provider.h"
#include "core/common/status.h"
#include "core/framework/ml_value.h"
#include "core/framework/inference_session.h"
#include "core/common/logging/logging.h"

namespace Lotus {
/**
* Input/Output binding.
* Usage is as follows:
*
* InferenceSession session;
* session.Load();
* session.Initialize();
* ...
* shared_ptr<IOBinding> io_binding;
* session.NewIOBinding("DML", &io_binding);
* io_binding->BindInput(...);
* io_binding->BindInput(...);
* io_binding->SynchronizeInputs();
*
* io_binding->BindOutput(...);
* io_binding->BindOutput(...);
*
* session.Run(io_binding);
*
* vector<MLValue>& outputs = io_binding->GetOutputs();
*/
class IOBinding {
public:
/**
* Call repeatedly to bind as many inputs as required.
* If the input mlvalue is not at the desired location (specified by the execution provider), this will
* copy it to the desired location. This copy may or may not be async. It depends on the exec provider.
* For copying it leverages IExecutionProvider::CopyTensor().
*/
Common::Status BindInput(const std::string& name, const MLValue& ml_value);

/**
* If the BindInput calls are async this function acts as a barrier to ensure all inputs are fully copied
* before you call the Run() method. There is no point calling Run() if you're inputs are not ready at the
* desired location.
* This is a blocking call and is a wrapper over IExecutionProvider::Sync().
* Call InferenceSession::Run() only after calling this method or else you'll end up wasting cycles inside Run().
*/
Common::Status SynchronizeInputs();

/**
* This simply provides the names and optionally allocated output containers.
*/
Common::Status BindOutput(const std::string& name, const MLValue& ml_value);

/**
* This simply collects the outputs obtained after calling Run() inside the @param outputs.
*/
const std::vector<std::string>& GetOutputNames() const;
std::vector<MLValue>& GetOutputs();

private:
friend InferenceSession;

IOBinding(IExecutionProvider* p_exec_provider, const Logging::Logger* p_logger);
IExecutionProvider* p_exec_provider_ = nullptr; // owned by session
std::unordered_map<std::string, MLValue> feeds_;
std::vector<std::string> output_names_;
std::vector<MLValue> outputs_;
const Logging::Logger* p_logger_ = nullptr; // owned by session

LOTUS_DISALLOW_COPY_ASSIGN_AND_MOVE(IOBinding);
};
} // namespace Lotus
9 changes: 9 additions & 0 deletions lotus/core/framework/execution_provider.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,15 @@ class IExecutionProvider {
// Example valid return values are: kCpuExecutionProvider, kCudaExecutionProvider
virtual std::string Type() const = 0;

/**
* Blocks until the device has completed all preceding requested tasks.
* Currently this is primarily used by the IOBinding object to ensure that all inputs have been
* copied to the device before execution begins.
*/
virtual Status Sync() {
return Status::OK();
};

protected:
AllocatorMap allocators_;
};
Expand Down
35 changes: 34 additions & 1 deletion lotus/core/framework/inference_session.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include "core/providers/cpu/cpu_execution_provider.h"
#include "core/framework/op_kernel_abi_wrapper.h"
#include "core/graph/schema_registry.h"
#include "core/framework/IOBinding.h"

namespace Lotus {
class InferenceSession::Impl {
Expand Down Expand Up @@ -296,11 +297,12 @@ class InferenceSession::Impl {
return Run(run_options, feeds, output_names, p_fetches);
}

Common::Status Run(const RunOptions& run_options,
Common::Status Run(const RunOptions& run_options0,
const NameMLValMap& feeds,
const std::vector<std::string>& output_names,
std::vector<MLValue>* p_fetches) {
Common::Status retval;
const RunOptions run_options(run_options0);
try {
{
std::lock_guard<std::mutex> l(session_mutex_);
Expand Down Expand Up @@ -383,6 +385,26 @@ class InferenceSession::Impl {
return std::make_pair(Common::Status::OK(), &output_def_list_);
}

Common::Status NewIOBinding(LotusIR::ProviderType provider_type, std::unique_ptr<IOBinding>* io_binding) {
IExecutionProvider* p_exec_provider = session_state_.GetExecutionProvider(provider_type);
if (!p_exec_provider) {
return Status(LOTUS, FAIL, "You did not register this execution provider before.");
}
*io_binding = std::unique_ptr<IOBinding>(new IOBinding(p_exec_provider, session_logger_)); // private constructor, can't use make_unique
return Status::OK();
}

Common::Status Run(const RunOptions& run_options, IOBinding& io_binding) {
// TODO should Run() call io_binding.SynchronizeInputs() or should it let the callers do it?
// io_binding.SynchronizeInputs();
return Run(run_options, io_binding.feeds_, io_binding.output_names_, &io_binding.outputs_);
}

Common::Status Run(IOBinding& io_binding) {
RunOptions run_options;
return Run(run_options, io_binding);
}

private:
// assumes model has already been loaded before
Common::Status DoPostLoadProcessing(LotusIR::Model& model) {
Expand Down Expand Up @@ -847,4 +869,15 @@ Common::Status InferenceSession::RegisterCustomOpSet(std::vector<OpSchema>& sche
return impl_->RegisterCustomOpSet(schemas, domain, version);
}

Common::Status InferenceSession::NewIOBinding(LotusIR::ProviderType provider_type, std::unique_ptr<IOBinding>* io_binding) {
return impl_->NewIOBinding(provider_type, io_binding);
}

Common::Status InferenceSession::Run(const RunOptions& run_options, IOBinding& io_binding) {
return impl_->Run(run_options, io_binding);
}

Common::Status InferenceSession::Run(IOBinding& io_binding) {
return impl_->Run(io_binding);
}
} // namespace Lotus
11 changes: 11 additions & 0 deletions lotus/core/framework/inference_session.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ class OpSchema;
namespace Lotus {
class IExecutionProvider; // forward decl
class KernelDefBuilder;
class IOBinding;

enum class AllocationPlannerType {
SIMPLE_SEQUENTIAL_PLANNER,
Expand Down Expand Up @@ -205,6 +206,16 @@ class InferenceSession {
const std::vector<std::string>& output_names,
std::vector<MLValue>* p_fetches);

/**
* Creates a new binding object for binding inputs and outputs.
* @param provider_type specifies the location where the inputs need to be potentially copied. See IOBinding class
* for more info.
*/
Common::Status NewIOBinding(const std::string& provider_type, std::unique_ptr<IOBinding>* io_binding);

Common::Status Run(const RunOptions& run_options, IOBinding& io_binding);
Common::Status Run(IOBinding& io_binding);

/**
* TEST ONLY: This API exists to facilitate testing only since today the ONNX model
* input/outputs don't have names. Issue: https://github.com/onnx/onnx/issues/679.
Expand Down
5 changes: 5 additions & 0 deletions lotus/core/providers/cuda/cuda_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,11 @@ CUDAExecutionProvider::CUDAExecutionProvider(const CUDAExecutionProviderInfo& in
}
}

Common::Status CUDAExecutionProvider::Sync() {
bool status = CUDA_CALL(cudaDeviceSynchronize());
return status ? Status::OK() : Status(LOTUS, FAIL, "Sync failed.");
}

Status CUDAExecutionProvider::CopyTensor(const Tensor& src, Tensor& dst) const {
if (src.Shape().Size() != dst.Shape().Size()) {
return Status(LOTUS, FAIL, "Tensor size mismatch");
Expand Down
2 changes: 2 additions & 0 deletions lotus/core/providers/cuda/cuda_execution_provider.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ class CUDAExecutionProvider : public IExecutionProvider {
return LotusIR::kCudaExecutionProvider;
}

Common::Status Sync() override;

Status CopyTensor(const Tensor& src, Tensor& dst) const override;

virtual const void* GetExecutionHandle() const noexcept override {
Expand Down
Loading

0 comments on commit 37d10d4

Please sign in to comment.