-
Notifications
You must be signed in to change notification settings - Fork 126
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Introduce DeprecatedTypeProperties class (#17991)
Summary: Pull Request resolved: pytorch/pytorch#17991 changes: -Breaks bc: Tensor::type() now returns DeprecatedTypeProperties& rather than Type&. -Added DeprecatedTypeProperties, it serves as a temporary replacement for Type as the return value of Tensor::type(). This contributes to making Type just for dispatch purposes so that we can make it dtype agnostic. -Tensor::dispatch_type() now returns Type& like Tensor::type() used to do. -Changed callsites of Tensor::type() appropriately. Reviewed By: ezyang Differential Revision: D14443117 fbshipit-source-id: 239ccb7a09626279a71d1a37f8f82e7f57bf7d9e
- Loading branch information
1 parent
2fa8477
commit 5166578
Showing
36 changed files
with
637 additions
and
501 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
#pragma once | ||
|
||
#include <c10/core/Backend.h> | ||
#include <c10/core/ScalarType.h> | ||
#include <c10/core/Layout.h> | ||
|
||
|
||
|
||
namespace at { | ||
|
||
// This class specifies a Backend and a ScalarType. Currently, it primarily | ||
// serves as a replacement return value for Tensor::type(). Previously, | ||
// Tensor::type() returned Type&, but we are changing Type to not be | ||
// dtype-specific. | ||
class DeprecatedTypeProperties { | ||
public: | ||
DeprecatedTypeProperties(Backend backend, ScalarType scalar_type) | ||
: backend_(backend), scalar_type_(scalar_type) {} | ||
|
||
Backend backend() const { | ||
return backend_; | ||
} | ||
|
||
bool is_sparse() const { | ||
return layout_from_backend(backend()) == kSparse; | ||
} | ||
|
||
DeviceType device_type() const { | ||
return backendToDeviceType(backend_); | ||
} | ||
|
||
bool is_cuda() const { | ||
return backendToDeviceType(backend_) == kCUDA; | ||
} | ||
|
||
ScalarType scalarType() const { | ||
return scalar_type_; | ||
} | ||
|
||
caffe2::TypeMeta typeMeta() const { | ||
return scalarTypeToTypeMeta(scalar_type_); | ||
} | ||
|
||
bool is_defined() const { | ||
return backend_ != Backend::Undefined && scalar_type_ != ScalarType::Undefined; | ||
} | ||
|
||
bool operator==(const DeprecatedTypeProperties& other) const { | ||
return backend_ == other.backend() && scalar_type_ == other.scalarType(); | ||
} | ||
|
||
bool operator!=(const DeprecatedTypeProperties& other) const { | ||
return !(*this == other); | ||
} | ||
|
||
std::string toString() const { | ||
std::stringstream ss; | ||
ss << at::toString(backend()) << at::toString(scalarType()) << "Type"; | ||
return ss.str(); | ||
} | ||
|
||
private: | ||
Backend backend_; | ||
ScalarType scalar_type_; | ||
}; | ||
|
||
} // namespace at |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
#include <ATen/core/DeprecatedTypePropertiesRegistry.h> | ||
|
||
namespace at { | ||
|
||
// TODO: This could be bad juju if someone calls globalContext() in the | ||
// destructor of an object with static lifetime. | ||
DeprecatedTypePropertiesRegistry & globalDeprecatedTypePropertiesRegistry() { | ||
static DeprecatedTypePropertiesRegistry singleton; | ||
return singleton; | ||
} | ||
|
||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
#pragma once | ||
|
||
// In order to preserve bc, we make DeprecatedTypeProperties instances unique | ||
// just like they are for Type. | ||
|
||
#include <c10/core/Backend.h> | ||
#include <c10/core/ScalarType.h> | ||
#include <ATen/core/DeprecatedTypeProperties.h> | ||
|
||
namespace at { | ||
|
||
struct CAFFE2_API DeprecatedTypePropertiesDeleter { | ||
void operator()(DeprecatedTypeProperties * ptr) { | ||
delete ptr; | ||
} | ||
}; | ||
|
||
class CAFFE2_API DeprecatedTypePropertiesRegistry { | ||
public: | ||
using DeprecatedTypePropertiesUniquePtr = | ||
std::unique_ptr<DeprecatedTypeProperties, DeprecatedTypePropertiesDeleter>; | ||
|
||
DeprecatedTypePropertiesRegistry() { | ||
for (int b = 0; b < static_cast<int>(Backend::NumOptions); ++b) { | ||
for (int s = 0; s < static_cast<int>(ScalarType::NumOptions); ++s) { | ||
registry[b][s] = DeprecatedTypePropertiesUniquePtr{ | ||
new DeprecatedTypeProperties(static_cast<Backend>(b), static_cast<ScalarType>(s)), | ||
DeprecatedTypePropertiesDeleter() | ||
}; | ||
} | ||
} | ||
} | ||
|
||
DeprecatedTypeProperties& getDeprecatedTypeProperties(Backend p, ScalarType s) { | ||
return *registry[static_cast<int>(p)][static_cast<int>(s)]; | ||
} | ||
|
||
private: | ||
DeprecatedTypePropertiesUniquePtr registry | ||
[static_cast<int>(Backend::NumOptions)] | ||
[static_cast<int>(ScalarType::NumOptions)]; | ||
}; | ||
|
||
CAFFE2_API DeprecatedTypePropertiesRegistry& globalDeprecatedTypePropertiesRegistry(); | ||
|
||
} // namespace at |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.