Skip to content

Commit

Permalink
Backport PR #361: TST: benchmarks (#362)
Browse files Browse the repository at this point in the history
Co-authored-by: Nathaniel Starkman <nstarman@users.noreply.github.com>
  • Loading branch information
meeseeksmachine and nstarman authored Jan 8, 2025
1 parent 5e85c6a commit b89ded5
Show file tree
Hide file tree
Showing 6 changed files with 181 additions and 8 deletions.
3 changes: 2 additions & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,8 @@ jobs:
if:
github.event_name == 'workflow_dispatch' || (github.event_name ==
'pull_request' && contains(github.event.pull_request.labels.*.name,
'run-benchmarks'))
'run-benchmarks')) || (github.event_name == 'push' && github.ref ==
'refs/heads/main')
steps:
- uses: actions/checkout@v4
with:
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ UNXT_ENABLE_RUNTIME_TYPECHECKING = "beartype.beartype"
"PLR2004", # Magic value used in comparison
"PYI041", # Use `complex` instead of `int | complex` <- plum is more strict
"RUF022", # `__all__` is not sorted
"TC003", # typing-only-standard-library-import
# "TC003", # typing-only-standard-library-import
"TD002", # Missing author in TODO
"TD003", # Missing issue link on the line following this TODO
]
Expand Down
94 changes: 94 additions & 0 deletions tests/benchmark/test_dimensions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
"""Benchmark tests for quaxified jax."""

import jax
import pytest

import unxt as u

LENGTH = u.dimension("length")


@pytest.fixture
def func_dimension_is_length():
return lambda x: u.dimension(x) == LENGTH


@pytest.fixture
def func_dimension_of_length():
return lambda x: u.dimension_of(x) == LENGTH


#####################################################################
# `dimension`


@pytest.mark.parametrize(
"args",
[
(u.dimension("length"),), # -> Dimension('length')
("length",), # -> Dimension('length')
],
)
@pytest.mark.benchmark(group="dimensions", warmup=False)
def test_dimension(args):
"""Test calling `unxt.dimension`."""
_ = u.dimension(*args)


@pytest.mark.parametrize(
"args",
[
(u.dimension("length"),), # -> Dimension('length')
("length",), # -> Dimension('length')
],
)
@pytest.mark.benchmark(group="dimensions", warmup=True)
def test_dimension_execute(func_dimension_is_length, args):
"""Test the speed of calling the function."""
_ = jax.block_until_ready(func_dimension_is_length(*args))


#####################################################################
# `dimension_of`


@pytest.mark.parametrize(
"args",
[
(u.dimension("length"),), # -> Dimension('length')
(u.unit("m"),), # -> Dimension('length')
(u.Quantity(1, "m"),), # -> Dimension('length')
(2,), # -> None
],
)
@pytest.mark.benchmark(group="dimensions", warmup=False)
def test_dimension_of(args):
"""Test calling `unxt.dimension_of`."""
_ = u.dimension_of(*args)


@pytest.mark.parametrize(
"args",
[
(u.Quantity(1, "m"),), # -> Dimension('length')
],
)
@pytest.mark.benchmark(group="dimensions", warmup=False)
def test_dimension_of_jit_compile(func_dimension_of_length, args):
"""Test the speed of jitting."""
_ = jax.jit(func_dimension_of_length).lower(*args).compile()


@pytest.mark.parametrize(
"args",
[
(u.dimension("length"),), # -> Dimension('length')
(u.unit("m"),), # -> Dimension('length')
(u.Quantity(1, "m"),), # -> Dimension('length')
(2,), # -> None
],
)
@pytest.mark.benchmark(group="dimensions", warmup=True)
def test_dimension_of_execute(func_dimension_of_length, args):
"""Test the speed of calling the function."""
_ = jax.block_until_ready(func_dimension_of_length(*args))
10 changes: 5 additions & 5 deletions tests/benchmark/test_quaxed.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,11 @@

import quaxed.numpy as jnp

from unxt import Quantity, dimension_of
import unxt as u

x_nodim = Quantity(jnp.linspace(0, 1, 1000), "")
x_length = Quantity(jnp.linspace(0, 1, 1000), "m")
x_angle = Quantity(jnp.linspace(0, 1, 1000), "rad")
x_nodim = u.Quantity(jnp.linspace(0, 1, 1000), "")
x_length = u.Quantity(jnp.linspace(0, 1, 1000), "m")
x_angle = u.Quantity(jnp.linspace(0, 1, 1000), "rad")


Args: TypeAlias = tuple[Any, ...]
Expand Down Expand Up @@ -44,7 +44,7 @@ def process_pytest_argvalues(
) -> ParameterizationKWArgs:
"""Process the argvalues."""
# Get the ID for each parameterization
get_dims = lambda args: tuple(str(dimension_of(a)) for a in args)
get_dims = lambda args: tuple(str(u.dimension_of(a)) for a in args)
ids: list[str] = []
processed_argvalues: list[tuple[Compiled, Args]] = []

Expand Down
78 changes: 78 additions & 0 deletions tests/benchmark/test_units.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
"""Benchmark tests for `unxt.units`."""

import jax
import pytest

import unxt as u

METER = u.unit("m")


@pytest.fixture
def func_unit_is_length():
return lambda x: u.unit(x) == METER


@pytest.fixture
def func_unit_of_length():
return lambda x: u.unit_of(x) == METER


#####################################################################
# `unit`


@pytest.mark.parametrize(
"args",
[
(u.unit("meter"),), # -> Unit('meter')
("meter",), # -> Unit('meter')
],
)
@pytest.mark.benchmark(group="units", warmup=False)
def test_unit(args):
"""Test calling `unxt.unit`."""
_ = u.unit(*args)


@pytest.mark.parametrize(
"args",
[
(u.unit("meter"),), # -> Unit('meter')
("meter",), # -> Unit('meter')
],
)
@pytest.mark.benchmark(group="units", warmup=True)
def test_unit_execute(func_unit_is_length, args):
"""Test the speed of calling the function."""
_ = jax.block_until_ready(func_unit_is_length(*args))


#####################################################################
# `unit_of`


@pytest.mark.parametrize(
"args",
[
(u.unit("meter"),), # -> Unit('meter')
(u.Quantity(1, "m"),), # -> Unit('meter')
(2,),
],
)
@pytest.mark.benchmark(group="units", warmup=False)
def test_unit_of(args):
"""Test calling `unxt.unit_of`."""
_ = u.unit_of(*args)


@pytest.mark.parametrize(
"args",
[
(u.Quantity(1, "m"),), # -> Unit('meter')
],
)
@pytest.mark.benchmark(group="units", warmup=False)
def test_unit_of_jit_compile(func_unit_of_length, args):
"""Test the speed of jitting a function."""
_ = jax.jit(func_unit_of_length).lower(*args).compile()
2 changes: 1 addition & 1 deletion uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit b89ded5

Please sign in to comment.