Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: arrow join methods #558

Merged
merged 5 commits into from
Jul 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 23 additions & 11 deletions narwhals/_arrow/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from narwhals.dependencies import get_pyarrow
from narwhals.dependencies import get_pyarrow_parquet
from narwhals.utils import flatten
from narwhals.utils import generate_unique_token

if TYPE_CHECKING:
from typing_extensions import Self
Expand Down Expand Up @@ -208,7 +209,7 @@ def join(
self,
other: Self,
*,
how: Literal["inner"] = "inner",
how: Literal["left", "inner", "outer", "cross", "anti", "semi"] = "inner",
left_on: str | list[str] | None,
right_on: str | list[str] | None,
) -> Self:
Expand All @@ -217,24 +218,35 @@ def join(
if isinstance(right_on, str):
right_on = [right_on]

if how == "cross": # type: ignore[comparison-overlap]
raise NotImplementedError

if how == "anti": # type: ignore[comparison-overlap]
raise NotImplementedError
how_to_join_map = {
"anti": "left anti",
"semi": "left semi",
"inner": "inner",
"left": "left outer",
}

if how == "semi": # type: ignore[comparison-overlap]
raise NotImplementedError
if how == "cross":
plx = self.__narwhals_namespace__()
key_token = generate_unique_token(
n_bytes=8, columns=[*self.columns, *other.columns]
)

if how == "left": # type: ignore[comparison-overlap]
raise NotImplementedError
return self._from_native_dataframe(
self.with_columns(**{key_token: plx.lit(0, None)})._native_dataframe.join(
other.with_columns(**{key_token: plx.lit(0, None)})._native_dataframe,
Comment on lines +235 to +236
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

truly wild

and i love it

keys=key_token,
right_keys=key_token,
join_type="inner",
right_suffix="_right",
),
).drop(key_token)

return self._from_native_dataframe(
self._native_dataframe.join(
other._native_dataframe,
keys=left_on,
right_keys=right_on,
join_type=how,
join_type=how_to_join_map[how],
right_suffix="_right",
),
)
Expand Down
3 changes: 3 additions & 0 deletions narwhals/_arrow/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,9 @@ def sample(
self, "sample", n=n, fraction=fraction, with_replacement=with_replacement
)

def fill_null(self: Self, value: Any) -> Self:
return reuse_series_implementation(self, "fill_null", value=value)

@property
def dt(self) -> ArrowExprDateTimeNamespace:
return ArrowExprDateTimeNamespace(self)
Expand Down
8 changes: 8 additions & 0 deletions narwhals/_arrow/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,14 @@ def sample(
mask = np.random.choice(idx, size=n, replace=with_replacement)
return self._from_native_series(pc.take(ser, mask))

def fill_null(self: Self, value: Any) -> Self:
pa = get_pyarrow()
pc = get_pyarrow_compute()
ser = self._native_series
dtype = ser.type

return self._from_native_series(pc.fill_null(ser, pa.scalar(value, dtype)))

@property
def shape(self) -> tuple[int]:
return (len(self._native_series),)
Expand Down
2 changes: 1 addition & 1 deletion narwhals/_pandas_like/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
from narwhals._pandas_like.expr import PandasLikeExpr
from narwhals._pandas_like.utils import Implementation
from narwhals._pandas_like.utils import create_native_series
from narwhals._pandas_like.utils import generate_unique_token
from narwhals._pandas_like.utils import horizontal_concat
from narwhals._pandas_like.utils import translate_dtype
from narwhals._pandas_like.utils import validate_dataframe_comparand
Expand All @@ -23,6 +22,7 @@
from narwhals.dependencies import get_numpy
from narwhals.dependencies import get_pandas
from narwhals.utils import flatten
from narwhals.utils import generate_unique_token

if TYPE_CHECKING:
from typing_extensions import Self
Expand Down
29 changes: 0 additions & 29 deletions narwhals/_pandas_like/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from __future__ import annotations

import secrets
from enum import Enum
from enum import auto
from typing import TYPE_CHECKING
Expand Down Expand Up @@ -458,31 +457,3 @@ def int_dtype_mapper(dtype: Any) -> str:
if str(dtype).lower() != str(dtype): # pragma: no cover
return "Int64"
return "int64"


def generate_unique_token(n_bytes: int, columns: list[str]) -> str: # pragma: no cover
"""Generates a unique token of specified n_bytes that is not present in the given list of columns.

Arguments:
n_bytes : The number of bytes to generate for the token.
columns : The list of columns to check for uniqueness.

Returns:
A unique token that is not present in the given list of columns.

Raises:
AssertionError: If a unique token cannot be generated after 100 attempts.
"""
counter = 0
while True:
token = secrets.token_hex(n_bytes)
if token not in columns:
return token

counter += 1
if counter > 100:
msg = (
"Internal Error: Narwhals was not able to generate a column name to perform cross "
"join operation"
)
raise AssertionError(msg)
29 changes: 29 additions & 0 deletions narwhals/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import re
import secrets
from typing import TYPE_CHECKING
from typing import Any
from typing import Iterable
Expand Down Expand Up @@ -325,3 +326,31 @@ def is_ordered_categorical(series: Series) -> bool:
return native_series.type.ordered # type: ignore[no-any-return]
# If it doesn't match any of the above, let's just play it safe and return False.
return False # pragma: no cover


def generate_unique_token(n_bytes: int, columns: list[str]) -> str: # pragma: no cover
"""Generates a unique token of specified n_bytes that is not present in the given list of columns.

Arguments:
n_bytes : The number of bytes to generate for the token.
columns : The list of columns to check for uniqueness.

Returns:
A unique token that is not present in the given list of columns.

Raises:
AssertionError: If a unique token cannot be generated after 100 attempts.
"""
counter = 0
while True:
token = secrets.token_hex(n_bytes)
if token not in columns:
return token

counter += 1
if counter > 100:
msg = (
"Internal Error: Narwhals was not able to generate a column name to perform given "
"join operation"
)
raise AssertionError(msg)
7 changes: 1 addition & 6 deletions tests/expr_and_series/fill_null_test.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
from typing import Any

import pytest

import narwhals.stable.v1 as nw
from tests.utils import compare_dicts

Expand All @@ -12,10 +10,7 @@
}


def test_fill_null(request: Any, constructor: Any) -> None:
if "pyarrow_table" in str(constructor):
request.applymarker(pytest.mark.xfail)

def test_fill_null(constructor: Any) -> None:
df = nw.from_native(constructor(data), eager_only=True)

result = df.with_columns(nw.col("a", "b", "c").fill_null(99))
Expand Down
37 changes: 8 additions & 29 deletions tests/frame/join_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,15 +37,11 @@ def test_inner_join_single_key(constructor: Any) -> None:
compare_dicts(result, expected)


def test_cross_join(request: Any, constructor: Any) -> None:
if "pyarrow_table" in str(constructor):
request.applymarker(pytest.mark.xfail)

def test_cross_join(constructor: Any) -> None:
data = {"a": [1, 3, 2]}
df = nw.from_native(constructor(data))
result = df.join(df, how="cross") # type: ignore[arg-type]

expected = {"a": [1, 1, 1, 3, 3, 3, 2, 2, 2], "a_right": [1, 3, 2, 1, 3, 2, 1, 3, 2]}
result = df.join(df, how="cross").sort("a", "a_right") # type: ignore[arg-type]
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Without sorting, the result is flaky as order is not guaranteed

expected = {"a": [1, 1, 1, 2, 2, 2, 3, 3, 3], "a_right": [1, 2, 3, 1, 2, 3, 1, 2, 3]}
compare_dicts(result, expected)

with pytest.raises(ValueError, match="Can not pass left_on, right_on for cross join"):
Expand All @@ -71,15 +67,11 @@ def test_cross_join_non_pandas() -> None:
],
)
def test_anti_join(
request: Any,
constructor: Any,
join_key: list[str],
filter_expr: nw.Expr,
expected: dict[str, list[Any]],
) -> None:
if "pyarrow_table" in str(constructor):
request.applymarker(pytest.mark.xfail)

