Skip to content

Commit

Permalink
Upgrade bazel to 8.0.0
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 719452855
  • Loading branch information
Grain Team authored and copybara-github committed Jan 29, 2025
1 parent c42ce44 commit e096e19
Show file tree
Hide file tree
Showing 25 changed files with 1,128 additions and 149 deletions.
8 changes: 8 additions & 0 deletions BUILD
Original file line number Diff line number Diff line change
@@ -1,5 +1,13 @@
load("@rules_python//python:pip.bzl", "compile_pip_requirements")

py_library(
name = "setup",
srcs = ["setup.py"],
srcs_version = "PY3",
)

compile_pip_requirements(
name = "requirements",
requirements_in = "requirements.in",
requirements_txt = "requirements_lock.txt",
)
43 changes: 43 additions & 0 deletions MODULE.bazel
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
module(
name = "grain",
version = "0.2.3",
repo_name = "com_google_grain",
)

http_file = use_repo_rule("@bazel_tools//tools/build_defs/repo:http.bzl", "http_file")

bazel_dep(name = "bazel_skylib", version = "1.2.1")
bazel_dep(name = "platforms", version = "0.0.5")
bazel_dep(name = "rules_python", version = "0.34.0")
bazel_dep(name = "pybind11_bazel", version = "2.13.6")
bazel_dep(name = "abseil-py", version = "2.1.0")
bazel_dep(name = "abseil-cpp", version = "20240722.0")

python = use_extension("@rules_python//python/extensions:python.bzl", "python")

PYTHON_VERSION = "3.10"

python.toolchain(
configure_coverage_tool = False,
ignore_root_user_error = True,
is_default = True,
python_version = PYTHON_VERSION,
)

pip = use_extension("@rules_python//python/extensions:pip.bzl", "pip")

# requirements_lock.txt is generated by
# bazel run //:requirements.update
pip.parse(
hub_name = "pypi",
python_version = PYTHON_VERSION,
requirements_lock = "//:requirements_lock.txt",
)
use_repo(pip, "pypi")

http_file(
name = "pybind11",
build_file = "@pybind11_bazel//:pybind11.BUILD",
sha256 = "201966a61dc826f1b1879a24a3317a1ec9214a918c8eb035be2f30c3e9cfbdcb",
urls = ["https://github.com/pybind/pybind11/archive/refs/tags/v2.12.1.zip"],
)
34 changes: 0 additions & 34 deletions WORKSPACE

This file was deleted.

63 changes: 51 additions & 12 deletions grain/_src/core/BUILD
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
load("@pypi//:requirements.bzl", "requirement")

package(default_visibility = ["//grain:__subpackages__"])

licenses(["notice"])
Expand All @@ -8,6 +10,7 @@ py_library(
srcs_version = "PY3",
deps = [
":monitoring",
"@abseil-py//absl/flags",
],
)

Expand Down Expand Up @@ -43,26 +46,42 @@ py_test(
name = "parallel_test",
srcs = ["parallel_test.py"],
srcs_version = "PY3",
deps = [":parallel"],
deps = [
":parallel",
"@abseil-py//absl/testing:absltest",
"@abseil-py//absl/testing:parameterized",
],
)

py_library(
name = "grain_random",
srcs = ["grain_random.py"],
srcs_version = "PY3",
deps = [
"@abseil-py//absl/logging",
"@pypi//jax:pkg",
"@pypi//numpy:pkg",
],
)

py_library(
name = "sharding",
srcs = ["sharding.py"],
srcs_version = "PY3",
deps = [
"@abseil-py//absl/logging",
],
)

py_test(
name = "sharding_test",
srcs = ["sharding_test.py"],
srcs_version = "PY3",
deps = [":sharding"],
deps = [
":sharding",
"@abseil-py//absl/testing:absltest",
"@abseil-py//absl/testing:parameterized",
],
)

