Skip to content

Commit

Permalink
Unsafe fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
basnijholt committed Apr 9, 2024
1 parent 05eb39f commit 1aa4cc7
Show file tree
Hide file tree
Showing 38 changed files with 870 additions and 560 deletions.
45 changes: 27 additions & 18 deletions adaptive/learner/average_learner.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,24 @@
from __future__ import annotations

from math import sqrt
from typing import Callable
from typing import TYPE_CHECKING, Callable

import cloudpickle
import numpy as np

from adaptive.learner.base_learner import BaseLearner
from adaptive.notebook_integration import ensure_holoviews
from adaptive.types import Float, Int, Real
from adaptive.utils import (
assign_defaults,
cache_latest,
partial_function_from_dataframe,
)

if TYPE_CHECKING:
from adaptive.types import Float, Int, Real

try:
import pandas
import pandas as pd

with_pandas = True

Expand Down Expand Up @@ -47,6 +49,7 @@ class AverageLearner(BaseLearner):
Points that still have to be evaluated.
npoints : int
Number of evaluated points.
"""

def __init__(
Expand All @@ -57,7 +60,8 @@ def __init__(
min_npoints: int = 2,
) -> None:
if atol is None and rtol is None:
raise Exception("At least one of `atol` and `rtol` should be set.")
msg = "At least one of `atol` and `rtol` should be set."
raise Exception(msg)
if atol is None:
atol = np.inf
if rtol is None:
Expand Down Expand Up @@ -92,7 +96,7 @@ def to_dataframe( # type: ignore[override]
function_prefix: str = "function.",
seed_name: str = "seed",
y_name: str = "y",
) -> pandas.DataFrame:
) -> pd.DataFrame:
"""Return the data as a `pandas.DataFrame`.
Parameters
Expand All @@ -116,10 +120,12 @@ def to_dataframe( # type: ignore[override]
------
ImportError
If `pandas` is not installed.
"""
if not with_pandas:
raise ImportError("pandas is not installed.")
df = pandas.DataFrame(sorted(self.data.items()), columns=[seed_name, y_name])
msg = "pandas is not installed."
raise ImportError(msg)
df = pd.DataFrame(sorted(self.data.items()), columns=[seed_name, y_name])
df.attrs["inputs"] = [seed_name]
df.attrs["output"] = y_name
if with_default_function_args:
Expand All @@ -128,12 +134,12 @@ def to_dataframe( # type: ignore[override]

def load_dataframe( # type: ignore[override]
self,
df: pandas.DataFrame,
df: pd.DataFrame,
with_default_function_args: bool = True,
function_prefix: str = "function.",
seed_name: str = "seed",
y_name: str = "y",
):
) -> None:
"""Load data from a `pandas.DataFrame`.
If ``with_default_function_args`` is True, then ``learner.function``'s
Expand All @@ -153,11 +159,14 @@ def load_dataframe( # type: ignore[override]
The ``seed_name`` used in ``to_dataframe``, by default "seed"
y_name : str, optional
The ``y_name`` used in ``to_dataframe``, by default "y"
"""
self.tell_many(df[seed_name].values, df[y_name].values)
if with_default_function_args:
self.function = partial_function_from_dataframe(
self.function, df, function_prefix
self.function,
df,
function_prefix,
)

def ask(self, n: int, tell_pending: bool = True) -> tuple[list[int], list[Float]]:
Expand All @@ -168,7 +177,7 @@ def ask(self, n: int, tell_pending: bool = True) -> tuple[list[int], list[Float]
points = list(
set(range(self.n_requested + n))
- set(self.data)
- set(self.pending_points)
- set(self.pending_points),
)[:n]

loss_improvements = [self._loss_improvement(n) / n] * n
Expand Down Expand Up @@ -199,7 +208,8 @@ def mean(self) -> Float:
@property
def std(self) -> Float:
"""The corrected sample standard deviation of the values
in `data`."""
in `data`.
"""
n = self.npoints
if n < self.min_npoints:
return np.inf
Expand All @@ -211,10 +221,7 @@ def std(self) -> Float:

@cache_latest
def loss(self, real: bool = True, *, n=None) -> Float:
if n is None:
n = self.npoints if real else self.n_requested
else:
n = n
n = (self.npoints if real else self.n_requested) if n is None else n
if n < self.min_npoints:
return np.inf
standard_error = self.std / sqrt(n)
Expand All @@ -232,7 +239,7 @@ def _loss_improvement(self, n: int) -> Float:
else:
return np.inf

def remove_unfinished(self):
def remove_unfinished(self) -> None:
"""Remove uncomputed data from the learner."""
self.pending_points = set()

Expand All @@ -242,7 +249,9 @@ def plot(self):
Returns
-------
holoviews.element.Histogram
A histogram of the evaluated data."""
A histogram of the evaluated data.
"""
hv = ensure_holoviews()
vals = [v for v in self.data.values() if v is not None]
if not vals:
Expand Down
Loading

0 comments on commit 1aa4cc7

Please sign in to comment.