Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Initial implementation of association pipeline. #69

Merged
merged 5 commits into from
Apr 21, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 16 additions & 1 deletion src/hipscat_import/association/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

from hipscat_import.runtime_arguments import RuntimeArguments

# pylint: disable=too-many-instance-attributes


@dataclass
class AssociationArguments(RuntimeArguments):
Expand All @@ -21,24 +23,37 @@ class AssociationArguments(RuntimeArguments):
join_id_column: str = ""
join_foreign_key: str = ""

join_how: str = "inner"
compute_partition_size: int = 1_000_000_000

def __post_init__(self):
RuntimeArguments._check_arguments(self)
self._check_arguments()

def _check_arguments(self):
super()._check_arguments()
if not self.primary_input_catalog_path:
raise ValueError("primary_input_catalog_path is required")
if not self.primary_id_column:
raise ValueError("primary_id_column is required")
if not self.primary_join_column:
raise ValueError("primary_join_column is required")
if self.primary_id_column in ["primary_id", "join_id"]:
raise ValueError("primary_id_column uses a reserved column name")

if not self.join_input_catalog_path:
raise ValueError("join_input_catalog_path is required")
if not self.join_id_column:
raise ValueError("join_id_column is required")
if not self.join_foreign_key:
raise ValueError("join_foreign_key is required")
if self.join_id_column in ["primary_id", "join_id"]:
raise ValueError("join_id_column uses a reserved column name")

if self.join_how not in ["left", "right", "outer", "inner"]:
raise ValueError("join_how must be one of left, right, outer, or inner")

if self.compute_partition_size < 100_000:
raise ValueError("compute_partition_size must be at least 100_000")

