Skip to content

Commit

Permalink
Fix dtype casting issue with numpy in LlamaTune adapter (#945)
Browse files Browse the repository at this point in the history
# Pull Request

## Description

This PR addresses a breaking change in NumPy's [universal
function](https://numpy.org/doc/stable/reference/ufuncs.html) (ufunc)
behavior in version 2.0, specifically affecting `np.clip` in the
LlamaTune adapter code.

**Key Changes**:
- Resolves an issue where configuration retrieval fails due to type
casting differences
- Explicitly casts `np.clip` output from NumPy data type (e.g., `int64`)
to native Python type (e.g., `int`)

**Problem Context:**
In NumPy 2.0+, when ufuncs receive Python scalars as input, the output
is [no
longer](https://numpy.org/doc/stable/reference/generated/numpy.can_cast.html#numpy.can_cast)
automatically cast to Python scalars. This prevents retrieving
previously suggested configurations stored in `self._suggested_configs`
dict during `self.inverse_transform` method calls.

**Technical Solution:**
Implement explicit type casting to ensure consistent configuration
representation in `self._suggested_configs` across `self.transform` and
`self.inverse_transform` method calls.

- **Closes #935** 

______________________________________________________________________

## Type of Change

- 🛠️ Bug fix

______________________________________________________________________

## Testing

- **Environment**: Ubuntu 22.04 with Python 3.13
- **NumPy Versions Tested**: 1.26.4 and 2.2.2
- **Verification**: Existing test suite passed across both versions

______________________________________________________________________

---------

Co-authored-by: Brian Kroth <bpkroth@users.noreply.github.com>
Co-authored-by: Brian Kroth <bpkroth@microsoft.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
4 people authored Jan 30, 2025
1 parent 2295f3a commit e94989e
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 2 deletions.
11 changes: 11 additions & 0 deletions mlos_core/mlos_core/spaces/adapters/llamatune.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
<https://www.microsoft.com/en-us/research/publication/llamatune-sample-efficient-dbms-configuration-tuning>`_.
"""
import os
from importlib.metadata import version
from typing import Any
from warnings import warn

Expand All @@ -23,11 +24,14 @@
import numpy.typing as npt
import pandas as pd
from ConfigSpace.hyperparameters import NumericalHyperparameter
from packaging.version import Version
from sklearn.preprocessing import MinMaxScaler

from mlos_core.spaces.adapters.adapter import BaseSpaceAdapter
from mlos_core.util import normalize_config

_NUMPY_VERS = Version(version("numpy"))


class LlamaTuneAdapter(BaseSpaceAdapter): # pylint: disable=too-many-instance-attributes
"""Implementation of LlamaTune, a set of parameter space transformation techniques,
Expand Down Expand Up @@ -385,6 +389,13 @@ def _transform(self, configuration: dict) -> dict:

orig_value = param.to_value(value)
orig_value = np.clip(orig_value, param.lower, param.upper)

if _NUMPY_VERS >= Version("2.0"):
# Convert numpy types to native Python types (e.g., np.int64 to int)
# This was performed automatically in NumPy<2.0, but not anymore.
# see, https://numpy.org/doc/stable/reference/generated/numpy.can_cast.html
orig_value = orig_value.item()

else:
raise NotImplementedError(
"Only Categorical, Integer, and Float hyperparameters are currently supported."
Expand Down
5 changes: 3 additions & 2 deletions mlos_viz/mlos_viz/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,14 @@
import pandas
import seaborn as sns
from matplotlib import pyplot as plt
from packaging.version import Version
from pandas.api.types import is_numeric_dtype
from pandas.core.groupby.generic import SeriesGroupBy

from mlos_bench.storage.base_experiment_data import ExperimentData
from mlos_viz.util import expand_results_data_args

_SEABORN_VERS = version("seaborn")
_SEABORN_VERS = Version(version("seaborn"))


def _get_kwarg_defaults(target: Callable, **kwargs: Any) -> dict[str, Any]:
Expand All @@ -40,7 +41,7 @@ def ignore_plotter_warnings() -> None:
adding them to the warnings filter.
"""
warnings.filterwarnings("ignore", category=FutureWarning)
if _SEABORN_VERS <= "0.13.1":
if _SEABORN_VERS <= Version("0.13.1"):
warnings.filterwarnings(
"ignore",
category=DeprecationWarning,
Expand Down

0 comments on commit e94989e

Please sign in to comment.