Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Python bindings for registering check dialect #2445

Merged
merged 3 commits into from
Jul 23, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions stablehlo/integrations/c/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.

add_mlir_public_c_api_library(CheckCAPI
PARTIAL_SOURCES_INTENDED
CheckDialect.cpp

LINK_LIBS PUBLIC
CheckOps
)

add_mlir_public_c_api_library(ChloCAPI
PARTIAL_SOURCES_INTENDED
ChloAttributes.cpp
Expand Down
19 changes: 19 additions & 0 deletions stablehlo/integrations/c/CheckDialect.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
/* Copyright 2024 The StableHLO Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "stablehlo/integrations/c/CheckDialect.h"

#include "mlir/CAPI/Registration.h"
#include "stablehlo/tests/CheckOps.h"

MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(Check, check,
mlir::stablehlo::check::CheckDialect)
28 changes: 28 additions & 0 deletions stablehlo/integrations/c/CheckDialect.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
/* Copyright 2024 The StableHLO Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#ifndef STABLEHLO_INTEGRATIONS_C_CHECK_DIALECT_H
#define STABLEHLO_INTEGRATIONS_C_CHECK_DIALECT_H

#include "mlir-c/RegisterEverything.h"

#ifdef __cplusplus
extern "C" {
#endif

MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(Check, check);

#ifdef __cplusplus
}
#endif

#endif // STABLEHLO_INTEGRATIONS_C_CHECK_DIALECT_H
37 changes: 37 additions & 0 deletions stablehlo/integrations/python/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,18 @@ include(AddMLIRPython)
# putting .td and .py files under . instead of mlir/python will break things,
# even if the build rules below are adjusted accordingly.

declare_mlir_python_sources(CheckPythonSources)
declare_mlir_python_sources(CheckPythonSources.Dialects
ADD_TO_PARENT CheckPythonSources
)

declare_mlir_dialect_python_bindings(
ADD_TO_PARENT CheckPythonSources.Dialects
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
TD_FILE dialects/CheckOps.td
SOURCES dialects/check.py
DIALECT_NAME check)

declare_mlir_python_sources(ChloPythonSources)
declare_mlir_python_sources(ChloPythonSources.Dialects
ADD_TO_PARENT ChloPythonSources
Expand Down Expand Up @@ -53,6 +65,15 @@ declare_mlir_python_sources(StablehloToSavedModelPythonSources
stablehlo/savedmodel/stablehlo_to_tf_saved_model.py
)

declare_mlir_python_sources(StablehloTestdataGeneratorPythonSources
ADD_TO_PARENT StablehloPythonSources
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}"
SOURCES
stablehlo/testdata_generator/testdata_execution_utils.py
stablehlo/testdata_generator/testdata_generator_lib.py
stablehlo/testdata_generator/testdata_processor.py
)

declare_mlir_python_sources(VhloPythonSources)
declare_mlir_python_sources(VhloPythonSources.Dialects
ADD_TO_PARENT VhloPythonSources
Expand All @@ -69,6 +90,18 @@ declare_mlir_dialect_python_bindings(
# Extensions
################################################################################

declare_mlir_python_sources(CheckPythonExtensions)
declare_mlir_python_extension(CheckPythonExtensions.Main
MODULE_NAME _check
ADD_TO_PARENT CheckPythonExtensions
SOURCES
CheckModule.cpp
EMBED_CAPI_LINK_LIBS
CheckCAPI
PRIVATE_LINK_LIBS
LLVMSupport
)

declare_mlir_python_sources(ChloPythonExtensions)
declare_mlir_python_extension(ChloPythonExtensions.Main
MODULE_NAME _chlo
Expand Down Expand Up @@ -127,6 +160,8 @@ add_mlir_python_common_capi_library(StablehloUnifiedPythonCAPI
DECLARED_SOURCES
MLIRPythonSources
MLIRPythonExtension.RegisterEverything
CheckPythonSources
CheckPythonExtensions
ChloPythonSources
ChloPythonExtensions
StablehloPythonSources
Expand All @@ -141,6 +176,8 @@ add_mlir_python_modules(StablehloUnifiedPythonModules
DECLARED_SOURCES
MLIRPythonSources
MLIRPythonExtension.RegisterEverything
CheckPythonSources
CheckPythonExtensions
ChloPythonSources
ChloPythonExtensions
StablehloPythonSources
Expand Down
36 changes: 36 additions & 0 deletions stablehlo/integrations/python/CheckModule.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
/* Copyright 2024 The StableHLO Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "mlir-c/IR.h"
#include "mlir/Bindings/Python/PybindAdaptors.h"
#include "stablehlo/integrations/c/CheckDialect.h"

namespace py = pybind11;

PYBIND11_MODULE(_check, m) {
m.doc() = "check main python extension";

//
// Dialects.
//

m.def(
"register_dialect",
[](MlirContext context, bool load) {
MlirDialectHandle dialect = mlirGetDialectHandle__check__();
mlirDialectHandleRegisterDialect(dialect, context);
if (load) {
mlirDialectHandleLoadDialect(dialect, context);
}
},
py::arg("context"), py::arg("load") = true);
}
21 changes: 21 additions & 0 deletions stablehlo/integrations/python/mlir/dialects/CheckOps.td
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
/* Copyright 2024 The StableHLO Authors.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#ifndef STABLEHLO_INTEGRATIONS_PYTHON_CHECK_OPS
#define STABLEHLO_INTEGRATIONS_PYTHON_CHECK_OPS

include "stablehlo/tests/CheckOps.td"

#endif
18 changes: 18 additions & 0 deletions stablehlo/integrations/python/mlir/dialects/check.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# Copyright 2024 The StableHLO Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

# pylint: disable=wildcard-import,relative-beyond-top-level,g-import-not-at-top
from ._check_ops_gen import *
from .._mlir_libs._check import *
1 change: 1 addition & 0 deletions stablehlo/integrations/python/tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ add_custom_target(${test_name}
add_dependencies(check-stablehlo-python ${test_name})
endfunction()

add_stablehlo_python_test(stablehlo-python-check check.py)
add_stablehlo_python_test(stablehlo-python-chlo chlo.py)
add_stablehlo_python_test(stablehlo-python-smoketest smoketest.py)
add_stablehlo_python_test(stablehlo-python-stablehlo stablehlo.py)
Expand Down
44 changes: 44 additions & 0 deletions stablehlo/integrations/python/tests/check.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
# Copyright 2024 The StableHLO Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for CHECK Python APIs."""

# pylint: disable=wildcard-import,undefined-variable

from mlir import ir
from mlir.dialects import check as check_dialect
from mlir.dialects import stablehlo as stablehlo_dialect


def run(f):
with ir.Context() as context:
check.register_dialect(context)
stablehlo_dialect.register_dialect(ctx)
f()
return f

@run
def test_parse():
asm = """
module {
func.func @main() {
%cst = stablehlo.constant dense<[1.0, 2.0]> : tensor<2xf32>
%cst_0 = stablehlo.constant dense<[3.0, 4.0]> : tensor<2xf32>
%0 = stablehlo.add %cst, %cst_0 : tensor<2xf32>
check.expect_eq_const %0, dense<[4.0, 6.0]> : tensor<2xf32>
return
}
}
"""
ir.Module.parse(asm)
Loading