py_library(
Expand All @@ -75,46 +94,66 @@ py_library(
name = "transforms",
srcs = ["transforms.py"],
srcs_version = "PY3",
deps = [
"@pypi//numpy:pkg",
],
)

py_test(
name = "transforms_test",
srcs = ["transforms_test.py"],
srcs_version = "PY3",
deps = [":transforms"],
deps = [
":transforms",
"@abseil-py//absl/testing:absltest",
"@abseil-py//absl/testing:parameterized",
],
)

py_library(
name = "tree",
name = "treelib",
srcs = [
"tree.py",
"treelib.py",
],
srcs_version = "PY3",
deps = [
"@pypi//dm_tree:pkg",
"@pypi//numpy:pkg",
],
)

py_library(
name = "tree_test_lib",
testonly = 1,
srcs = ["tree_test.py"],
srcs = ["treelib_test.py"],
srcs_version = "PY3",
deps = [":tree"],
deps = [
":treelib",
"@abseil-py//absl/testing:absltest",
"@abseil-py//absl/testing:parameterized",
"@pypi//numpy:pkg",
],
)

py_test(
name = "tree_test",
srcs = ["tree_test.py"],
name = "treelib_test",
srcs = ["treelib_test.py"],
srcs_version = "PY3",
deps = [
":tree_test_lib",
],
)

py_test(
name = "tree_jax_test",
srcs = ["tree_jax_test.py"],
name = "treelib_jax_test",
srcs = ["treelib_jax_test.py"],
srcs_version = "PY3",
deps = [
":tree",
":tree_test_lib",
":treelib",
"@abseil-py//absl/testing:absltest",
"@pypi//attrs:pkg",
"@pypi//jax:pkg",
"@pypi//numpy:pkg",
],
)
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Testes for tree.py with JAX dependency present."""
"""Testes for treelib.py with JAX dependency present."""

from absl.testing import absltest
import attrs
from grain._src.core import tree
from grain._src.core import tree_test
from grain._src.core import treelib
from grain._src.core import treelib_test
import jax
import numpy as np

Expand All @@ -43,19 +43,19 @@ class MyAttrs:
e: str


class TreeJaxTest(tree_test.TreeTest):
class TreeJaxTest(treelib_test.TreeTest):

def test_map_custom_tree(self):
jax.tree_util.register_pytree_node(
MyTree, lambda t: ((t.a, t.b), None), lambda _, args: MyTree(*args)
)
self.assertEqual(
tree.map_structure(lambda x: x + 1, MyTree(1, 2)), MyTree(2, 3)
treelib.map_structure(lambda x: x + 1, MyTree(1, 2)), MyTree(2, 3)
)

def test_spec_like_with_class(self):
self.assertEqual(
tree.spec_like({"B": 1232.4, "C": MyClass(1)}),
treelib.spec_like({"B": 1232.4, "C": MyClass(1)}),
{
"B": "<class 'float'>[]",
"C": "<class '__main__.MyClass'>[]",
Expand All @@ -64,22 +64,22 @@ def test_spec_like_with_class(self):

def test_spec_like_with_list(self):
self.assertEqual(
tree.spec_like({
treelib.spec_like({
"B": 1232.4,
"C": [
tree_test.TestClass(a=1, b="v2"),
tree_test.TestClass(a=2, b="v2"),
treelib_test.TestClass(a=1, b="v2"),
treelib_test.TestClass(a=2, b="v2"),
],
}),
{
"B": "<class 'float'>[]",
"C": "list<grain._src.core.tree_test.TestClass>[2]",
"C": "list<grain._src.core.treelib_test.TestClass>[2]",
},
)

def test_spec_like_with_unknown_shape(self):
self.assertEqual(
tree.spec_like({
treelib.spec_like({
"B": [np.zeros([2]), np.zeros([1])],
"C": [],
}),
Expand All @@ -88,14 +88,14 @@ def test_spec_like_with_unknown_shape(self):

def test_spec_like_with_dataclass(self):
self.assertEqual(
tree.spec_like(tree_test.TestClass(a=1, b="v2")),
"<class 'grain._src.core.tree_test.TestClass'>\n"
treelib.spec_like(treelib_test.TestClass(a=1, b="v2")),
"<class 'grain._src.core.treelib_test.TestClass'>\n"
"{'a': \"<class 'int'>[]\", 'b': \"<class 'str'>[]\"}[]",
)

def test_spec_like_with_attrs(self):
self.assertEqual(
tree.spec_like(MyAttrs(d=1, e="v2")),
treelib.spec_like(MyAttrs(d=1, e="v2")),
"<class '__main__.MyAttrs'>\n"
"{'d': \"<class 'int'>[]\", 'e': \"<class 'str'>[]\"}[]",
)
Expand Down
Loading

0 comments on commit e096e19

Please sign in to comment.