Skip to content

Commit

Permalink
Add bazel targets for TorchOnnxToTorch conversion passes (#2596)
Browse files Browse the repository at this point in the history
Adapts to the TorchOnnxToTorch changes from
#2585.
Also restores bazel builds in post-merge CI that was disabled in
2148c4c.

Bazel workflow:
https://github.com/sjain-stanford/torch-mlir/actions/runs/7023912962
  • Loading branch information
sjain-stanford authored Nov 28, 2023
1 parent dc9ea08 commit 49fdc1a
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 3 deletions.
5 changes: 2 additions & 3 deletions .github/workflows/bazelBuildAndTest.yml
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
name: Bazel Build and Test

on:
# TODO: Fix bazel build after US holidays of 2023-Nov-23 and re-enable.
# push:
# branches: [ main ]
push:
branches: [ main ]
workflow_dispatch:

# Ensure that only a single job or workflow using the same
Expand Down
42 changes: 42 additions & 0 deletions utils/bazel/torch-mlir-overlay/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,31 @@ gentbl_cc_library(
],
)

td_library(
name = "TorchMLIRConversionTorchOnnxToTorchPassTdFiles",
srcs = [
"include/torch-mlir/Conversion/TorchOnnxToTorch/Passes.td",
],
includes = ["include"],
)

gentbl_cc_library(
name = "TorchMLIRConversionTorchOnnxToTorchPassIncGen",
strip_include_prefix = "include",
tbl_outs = [
(
["-gen-pass-decls"],
"include/torch-mlir/Conversion/TorchOnnxToTorch/Passes.h.inc",
),
],
tblgen = "@llvm-project//mlir:mlir-tblgen",
td_file = "include/torch-mlir/Conversion/TorchOnnxToTorch/Passes.td",
deps = [
":TorchMLIRConversionTorchOnnxToTorchPassTdFiles",
"@llvm-project//mlir:PassBaseTdFiles",
],
)

# TorchConversion transforms
td_library(
name = "TorchMLIRTorchConversionPassesTdFiles",
Expand Down Expand Up @@ -454,6 +479,22 @@ cc_library(
],
)

cc_library(
name = "TorchMLIRTorchOnnxToTorch",
srcs = glob([
"lib/Conversion/TorchOnnxToTorch/*.h",
"lib/Conversion/TorchOnnxToTorch/*.cpp",
]),
hdrs = glob(["include/torch-mlir/Conversion/TorchOnnxToTorch/*.h"]),
strip_include_prefix = "include",
deps = [
":TorchMLIRConversionTorchOnnxToTorchPassIncGen",
":TorchMLIRTorchDialect",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
],
)

cc_library(
name = "TorchMLIRConversionPasses",
srcs = [
Expand All @@ -468,6 +509,7 @@ cc_library(
strip_include_prefix = "include",
deps = [
":TorchMLIRTorchConversionToMLProgram",
":TorchMLIRTorchOnnxToTorch",
":TorchMLIRTorchToArith",
":TorchMLIRTorchToLinalg",
":TorchMLIRTorchToSCF",
Expand Down

0 comments on commit 49fdc1a

Please sign in to comment.