Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix np.float_ type #171

Merged
merged 1 commit into from
Jul 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,14 @@ ci:

repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.4.7
rev: v0.5.0
hooks:
- id: ruff
args: ["--fix", "--show-source"]
args: ["--fix", "--output-format=full"]
- id: ruff-format
args: ["--line-length=100"]
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.10.0
rev: v1.10.1
hooks:
- id: mypy
args: [--ignore-missing-imports]
Expand Down
4 changes: 2 additions & 2 deletions pymc_bart/bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def __new__(
alpha: float = 0.95,
beta: float = 2.0,
response: str = "constant",
split_prior: Optional[npt.NDArray[np.float_]] = None,
split_prior: Optional[npt.NDArray[np.float64]] = None,
split_rules: Optional[List[SplitRule]] = None,
separate_trees: Optional[bool] = False,
**kwargs,
Expand Down Expand Up @@ -198,7 +198,7 @@ def get_moment(cls, rv, size, *rv_inputs):

def preprocess_xy(
X: TensorLike, Y: TensorLike
) -> Tuple[npt.NDArray[np.float_], npt.NDArray[np.float_]]:
) -> Tuple[npt.NDArray[np.float64], npt.NDArray[np.float64]]:
if isinstance(Y, (Series, DataFrame)):
Y = Y.to_numpy()
if isinstance(X, (Series, DataFrame)):
Expand Down
46 changes: 23 additions & 23 deletions pymc_bart/pgbart.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,7 +313,7 @@ def normalize(self, particles: List[ParticleTree]) -> float:
return wei / wei.sum()

