Skip to content

Commit

Permalink
Add some LTC changes so tests pass
Browse files Browse the repository at this point in the history
  • Loading branch information
rsuderman committed Jan 24, 2024
1 parent 2036fc4 commit e967a91
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 2 deletions.
57 changes: 57 additions & 0 deletions projects/ltc/csrc/base_lazy_backend/shape_inference.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,63 @@ std::vector<torch::lazy::Shape> compute_shape_isinf(const at::Tensor& self) {
return {Shape(at::kBool, self.sizes().vec())};
}

std::vector<torch::lazy::Shape> compute_shape__make_per_channel_quantized_tensor(const at::Tensor & self,
const at::Tensor & scale,
const at::Tensor & zero_point,
int64_t axis) {
at::ScalarType scalar = self.scalar_type();
if (scalar == at::ScalarType::Byte)
scalar = at::ScalarType::QUInt8;
if (scalar == at::ScalarType::Char)
scalar = at::ScalarType::QInt8;
if (scalar == at::ScalarType::Int)
scalar = at::ScalarType::QInt32;
return {Shape(scalar, self.sizes().vec())};
}

std::vector<torch::lazy::Shape> compute_shape__make_per_tensor_quantized_tensor(const at::Tensor & self,
double scale,
int64_t zero_point) {
at::ScalarType scalar = self.scalar_type();
if (scalar == at::ScalarType::Byte)
scalar = at::ScalarType::QUInt8;
if (scalar == at::ScalarType::Char)
scalar = at::ScalarType::QInt8;
if (scalar == at::ScalarType::Int)
scalar = at::ScalarType::QInt32;
return {Shape(scalar, self.sizes().vec())};
}

std::vector<torch::lazy::Shape> compute_shape_dequantize(const at::Tensor & self) {
return {Shape(at::ScalarType::Float, self.sizes().vec())};
}

std::vector<torch::lazy::Shape> compute_shape_quantize_per_channel(const at::Tensor & self,
const at::Tensor & scales,
const at::Tensor & zero_points,
int64_t axis,
at::ScalarType dtype) {
return {Shape(dtype, self.sizes().vec())};
}

std::vector<torch::lazy::Shape> compute_shape_quantize_per_tensor(const at::Tensor & self,
double scale,
int64_t zero_point,
at::ScalarType dtype) {
return {Shape(dtype, self.sizes().vec())};
}

std::vector<torch::lazy::Shape> compute_shape_int_repr(const at::Tensor & self) {
at::ScalarType scalar = self.scalar_type();
if (scalar == at::ScalarType::QUInt8)
scalar = at::ScalarType::Byte;
if (scalar == at::ScalarType::QInt8)
scalar = at::ScalarType::Char;
if (scalar == at::ScalarType::QInt32)
scalar = at::ScalarType::Int;
return {Shape(scalar, self.sizes().vec())};
}

std::vector<torch::lazy::Shape> compute_shape_max_pool3d_with_indices(
const at::Tensor& self, at::IntArrayRef kernel_size,
at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation,
Expand Down
2 changes: 0 additions & 2 deletions projects/pt1/e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -1519,6 +1519,4 @@
"ElementwiseBitwiseAndScalarInt32Module_basic",
"ElementwiseBitwiseAndScalarInt8Module_basic",
"ElementwiseNanToNumModule_Basic",
"ElementwiseQuantizePerTensorModule_basic",
"ElementwiseDequantizePerTensorModule_basic"
}

0 comments on commit e967a91

Please sign in to comment.