Skip to content

Commit

Permalink
Introduce DeprecatedTypeProperties class (#17991)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch/pytorch#17991

changes:
-Breaks bc: Tensor::type() now returns DeprecatedTypeProperties& rather than Type&.
-Added DeprecatedTypeProperties, it serves as a temporary replacement for Type as the return value of Tensor::type(). This contributes to making Type just for dispatch purposes so that we can make it dtype agnostic.
-Tensor::dispatch_type() now returns Type& like Tensor::type() used to do.
-Changed callsites of Tensor::type() appropriately.

Reviewed By: ezyang

Differential Revision: D14443117

fbshipit-source-id: 239ccb7a09626279a71d1a37f8f82e7f57bf7d9e
  • Loading branch information
Roy Li authored and facebook-github-bot committed Apr 4, 2019
1 parent 2fa8477 commit 5166578
Show file tree
Hide file tree
Showing 36 changed files with 637 additions and 501 deletions.
6 changes: 3 additions & 3 deletions aten/src/ATen/DLConvertor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,10 @@ static DLDataType getDLDataType(const Tensor& t) {
return dtype;
}

static DLContext getDLContext(const Type& type, const int64_t& device_id) {
static DLContext getDLContext(const Tensor& tensor, const int64_t& device_id) {
DLContext ctx;
ctx.device_id = device_id;
if (type.is_cuda()) {
if (tensor.is_cuda()) {
ctx.device_type = DLDeviceType::kDLGPU;
} else {
ctx.device_type = DLDeviceType::kDLCPU;
Expand Down Expand Up @@ -161,7 +161,7 @@ DLManagedTensor* toDLPack(const Tensor& src) {
if (src.is_cuda()) {
device_id = src.get_device();
}
atDLMTensor->tensor.dl_tensor.ctx = getDLContext(src.type(), device_id);
atDLMTensor->tensor.dl_tensor.ctx = getDLContext(src, device_id);
atDLMTensor->tensor.dl_tensor.ndim = src.dim();
atDLMTensor->tensor.dl_tensor.dtype = getDLDataType(src);
atDLMTensor->tensor.dl_tensor.shape =
Expand Down
4 changes: 2 additions & 2 deletions aten/src/ATen/Dispatch.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,9 @@ inline at::ScalarType scalar_type(at::ScalarType s) {
return s;
}

C10_DEPRECATED_MESSAGE("passing at::Type to an AT_DISPATCH macro is deprecated, " \
C10_DEPRECATED_MESSAGE("passing at::DeprecatedTypeProperties to an AT_DISPATCH macro is deprecated, " \
"pass an at::ScalarType instead")
inline at::ScalarType scalar_type(const at::Type &t) {
inline at::ScalarType scalar_type(const at::DeprecatedTypeProperties &t) {
return t.scalarType();
}

Expand Down
3 changes: 2 additions & 1 deletion aten/src/ATen/SparseTensorImpl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,8 @@ void SparseTensorImpl::set_indices_and_values_unsafe(const Tensor& indices, cons
AT_CHECK(!indices.is_sparse(), "expected indices to be a dense tensor, but got indices of layout ", indices.layout());
AT_CHECK(!values.is_sparse(), "expected values to be a dense tensor, but got values of layout ", values.layout());

AT_CHECK(values.type().toSparse() == legacyTensorType(*this), "values type must match sparse tensor type");
AT_CHECK(values.device().type() == device().type(), "device type of values (", values.device().type(), ") must match device type of device().type()", device().type(), ")");
AT_CHECK(values.scalar_type() == typeMetaToScalarType(dtype()), "dtype of values (", values.scalar_type(), ") must match dtype of sparse tensor (", typeMetaToScalarType(dtype()), ")");
AT_CHECK(indices.scalar_type() == kLong, "indices must be an int64 tensor");
AT_CHECK(indices.type().backend() == values.type().backend(), "backend of indices (", indices.type().backend(), ") must match backend of values (", values.type().backend(), ")");
AT_CHECK(!indices.is_cuda() || indices.get_device() == values.get_device(), "device of indices (", indices.get_device(), ") must match device of values (", values.get_device(), ")");
Expand Down
7 changes: 5 additions & 2 deletions aten/src/ATen/SparseTensorUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,10 @@ inline void alias_into_sparse(const SparseTensor& self, const LongTensor& indice
// Take indices and values and makes a (data) copy of them to put into the sparse
// indices/values. This used to be called THSTensor_(_set)
inline void copy_into_sparse(const SparseTensor& self, const LongTensor& indices, const Tensor& values, bool non_blocking) {
alias_into_sparse(self, self._indices().type().copy(indices, non_blocking), self._values().type().copy(values, non_blocking));
alias_into_sparse(
self,
self._indices().dispatch_type().copy(indices, non_blocking),
self._values().dispatch_type().copy(values, non_blocking));
}

// TODO: put this into the public API
Expand Down Expand Up @@ -82,7 +85,7 @@ inline LongTensor flatten_indices(const Tensor& indices, IntArrayRef full_size,
indices_mult_cpu_vec[i] = mult;
mult *= full_size[i];
}
auto indices_mult_cpu = indices.type().cpu()
auto indices_mult_cpu = indices.dispatch_type().cpu()
.tensorFromBlob(indices_mult_cpu_vec.data(), /*size=*/{sparse_dim, 1});
// NB: must be blocking because this blob may be freed after this closure,
// and non_blocking copy will see garbage.
Expand Down
67 changes: 67 additions & 0 deletions aten/src/ATen/core/DeprecatedTypeProperties.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
#pragma once

#include <c10/core/Backend.h>
#include <c10/core/ScalarType.h>
#include <c10/core/Layout.h>



namespace at {

// This class specifies a Backend and a ScalarType. Currently, it primarily
// serves as a replacement return value for Tensor::type(). Previously,
// Tensor::type() returned Type&, but we are changing Type to not be
// dtype-specific.
class DeprecatedTypeProperties {
public:
DeprecatedTypeProperties(Backend backend, ScalarType scalar_type)
: backend_(backend), scalar_type_(scalar_type) {}

Backend backend() const {
return backend_;
}

bool is_sparse() const {
return layout_from_backend(backend()) == kSparse;
}

DeviceType device_type() const {
return backendToDeviceType(backend_);
}

bool is_cuda() const {
return backendToDeviceType(backend_) == kCUDA;
}

ScalarType scalarType() const {
return scalar_type_;
}

caffe2::TypeMeta typeMeta() const {
return scalarTypeToTypeMeta(scalar_type_);
}

bool is_defined() const {
return backend_ != Backend::Undefined && scalar_type_ != ScalarType::Undefined;
}

bool operator==(const DeprecatedTypeProperties& other) const {
return backend_ == other.backend() && scalar_type_ == other.scalarType();
}

bool operator!=(const DeprecatedTypeProperties& other) const {
return !(*this == other);
}

std::string toString() const {
std::stringstream ss;
ss << at::toString(backend()) << at::toString(scalarType()) << "Type";
return ss.str();
}

private:
Backend backend_;
ScalarType scalar_type_;
};

} // namespace at
12 changes: 12 additions & 0 deletions aten/src/ATen/core/DeprecatedTypePropertiesRegistry.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
#include <ATen/core/DeprecatedTypePropertiesRegistry.h>

namespace at {

// TODO: This could be bad juju if someone calls globalContext() in the
// destructor of an object with static lifetime.
DeprecatedTypePropertiesRegistry & globalDeprecatedTypePropertiesRegistry() {
static DeprecatedTypePropertiesRegistry singleton;
return singleton;
}

}
46 changes: 46 additions & 0 deletions aten/src/ATen/core/DeprecatedTypePropertiesRegistry.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
#pragma once

// In order to preserve bc, we make DeprecatedTypeProperties instances unique
// just like they are for Type.

#include <c10/core/Backend.h>
#include <c10/core/ScalarType.h>
#include <ATen/core/DeprecatedTypeProperties.h>

namespace at {

struct CAFFE2_API DeprecatedTypePropertiesDeleter {
void operator()(DeprecatedTypeProperties * ptr) {
delete ptr;
}
};

class CAFFE2_API DeprecatedTypePropertiesRegistry {
public:
using DeprecatedTypePropertiesUniquePtr =
std::unique_ptr<DeprecatedTypeProperties, DeprecatedTypePropertiesDeleter>;

DeprecatedTypePropertiesRegistry() {
for (int b = 0; b < static_cast<int>(Backend::NumOptions); ++b) {
for (int s = 0; s < static_cast<int>(ScalarType::NumOptions); ++s) {
registry[b][s] = DeprecatedTypePropertiesUniquePtr{
new DeprecatedTypeProperties(static_cast<Backend>(b), static_cast<ScalarType>(s)),
DeprecatedTypePropertiesDeleter()
};
}
}
}

DeprecatedTypeProperties& getDeprecatedTypeProperties(Backend p, ScalarType s) {
return *registry[static_cast<int>(p)][static_cast<int>(s)];
}

private:
DeprecatedTypePropertiesUniquePtr registry
[static_cast<int>(Backend::NumOptions)]
[static_cast<int>(ScalarType::NumOptions)];
};

CAFFE2_API DeprecatedTypePropertiesRegistry& globalDeprecatedTypePropertiesRegistry();

} // namespace at
7 changes: 5 additions & 2 deletions aten/src/ATen/core/Formatting.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,10 @@ std::ostream& operator<<(std::ostream & out, const Type& t) {
return out << t.toString();
}

std::ostream& operator<<(std::ostream & out, const DeprecatedTypeProperties& t) {
return out << t.toString();
}

static std::tuple<double, int64_t> __printFormat(std::ostream& stream, const Tensor& self) {
auto size = self.numel();
if(size == 0) {
Expand Down Expand Up @@ -238,8 +242,7 @@ std::ostream& print(std::ostream& stream, const Tensor & tensor_, int64_t linesi
stream << "size:\n" << tensor_.sizes() << "\n";
stream << "]";
} else {
Type& cpudouble = tensor_.type().toBackend(Backend::CPU).toScalarType(kDouble);
Tensor tensor = tensor_.toType(cpudouble).contiguous();
Tensor tensor = tensor_.to(kCPU, kDouble).contiguous();
if(tensor.ndimension() == 0) {
stream << defaultfloat << tensor.data<double>()[0] << std::endl;
stream << "[ " << tensor_.toString() << "{} ]";
Expand Down
1 change: 1 addition & 0 deletions aten/src/ATen/core/Formatting.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ CAFFE2_API std::ostream& operator<<(std::ostream& out, Backend b);
namespace at {

CAFFE2_API std::ostream& operator<<(std::ostream& out, const Type& t);
CAFFE2_API std::ostream& operator<<(std::ostream& out, const DeprecatedTypeProperties& t);
CAFFE2_API std::ostream& print(
std::ostream& stream,
const Tensor& tensor,
Expand Down
6 changes: 3 additions & 3 deletions aten/src/ATen/core/LegacyTypeDispatch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,13 @@ namespace at {
/// Previously, in VariableType_*.cpp (generated by gen_variable_type.py), when
/// a function is using the 'use_derived' strategy, we call its implementation
/// on the base non-Variable type (`baseType`), passing unwrapped tensors to the
/// call so that any `.type()` calls in the implementation can treat the passed
/// call so that any `.dispatch_type()` calls in the implementation can treat the passed
/// tensors as non-Variables and won't dispatch back to functions in VariableType.
///
/// However, after the Variable/Tensor merge, there is no concept of unwrapping
/// a tensor anymore, and directly passing variables to the base type calls will
/// cause the `.type()` dispatch in the implementation to treat the tensor as a
/// variable, and any function dispatch based on `.type()` will dispatch back to
/// cause the `.dispatch_type()` dispatch in the implementation to treat the tensor as a
/// variable, and any function dispatch based on `.dispatch_type()` will dispatch back to
/// VariableType, which is not what we want.
///
/// The solution to the above problem is to add `at::NonVariableTypeMode`, which
Expand Down
4 changes: 2 additions & 2 deletions aten/src/ATen/core/Tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,14 +35,14 @@ void Tensor::enforce_invariants() {

void Tensor::print() const {
if (defined()) {
std::cerr << "[" << type().toString() << " " << sizes() << "]" << std::endl;
std::cerr << "[" << dispatch_type().toString() << " " << sizes() << "]" << std::endl;
} else {
std::cerr << "[UndefinedTensor]" << std::endl;
}
}

const char * Tensor::toString() const {
return type().toString();
return dispatch_type().toString();
}

} // namespace at
7 changes: 6 additions & 1 deletion aten/src/ATen/core/Tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include <c10/util/Optional.h>
#include <c10/core/Tensor.h>
#include <ATen/core/LegacyTypeDispatch.h>
#include <ATen/core/DeprecatedTypePropertiesRegistry.h>

namespace c10{
struct TensorOptions;
Expand Down Expand Up @@ -196,7 +197,11 @@ class CAFFE2_API Tensor {
return impl_->itemsize();
}

Type & type() const {
DeprecatedTypeProperties & type() const {
return globalDeprecatedTypePropertiesRegistry().getDeprecatedTypeProperties(
tensorTypeIdToBackend(type_id()), scalar_type());
}
Type & dispatch_type() const {
return legacyTensorType(*impl_);
}
TensorTypeId type_id() const {
Expand Down
Loading

0 comments on commit 5166578

Please sign in to comment.