Skip to content

Commit

Permalink
Merge branch 'main' into raghavanr/torch-mlir-upgrade
Browse files Browse the repository at this point in the history
  • Loading branch information
navahgar committed Oct 18, 2023
2 parents bb5357f + 52abae1 commit 86cf909
Show file tree
Hide file tree
Showing 54 changed files with 2,559 additions and 818 deletions.
19 changes: 15 additions & 4 deletions .github/workflows/RollPyTorch.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,21 @@ jobs:
- name: Get torch-mlir
uses: actions/checkout@v3
with:
submodules: 'true'
submodules: 'false'
token: ${{ secrets.WORKFLOW_INVOCATION_TOKEN }}

- name: Get LLVM and StableHlo submodules
run: |
set -eo pipefail
cd ${GITHUB_WORKSPACE}
# Fetching the submodules concurrently may cause problems, so we fetch
# them one after another.
rm -f .git/modules/externals/llvm-project/index.lock
rm -f .git/modules/externals/stablehlo/index.lock
git submodule update --init --recursive externals/llvm-project
git submodule update --init --recursive externals/stablehlo
- name: Setup ccache
uses: ./.github/actions/setup-build
with:
Expand Down Expand Up @@ -71,15 +83,14 @@ jobs:
echo "PTVISION_RELEASE=${VISION_RELEASE}" >> ${GITHUB_ENV}
echo "PT_HASH_CHANGED=${PT_HASH_CHANGED}" >> ${GITHUB_ENV}
- name: Build and test (in-tree), also update ODS and abstract interpretation library
- name: Build and test (out-of-tree), also update ODS and abstract interpretation library
if: env.PT_HASH_CHANGED != '0'
run: |
cd ${GITHUB_WORKSPACE}
TM_PACKAGES="in-tree" TM_USE_PYTORCH_BINARY="OFF" \
TM_PACKAGES="out-of-tree" TM_USE_PYTORCH_BINARY="OFF" \
TORCH_MLIR_SRC_PYTORCH_BRANCH="${{ env.PT_HASH }}" \
TORCH_MLIR_SRC_PYTORCH_RELEASE="${{ env.PT_RELEASE }}" \
TM_UPDATE_ODS_AND_ABSTRACT_INTERP_LIB="ON" \
TM_PYTHON_VERSIONS="cp311-cp311" \
./build_tools/python_deploy/build_linux_packages.sh
- name: Post issue comment on build failure
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/buildRelease.yml
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ jobs:
cd $GITHUB_WORKSPACE
TM_PACKAGE_VERSION=${{ github.event.inputs.python_package_version }}
printf "TORCH_MLIR_PYTHON_PACKAGE_VERSION=%s\n" $TM_PACKAGE_VERSION > ./torch_mlir_package_version
TM_PYTHON_VERSIONS=${{ matrix.py_version }} TM_PACKAGES=${{ matrix.package }} TM_TORCH_VERSION="stable" ./build_tools/python_deploy/build_linux_packages.sh
TM_PYTHON_VERSIONS=${{ matrix.py_version }} TM_PACKAGES=${{ matrix.package }} TORCH_MLIR_ENABLE_LTC='0' ./build_tools/python_deploy/build_linux_packages.sh
# If we were given a release_id, then upload the package we just built
# to the github releases page.
Expand Down
15 changes: 12 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,17 @@ We have few paths to lower down to the Torch MLIR Dialect.
- `#torch-mlir` channel on the LLVM [Discord](https://discord.gg/xS7Z362) - this is the most active communication channel
- Github issues [here](https://github.com/llvm/torch-mlir/issues)
- [`torch-mlir` section](https://llvm.discourse.group/c/projects-that-want-to-become-official-llvm-projects/torch-mlir/41) of LLVM Discourse
- Weekly meetings on Mondays 9AM PST. See [here](https://discourse.llvm.org/t/community-meeting-developer-hour-refactoring-recurring-meetings/62575) for more information.
- Weekly op office hours on Thursdays 8:30-9:30AM PST. See [here](https://discourse.llvm.org/t/announcing-torch-mlir-office-hours/63973/2) for more information.

### Meetings

Community Meeting / Developer Hour:
- 1st and 3rd Monday of the month at 9 am PST
- 2nd and 4th Monday of the month at 5 pm PST

Office Hours:
- Every Thursday at 8:30 am PST

Meeting links can be found [here](https://discourse.llvm.org/t/new-community-meeting-developer-hour-schedule/73868).

## Install torch-mlir snapshot

Expand All @@ -61,7 +70,7 @@ python -m pip install --upgrade pip
Then, we can install torch-mlir with the corresponding torch and torchvision nightlies.
```
pip install --pre torch-mlir torchvision \
-f https://llvm.github.io/torch-mlir/package-index/
-f https://llvm.github.io/torch-mlir/package-index/ \
--extra-index-url https://download.pytorch.org/whl/nightly/cpu
```

Expand Down
3 changes: 2 additions & 1 deletion build_tools/autogen_ltc_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,7 +467,8 @@ def gen_fallback_code(*args, **kwargs):
node_base="torch::lazy::TorchMlirNode",
node_base_hdr=str(self.backend_path.joinpath("mlir_node.h")),
tensor_class=self.tensor_class,
tensor_class_hdr="torch/csrc/lazy/core/tensor.h",
tensor_class_hdr="torch_mlir/csrc/base_lazy_backend/tensor.h",
create_aten_from_ltc_tensor="CreateFunctionalizedAtenFromLtcTensor",
shape_inference_hdr=str(self.generated_path.joinpath("shape_inference.h")),
lazy_ir_generator=GenMlirLazyIr,
)
Expand Down
35 changes: 7 additions & 28 deletions build_tools/autogen_ltc_backend.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,6 @@ blacklist:
# It also doesn't have confusing `unsafe` argument.
- _index_put_impl

# Ops with list of tensors output
- split.Tensor
- split_with_sizes
- unbind.int
- chunk

# Additional ops which autogen is supported for but don't compile yet
- _convolution
- detach
Expand All @@ -18,42 +12,28 @@ blacklist:

# Disabled for consistency with TS backend
- lift_fresh_copy
- new_empty
- rsub
- slice.Tensor # Disabled in favour of slice_copy.Tensor
- zeros
- ones
- arange
- arange.start
- arange.start_step
- fill.Scalar
- scalar_tensor

# Disabled in favour of functionalized alternatives
- _reshape_alias
- expand
- permute
- select.int
- squeeze
- squeeze.dim
- t
- transpose.int
- expand
- squeeze
- unsqueeze
- view
- slice.Tensor
- split.Tensor
- split_with_sizes
- unbind.int

whitelist:
# Enabled for consistency with TS backend
- arange.start_out

# List of supported ops that we don't want to do the full codegen for
supported:
# - bernoulli
# - bernoulli_
- _to_copy
- clone
- empty.memory_format
- empty_strided
- fill_.Scalar
- _unsafe_view
- unbind_copy.int
- split_copy.Tensor
Expand All @@ -80,18 +60,17 @@ supported:
- _trilinear
- linalg_pinv.atol_rtol_tensor
- logsumexp.out
- t

# List of ops that will take in symints for the size instead of ints
symint:
- empty.memory_format
- new_empty_strided
- expand_copy
- narrow_copy
- slice_backward
- slice_copy.Tensor
- split_copy.Tensor
- slice_scatter
- view
- view_copy
- as_strided_copy
- as_strided_scatter
Expand Down
6 changes: 6 additions & 0 deletions build_tools/python_deploy/build_linux_packages.sh
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,12 @@ function run_in_docker() {
out-of-tree)
setup_venv "$python_version" "$TM_TORCH_VERSION"
build_out_of_tree "$TM_USE_PYTORCH_BINARY" "$python_version" "$TM_TORCH_VERSION"
if [ "${TM_UPDATE_ODS_AND_ABSTRACT_INTERP_LIB}" == "ON" ]; then
pushd /main_checkout/torch-mlir
TORCH_MLIR_BUILD_DIR=/main_checkout/torch-mlir/build_oot ./build_tools/update_torch_ods.sh
TORCH_MLIR_BUILD_DIR=/main_checkout/torch-mlir/build_oot ./build_tools/update_abstract_interp_lib.sh
popd
fi
if [ "${TM_SKIP_TESTS}" == "OFF" ]; then
test_out_of_tree
fi
Expand Down
24 changes: 17 additions & 7 deletions e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@
# Lowering Torch Backend IR -> Linalg-on-Tensors Backend IR failed
# 'linalg.depthwise_conv_2d_nchw_chw' op inferred input/output operand #1 has shape's dimension #0 to be 4, but found 8
"Conv2dWithPaddingDilationStrideStaticModule_depthwise_multiplier",
"UnflattenStaticModule_basic",
"IscloseStaticModule_basic",
"IscloseStaticModuleTrue_basic",
}

TORCHDYNAMO_XFAIL_SET = {
Expand Down Expand Up @@ -288,6 +291,12 @@

# AssertionError: Unregistered operation: torch.aten._embedding_bag_forward_only
"AtenEmbeddingBagStaticModule_basic",

# Lowering not present for this case
"ElementwiseToDtypeI64ToUI8Module_basic",

# torch._dynamo.exc.TorchRuntimeError: Failed running call_function <built-in method add of type object at 0x7f4f8b05a720>(*(FakeTensor(..., size=(3, 4), dtype=torch.int8), 3, 2), **{}): Tensor with dtype torch.int64 is not the expected dtype of torch.int8!
"ElementwiseAddScalarInt8Module_basic",
}

if torch_version_for_comparison() < version.parse("2.1.0.dev"):
Expand Down Expand Up @@ -827,7 +836,6 @@
"ReshapeAliasCollapseModule_basic",
"ReshapeAliasExpandModule_basic",
"ReshapeExpandModule_basic",
"RollModule_basic",
"TestMultipleTensorReturn_basic",
"AdaptiveAvgPool1dUnitOutputSizeStaticModule_basic",
"AdaptiveAvgPool2dUnitOutputSizeStaticModule_basic",
Expand Down Expand Up @@ -1046,6 +1054,8 @@
# Write the TOSA set as a "passing" set as it is very early in development
# and very few tests work yet.
TOSA_PASS_SET = {
"IscloseStaticModule_basic",
"IscloseStaticModuleTrue_basic",
"TileBigDimsSizeModule_basic",
"TileSmallDimsSizeModule_basic",
"IndexPutImpl2DNoneIndexStaticModule_basic",
Expand Down Expand Up @@ -1175,6 +1185,7 @@
"BatchNorm3DModule_basic",
"BatchNorm1DStaticShapeModule_basic",
"FlattenStaticModule_basic",
"UnflattenStaticModule_basic",
"FlattenRank0Module_basic",
"ElementwiseFlattenBroadcastModule_basic",
"SquareModule_basic",
Expand Down Expand Up @@ -1383,6 +1394,8 @@
"SoftmaxIntNegDimModule_basic",
"_LogSoftmaxModule_basic",
"_SoftmaxModule_basic",
"ElementwiseAddScalarInt8Module_basic",
"ElementwiseSubTensorInt8Module_basic",
}

MAKE_FX_TOSA_PASS_SET = (TOSA_PASS_SET | {
Expand Down Expand Up @@ -1441,10 +1454,6 @@
"_ConvolutionDeprecated2DBenchmarkModule_basic",
"_ConvolutionDeprecated2DCudnnModule_basic",
"_ConvolutionDeprecated2DDeterministicModule_basic",
"AdaptiveAvgPool1dNonUnitOutputSizeDynamicModule_basic",
"AdaptiveAvgPool1dNonUnitOutputSizeStaticModule_basic",
"AdaptiveAvgPool2dNonUnitOutputSizeDynamicModule_basic",
"AdaptiveAvgPool2dNonUnitOutputSizeStaticModule_basic",
"AddIntModule_basic",
"AtenIntBoolOpModule_basic",
"BernoulliTensorModule_basic",
Expand Down Expand Up @@ -1480,7 +1489,6 @@
"NeFloatIntModule_basic",
"NeIntModule_basic",
"QuantizedMLP_basic",
"RollModule_basic",
"ScalarImplicitFloatModule_basic",
"ScalarImplicitIntModule_basic",
"SliceEndSleStartModule_basic",
Expand Down Expand Up @@ -1512,7 +1520,6 @@
"ConvolutionBackwardModule2DPadded_basic",
"VarMeanCorrectionModule_basic",
"VarMeanCorrectionNoneModule_basic",
"PrimsConvertElementTypeModule_basic",
"ElementwisePreluModule_basic",
"VarMeanBiasedModule_basic",
"VarMeanUnbiasedModule_basic",
Expand Down Expand Up @@ -1547,4 +1554,7 @@
"UniformStaticShapeModule_basic",
"AtenEmbeddingBagStaticModule_basic",
"EmptyStridedModule_basic",
"ElementwiseBitwiseAndScalarInt64Module_basic",
"ElementwiseBitwiseAndScalarInt32Module_basic",
"ElementwiseBitwiseAndScalarInt8Module_basic",
}
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ LogicalResult AttentionOp::generateScalarImplementation(OpBuilder &b,
loc, init,
[&](OpBuilder &b, Location loc, Value elem, Value acc) {
Value x = b.create<memref::LoadOp>(loc, weight, localIVs);
Value max = b.create<arith::MaxFOp>(loc, x, acc);
Value max = b.create<arith::MaximumFOp>(loc, x, acc);
b.create<scf::ReduceReturnOp>(loc, max);
});
})
Expand Down
2 changes: 1 addition & 1 deletion externals/llvm-project
Submodule llvm-project updated 12654 files
3 changes: 2 additions & 1 deletion include/torch-mlir/Conversion/Utils/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,8 @@ mlir::RankedTensorType GetTypeFromTensorShape(llvm::ArrayRef<int64_t> shape,
// from a tensor or a scalar in the pytorch dialect. Both the scalar and dtype
// should be converted builtin types.
Value convertScalarToDtype(OpBuilder &b, Location loc, Value scalar, Type dtype,
std::optional<Type> srcOriginalDtype = std::nullopt);
std::optional<Type> srcOriginalDtype = std::nullopt,
std::optional<Type> dstOriginalDtype = std::nullopt);

Value toPositiveValidDim(ConversionPatternRewriter &rewriter, Location loc,
Value torchOptionalInt, Value builtinInt,
Expand Down
Loading

0 comments on commit 86cf909

Please sign in to comment.