def to_catalog_parameters(self) -> CatalogParameters:
"""Convert importing arguments into hipscat catalog parameters.
Expand Down
10 changes: 6 additions & 4 deletions src/hipscat_import/association/map_reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def map_association(args):
Implementation notes:

Because we may be joining to a column that is NOT the natural/primary key
(on either side of the join), we fetch both the identifying column and the
(on either side of the join), we fetch both the identifying column and the
join predicate column, possibly duplicating one of the columns.

This way, when we drop the join predicate columns at the end of the process,
Expand Down Expand Up @@ -86,7 +86,7 @@ def map_association(args):

## Join the two data sets on the shared join predicate.
join_data = primary_index.merge(
join_index, how="inner", left_index=True, right_index=True
join_index, how=args.join_how, left_index=True, right_index=True
).reset_index()

## Write out a summary of each partition join
Expand Down Expand Up @@ -134,7 +134,7 @@ def map_association(args):
path=args.tmp_path,
engine="pyarrow",
partition_on=["Norder", "Dir", "Npix", "join_Norder", "join_Dir", "join_Npix"],
compute_kwargs={"partition_size": 1_000_000_000},
compute_kwargs={"partition_size": args.compute_partition_size},
write_index=False,
)

Expand Down Expand Up @@ -199,6 +199,8 @@ def reduce_association(input_path, output_path):
f" Expected {partition['num_rows']}, wrote {rows_written}",
)

pq.write_table(table, where=output_file)
table.to_pandas().set_index("primary_hipscat_index").sort_index().to_parquet(
output_file
)

return data_frame["num_rows"].sum()
7 changes: 3 additions & 4 deletions src/hipscat_import/association/run_association.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,14 @@
from tqdm import tqdm

from hipscat_import.association.arguments import AssociationArguments
from hipscat_import.association.map_reduce import (map_association,
reduce_association)
from hipscat_import.association.map_reduce import map_association, reduce_association


def _validate_args(args):
if not args:
raise ValueError("args is required and should be type AssociationArguments")
raise TypeError("args is required and should be type AssociationArguments")
if not isinstance(args, AssociationArguments):
raise ValueError("args must be type AssociationArguments")
raise TypeError("args must be type AssociationArguments")


def run(args):
Expand Down
103 changes: 103 additions & 0 deletions tests/hipscat_import/association/test_association_argument.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,109 @@ def test_empty_required(tmp_path, small_sky_object_catalog):
overwrite=True,
)

delucchi-cmu marked this conversation as resolved.
Show resolved Hide resolved
with pytest.raises(ValueError, match="output_path"):
AssociationArguments(
primary_input_catalog_path=small_sky_object_catalog,
primary_id_column="id",
primary_join_column="id",
join_input_catalog_path=small_sky_object_catalog,
join_id_column="id",
join_foreign_key="id",
output_path="", ## empty
output_catalog_name="small_sky_self_join",
overwrite=True,
)

with pytest.raises(ValueError, match="output_catalog_name"):
AssociationArguments(
primary_input_catalog_path=small_sky_object_catalog,
primary_id_column="id",
primary_join_column="id",
join_input_catalog_path=small_sky_object_catalog,
join_id_column="id",
join_foreign_key="id",
output_path=tmp_path,
output_catalog_name="", ## empty
overwrite=True,
)

with pytest.raises(ValueError, match="join_how"):
AssociationArguments(
primary_input_catalog_path=small_sky_object_catalog,
primary_id_column="id",
primary_join_column="id",
join_input_catalog_path=small_sky_object_catalog,
join_id_column="id",
join_foreign_key="id",
output_path=tmp_path,
output_catalog_name="small_sky_self_join",
join_how="", ## empty
overwrite=True,
)


def test_column_names(tmp_path, small_sky_object_catalog):
"""Test validation of column names."""
with pytest.raises(ValueError, match="primary_id_column"):
AssociationArguments(
primary_input_catalog_path=small_sky_object_catalog,
primary_id_column="primary_id",
primary_join_column="id",
join_input_catalog_path=small_sky_object_catalog,
join_id_column="id",
join_foreign_key="id",
output_path=tmp_path,
output_catalog_name="bad_columns", ## empty
overwrite=True,
)

with pytest.raises(ValueError, match="join_id_column"):
AssociationArguments(
primary_input_catalog_path=small_sky_object_catalog,
primary_id_column="id",
primary_join_column="id",
join_input_catalog_path=small_sky_object_catalog,
join_id_column="primary_id",
join_foreign_key="id",
output_path=tmp_path,
output_catalog_name="bad_columns", ## empty
overwrite=True,
)


def test_join_how(tmp_path, small_sky_object_catalog):
"""Test validation of join how."""
with pytest.raises(ValueError, match="join_how"):
AssociationArguments(
primary_input_catalog_path=small_sky_object_catalog,
primary_id_column="id",
primary_join_column="id",
join_input_catalog_path=small_sky_object_catalog,
join_id_column="id",
join_foreign_key="id",
output_path=tmp_path,
output_catalog_name="bad_columns",
join_how="middle", ## not a valid join option
overwrite=True,
)
delucchi-cmu marked this conversation as resolved.
Show resolved Hide resolved


def test_compute_partition_size(tmp_path, small_sky_object_catalog):
"""Test validation of compute_partition_size."""
with pytest.raises(ValueError, match="compute_partition_size"):
AssociationArguments(
primary_input_catalog_path=small_sky_object_catalog,
primary_id_column="id",
primary_join_column="id",
join_input_catalog_path=small_sky_object_catalog,
join_id_column="id",
join_foreign_key="id",
output_path=tmp_path,
output_catalog_name="bad_columns",
compute_partition_size=10, ## not a valid join option
overwrite=True,
)


def test_all_required_args(tmp_path, small_sky_object_catalog):
"""Required arguments are provided."""
Expand Down
121 changes: 121 additions & 0 deletions tests/hipscat_import/association/test_association_map_reduce.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
"""Test behavior of map reduce methods in association."""

import os

import pandas as pd
import pytest

from hipscat_import.association.map_reduce import reduce_association


def test_reduce_bad_inputs(tmp_path, assert_text_file_matches):
"""Test reducing with corrupted input."""

input_path = os.path.join(tmp_path, "incomplete_inputs")
os.makedirs(input_path, exist_ok=True)

output_path = os.path.join(tmp_path, "output")
os.makedirs(output_path, exist_ok=True)

## We don't even have a partitions file
with pytest.raises(FileNotFoundError):
reduce_association(input_path=input_path, output_path=output_path)

## Create a partitions file, but it doesn't have the right columns
partitions_data = pd.DataFrame(
data=[[700, 282.5, -58.5], [701, 299.5, -48.5]],
columns=["id", "ra", "dec"],
)
partitions_csv_file = os.path.join(input_path, "partitions.csv")
partitions_data.to_csv(partitions_csv_file, index=False)

with pytest.raises(KeyError, match="primary_hipscat_index"):
reduce_association(input_path=input_path, output_path=output_path)

## Create a partitions file, but it doesn't have corresponding parquet data.
partitions_data = pd.DataFrame(
data=[[0, 0, 11, 0, 0, 11, 131]],
columns=[
"Norder",
"Dir",
"Npix",
"join_Norder",
"join_Dir",
"join_Npix",
"primary_hipscat_index",
],
)
partitions_data.to_csv(partitions_csv_file, index=False)

with pytest.raises(FileNotFoundError):
reduce_association(input_path=input_path, output_path=output_path)

## We still wrote out the partition info file, though!
expected_lines = [
"Norder,Dir,Npix,join_Norder,join_Dir,join_Npix,num_rows",
"0,0,11,0,0,11,131",
]
metadata_filename = os.path.join(output_path, "partition_join_info.csv")
assert_text_file_matches(expected_lines, metadata_filename)


def test_reduce_bad_expectation(tmp_path):
"""Test reducing with corrupted input."""
input_path = os.path.join(tmp_path, "incomplete_inputs")
os.makedirs(input_path, exist_ok=True)

output_path = os.path.join(tmp_path, "output")
os.makedirs(output_path, exist_ok=True)

## Create a partitions file, and a parquet file with not-enough rows.
partitions_data = pd.DataFrame(
data=[[0, 0, 11, 0, 0, 11, 3]],
columns=[
"Norder",
"Dir",
"Npix",
"join_Norder",
"join_Dir",
"join_Npix",
"primary_hipscat_index",
],
)
partitions_csv_file = os.path.join(input_path, "partitions.csv")
partitions_data.to_csv(partitions_csv_file, index=False)

parquet_dir = os.path.join(
input_path,
"Norder=0",
"Dir=0",
"Npix=11",
"join_Norder=0",
"join_Dir=0",
"join_Npix=11",
)
os.makedirs(parquet_dir, exist_ok=True)

parquet_data = pd.DataFrame(
data=[[700, 7_000_000, 800, 8_000_000], [701, 7_000_100, 801, 8_001_000]],
columns=[
"primary_id",
"primary_hipscat_index",
"join_id",
"join_hipscat_index",
],
)
parquet_data.to_parquet(os.path.join(parquet_dir, "part0.parquet"))
with pytest.raises(ValueError, match="Unexpected"):
reduce_association(input_path=input_path, output_path=output_path)

## Add one more row in another file, and the expectation is met.
parquet_data = pd.DataFrame(
data=[[702, 7_002_000, 802, 8_002_000]],
columns=[
"primary_id",
"primary_hipscat_index",
"join_id",
"join_hipscat_index",
],
)
parquet_data.to_parquet(os.path.join(parquet_dir, "part1.parquet"))
reduce_association(input_path=input_path, output_path=output_path)
13 changes: 8 additions & 5 deletions tests/hipscat_import/association/test_run_association.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,14 @@

def test_empty_args():
"""Runner should fail with empty arguments"""
with pytest.raises(ValueError):
with pytest.raises(TypeError):
runner.run(None)


def test_bad_args():
"""Runner should fail with mis-typed arguments"""
args = {"output_catalog_name": "bad_arg_type"}
with pytest.raises(ValueError):
with pytest.raises(TypeError):
runner.run(args)


Expand Down Expand Up @@ -97,8 +97,9 @@ def test_object_to_source(
data_frame = pd.read_parquet(output_file, engine="pyarrow")
npt.assert_array_equal(
data_frame.columns,
["primary_id", "primary_hipscat_index", "join_id", "join_hipscat_index"],
["primary_id", "join_id", "join_hipscat_index"],
)
assert data_frame.index.name == "primary_hipscat_index"
assert len(data_frame) == 50
ids = data_frame["primary_id"]
assert np.logical_and(ids >= 700, ids < 832).all()
Expand Down Expand Up @@ -179,8 +180,9 @@ def test_source_to_object(
data_frame = pd.read_parquet(output_file, engine="pyarrow")
npt.assert_array_equal(
data_frame.columns,
["primary_id", "primary_hipscat_index", "join_id", "join_hipscat_index"],
["primary_id", "join_id", "join_hipscat_index"],
)
assert data_frame.index.name == "primary_hipscat_index"
assert len(data_frame) == 50
ids = data_frame["primary_id"]
assert np.logical_and(ids >= 70_000, ids < 87161).all()
Expand Down Expand Up @@ -247,8 +249,9 @@ def test_self_join(
data_frame = pd.read_parquet(output_file, engine="pyarrow")
npt.assert_array_equal(
data_frame.columns,
["primary_id", "primary_hipscat_index", "join_id", "join_hipscat_index"],
["primary_id", "join_id", "join_hipscat_index"],
)
assert data_frame.index.name == "primary_hipscat_index"
assert len(data_frame) == 131
ids = data_frame["primary_id"]
assert np.logical_and(ids >= 700, ids < 832).all()
Expand Down
1 change: 1 addition & 0 deletions tests/hipscat_import/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,7 @@ def assert_parquet_file_index(file_name, expected_values):
data_frame = pd.read_parquet(file_name, engine="pyarrow")
values = data_frame.index.values.tolist()
values.sort()
expected_values.sort()

assert len(values) == len(
expected_values
Expand Down