data = {"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8, 9]}
df = nw.from_native(constructor(data))
other = df.filter(filter_expr)
Expand All @@ -96,15 +88,11 @@ def test_anti_join(
],
)
def test_semi_join(
request: Any,
constructor: Any,
join_key: list[str],
filter_expr: nw.Expr,
expected: dict[str, list[Any]],
) -> None:
if "pyarrow_table" in str(constructor):
request.applymarker(pytest.mark.xfail)

data = {"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8, 9]}
df = nw.from_native(constructor(data))
other = df.filter(filter_expr)
Expand All @@ -127,10 +115,7 @@ def test_join_not_implemented(constructor: Any, how: str) -> None:


@pytest.mark.filterwarnings("ignore:the default coalesce behavior")
def test_left_join(request: Any, constructor: Any) -> None:
if "pyarrow_table" in str(constructor):
request.applymarker(pytest.mark.xfail)

def test_left_join(constructor: Any) -> None:
data_left = {"a": [1.0, 2, 3], "b": [4.0, 5, 6]}
data_right = {"a": [1.0, 2, 3], "c": [4.0, 5, 7]}
df_left = nw.from_native(constructor(data_left), eager_only=True)
Expand All @@ -143,10 +128,7 @@ def test_left_join(request: Any, constructor: Any) -> None:


@pytest.mark.filterwarnings("ignore: the default coalesce behavior")
def test_left_join_multiple_column(request: Any, constructor: Any) -> None:
if "pyarrow_table" in str(constructor):
request.applymarker(pytest.mark.xfail)

def test_left_join_multiple_column(constructor: Any) -> None:
data_left = {"a": [1, 2, 3], "b": [4, 5, 6]}
data_right = {"a": [1, 2, 3], "c": [4, 5, 6]}
df_left = nw.from_native(constructor(data_left), eager_only=True)
Expand All @@ -157,12 +139,9 @@ def test_left_join_multiple_column(request: Any, constructor: Any) -> None:


@pytest.mark.filterwarnings("ignore: the default coalesce behavior")
def test_left_join_overlapping_column(request: Any, constructor: Any) -> None:
if "pyarrow_table" in str(constructor):
request.applymarker(pytest.mark.xfail)

data_left = {"a": [1, 2, 3], "b": [4, 5, 6], "d": [1, 4, 2]}
data_right = {"a": [1, 2, 3], "c": [4, 5, 6], "d": [1, 4, 2]}
def test_left_join_overlapping_column(constructor: Any) -> None:
data_left = {"a": [1.0, 2, 3], "b": [4.0, 5, 6], "d": [1.0, 4, 2]}
data_right = {"a": [1.0, 2, 3], "c": [4.0, 5, 6], "d": [1.0, 4, 2]}
df_left = nw.from_native(constructor(data_left), eager_only=True)
df_right = nw.from_native(constructor(data_right), eager_only=True)
result = df_left.join(df_right, left_on="b", right_on="c", how="left")
Expand Down
16 changes: 16 additions & 0 deletions tests/hypothesis/test_join.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import pandas as pd
import polars as pl
import pyarrow as pa
import pytest
from hypothesis import assume
from hypothesis import given
Expand Down Expand Up @@ -163,3 +164,18 @@ def test_left_join( # pragma: no cover
)
).select(pl.all().fill_null(float("nan")))
compare_dicts(result_pd.to_dict(as_series=False), result_pl.to_dict(as_series=False))
# For PyArrow, insert an extra sort, as the order of rows isn't guaranteed
result_pa = (
nw.from_native(pa.table(data_left), eager_only=True)
.join(
nw.from_native(pa.table(data_right), eager_only=True),
how="left",
left_on=left_key,
right_on=right_key,
)
.select(nw.all().cast(nw.Float64).fill_null(float("nan")))
.pipe(lambda df: df.sort(df.columns))
)
compare_dicts(
result_pa, result_pd.pipe(lambda df: df.sort(df.columns)).to_dict(as_series=False)
)
1 change: 0 additions & 1 deletion utils/check_backend_completeness.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
"DataFrame.pipe",
"DataFrame.unique",
"Series.drop_nulls",
"Series.fill_null",
"Series.from_iterable",
"Series.is_between",
"Series.is_duplicated",
Expand Down
Loading