def resample(
self, particles: List[ParticleTree], normalized_weights: npt.NDArray[np.float_]
self, particles: List[ParticleTree], normalized_weights: npt.NDArray[np.float64]
) -> List[ParticleTree]:
"""
Use systematic resample for all but the first particle
Expand All @@ -335,7 +335,7 @@ def resample(
return particles

def get_particle_tree(
self, particles: List[ParticleTree], normalized_weights: npt.NDArray[np.float_]
self, particles: List[ParticleTree], normalized_weights: npt.NDArray[np.float64]
) -> Tuple[ParticleTree, Tree]:
"""
Sample a new particle and associated tree
Expand All @@ -347,7 +347,7 @@ def get_particle_tree(

return new_particle, new_particle.tree

def systematic(self, normalized_weights: npt.NDArray[np.float_]) -> npt.NDArray[np.int_]:
def systematic(self, normalized_weights: npt.NDArray[np.float64]) -> npt.NDArray[np.int_]:
"""
Systematic resampling.

Expand Down Expand Up @@ -399,7 +399,7 @@ def __init__(self, shape: tuple) -> None:
self.mean = np.zeros(shape) # running mean
self.m_2 = np.zeros(shape) # running second moment

def update(self, new_value: npt.NDArray[np.float_]) -> Union[float, npt.NDArray[np.float_]]:
def update(self, new_value: npt.NDArray[np.float64]) -> Union[float, npt.NDArray[np.float64]]:
self.count = self.count + 1
self.mean, self.m_2, std = _update(self.count, self.mean, self.m_2, new_value)
return fast_mean(std)
Expand All @@ -408,10 +408,10 @@ def update(self, new_value: npt.NDArray[np.float_]) -> Union[float, npt.NDArray[
@njit
def _update(
count: int,
mean: npt.NDArray[np.float_],
m_2: npt.NDArray[np.float_],
new_value: npt.NDArray[np.float_],
) -> Tuple[npt.NDArray[np.float_], npt.NDArray[np.float_], Union[float, npt.NDArray[np.float_]]]:
mean: npt.NDArray[np.float64],
m_2: npt.NDArray[np.float64],
new_value: npt.NDArray[np.float64],
) -> Tuple[npt.NDArray[np.float64], npt.NDArray[np.float64], Union[float, npt.NDArray[np.float64]]]:
delta = new_value - mean
mean += delta / count
delta2 = new_value - mean
Expand All @@ -422,7 +422,7 @@ def _update(


class SampleSplittingVariable:
def __init__(self, alpha_vec: npt.NDArray[np.float_]) -> None:
def __init__(self, alpha_vec: npt.NDArray[np.float64]) -> None:
"""
Sample splitting variables proportional to `alpha_vec`.

Expand Down Expand Up @@ -535,13 +535,13 @@ def filter_missing_values(available_splitting_values, idx_data_points, missing_d


def draw_leaf_value(
y_mu_pred: npt.NDArray[np.float_],
x_mu: npt.NDArray[np.float_],
y_mu_pred: npt.NDArray[np.float64],
x_mu: npt.NDArray[np.float64],
m: int,
norm: npt.NDArray[np.float_],
norm: npt.NDArray[np.float64],
shape: int,
response: str,
) -> Tuple[npt.NDArray[np.float_], Optional[npt.NDArray[np.float_]]]:
) -> Tuple[npt.NDArray[np.float64], Optional[npt.NDArray[np.float64]]]:
"""Draw Gaussian distributed leaf values."""
linear_params = None
mu_mean = np.empty(shape)
Expand All @@ -559,7 +559,7 @@ def draw_leaf_value(


@njit
def fast_mean(ari: npt.NDArray[np.float_]) -> Union[float, npt.NDArray[np.float_]]:
def fast_mean(ari: npt.NDArray[np.float64]) -> Union[float, npt.NDArray[np.float64]]:
"""Use Numba to speed up the computation of the mean."""
if ari.ndim == 1:
count = ari.shape[0]
Expand All @@ -578,11 +578,11 @@ def fast_mean(ari: npt.NDArray[np.float_]) -> Union[float, npt.NDArray[np.float_

@njit
def fast_linear_fit(
x: npt.NDArray[np.float_],
y: npt.NDArray[np.float_],
x: npt.NDArray[np.float64],
y: npt.NDArray[np.float64],
m: int,
norm: npt.NDArray[np.float_],
) -> Tuple[npt.NDArray[np.float_], List[npt.NDArray[np.float_]]]:
norm: npt.NDArray[np.float64],
) -> Tuple[npt.NDArray[np.float64], List[npt.NDArray[np.float64]]]:
n = len(x)
y = y / m + np.expand_dims(norm, axis=1)

Expand Down Expand Up @@ -666,17 +666,17 @@ def update(self):

@njit
def inverse_cdf(
single_uniform: npt.NDArray[np.float_], normalized_weights: npt.NDArray[np.float_]
single_uniform: npt.NDArray[np.float64], normalized_weights: npt.NDArray[np.float64]
) -> npt.NDArray[np.int_]:
"""
Inverse CDF algorithm for a finite distribution.

Parameters
----------
single_uniform: npt.NDArray[np.float_]
single_uniform: npt.NDArray[np.float64]
Ordered points in [0,1]

normalized_weights: npt.NDArray[np.float_])
normalized_weights: npt.NDArray[np.float64])
Normalized weights

Returns
Expand All @@ -699,7 +699,7 @@ def inverse_cdf(


@njit
def jitter_duplicated(array: npt.NDArray[np.float_], std: float) -> npt.NDArray[np.float_]:
def jitter_duplicated(array: npt.NDArray[np.float64], std: float) -> npt.NDArray[np.float64]:
"""
Jitter duplicated values.
"""
Expand All @@ -715,7 +715,7 @@ def jitter_duplicated(array: npt.NDArray[np.float_], std: float) -> npt.NDArray[


@njit
def are_whole_number(array: npt.NDArray[np.float_]) -> np.bool_:
def are_whole_number(array: npt.NDArray[np.float64]) -> np.bool_:
"""Check if all values in array are whole numbers"""
return np.all(np.mod(array[~np.isnan(array)], 1) == 0)

Expand Down
40 changes: 20 additions & 20 deletions pymc_bart/tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ class Node:

Attributes
----------
value : npt.NDArray[np.float_]
value : npt.NDArray[np.float64]
idx_data_points : Optional[npt.NDArray[np.int_]]
idx_split_variable : int
linear_params: Optional[List[float]] = None
Expand All @@ -37,11 +37,11 @@ class Node:

def __init__(
self,
value: npt.NDArray[np.float_] = np.array([-1.0]),
value: npt.NDArray[np.float64] = np.array([-1.0]),
nvalue: int = 0,
idx_data_points: Optional[npt.NDArray[np.int_]] = None,
idx_split_variable: int = -1,
linear_params: Optional[List[npt.NDArray[np.float_]]] = None,
linear_params: Optional[List[npt.NDArray[np.float64]]] = None,
) -> None:
self.value = value
self.nvalue = nvalue
Expand All @@ -52,11 +52,11 @@ def __init__(
@classmethod
def new_leaf_node(
cls,
value: npt.NDArray[np.float_],
value: npt.NDArray[np.float64],
nvalue: int = 0,
idx_data_points: Optional[npt.NDArray[np.int_]] = None,
idx_split_variable: int = -1,
linear_params: Optional[List[npt.NDArray[np.float_]]] = None,
linear_params: Optional[List[npt.NDArray[np.float64]]] = None,
) -> "Node":
return cls(
value=value,
Expand Down Expand Up @@ -100,7 +100,7 @@ class Tree:
The dictionary's keys are integers that represent the nodes position.
The dictionary's values are objects of type Node that represent the split and leaf nodes
of the tree itself.
output: Optional[npt.NDArray[np.float_]]
output: Optional[npt.NDArray[np.float64]]
Array of shape number of observations, shape
split_rules : List[SplitRule]
List of SplitRule objects, one per column in input data.
Expand All @@ -121,7 +121,7 @@ class Tree:
def __init__(
self,
tree_structure: Dict[int, Node],
output: npt.NDArray[np.float_],
output: npt.NDArray[np.float64],
split_rules: List[SplitRule],
idx_leaf_nodes: Optional[List[int]] = None,
) -> None:
Expand All @@ -133,7 +133,7 @@ def __init__(
@classmethod
def new_tree(
cls,
leaf_node_value: npt.NDArray[np.float_],
leaf_node_value: npt.NDArray[np.float64],
idx_data_points: Optional[npt.NDArray[np.int_]],
num_observations: int,
shape: int,
Expand Down Expand Up @@ -189,7 +189,7 @@ def grow_leaf_node(
self,
current_node: Node,
selected_predictor: int,
split_value: npt.NDArray[np.float_],
split_value: npt.NDArray[np.float64],
index_leaf_node: int,
) -> None:
current_node.value = split_value
Expand Down Expand Up @@ -221,7 +221,7 @@ def get_split_variables(self) -> Generator[int, None, None]:
if node.is_split_node():
yield node.idx_split_variable

def _predict(self) -> npt.NDArray[np.float_]:
def _predict(self) -> npt.NDArray[np.float64]:
output = self.output

if self.idx_leaf_nodes is not None:
Expand All @@ -232,23 +232,23 @@ def _predict(self) -> npt.NDArray[np.float_]:

def predict(
self,
x: npt.NDArray[np.float_],
x: npt.NDArray[np.float64],
excluded: Optional[List[int]] = None,
shape: int = 1,
) -> npt.NDArray[np.float_]:
) -> npt.NDArray[np.float64]:
"""
Predict output of tree for an (un)observed point x.

Parameters
----------
x : npt.NDArray[np.float_]
x : npt.NDArray[np.float64]
Unobserved point
excluded: Optional[List[int]]
Indexes of the variables to exclude when computing predictions

Returns
-------
npt.NDArray[np.float_]
npt.NDArray[np.float64]
Value of the leaf value where the unobserved point lies.
"""
if excluded is None:
Expand All @@ -258,16 +258,16 @@ def predict(

def _traverse_tree(
self,
X: npt.NDArray[np.float_],
X: npt.NDArray[np.float64],
excluded: Optional[List[int]] = None,
shape: Union[int, Tuple[int, ...]] = 1,
) -> npt.NDArray[np.float_]:
) -> npt.NDArray[np.float64]:
"""
Traverse the tree starting from the root node given an (un)observed point.

Parameters
----------
X : npt.NDArray[np.float_]
X : npt.NDArray[np.float64]
(Un)observed point(s)
node_index : int
Index of the node to start the traversal from
Expand All @@ -278,7 +278,7 @@ def _traverse_tree(

Returns
-------
npt.NDArray[np.float_]
npt.NDArray[np.float64]
Leaf node value or mean of leaf node values
"""

Expand Down Expand Up @@ -327,14 +327,14 @@ def _traverse_tree(
return p_d

def _traverse_leaf_values(
self, leaf_values: List[npt.NDArray[np.float_]], leaf_n_values: List[int], node_index: int
self, leaf_values: List[npt.NDArray[np.float64]], leaf_n_values: List[int], node_index: int
) -> None:
"""
Traverse the tree appending leaf values starting from a particular node.

Parameters
----------
leaf_values : List[npt.NDArray[np.float_]]
leaf_values : List[npt.NDArray[np.float64]]
node_index : int
"""
node = self.get_node(node_index)
Expand Down
Loading
Loading