Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Initial implementation of association pipeline. #69

Merged
merged 5 commits into from
Apr 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions src/hipscat_import/association/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
"""Modules for creating a new association table from an equijoin between two catalogs"""

from .arguments import AssociationArguments
from .map_reduce import map_association, reduce_association
from .run_association import run
74 changes: 74 additions & 0 deletions src/hipscat_import/association/arguments.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
"""Utility to hold all arguments required throughout association pipeline"""

from dataclasses import dataclass

from hipscat.catalog import CatalogParameters

from hipscat_import.runtime_arguments import RuntimeArguments

# pylint: disable=too-many-instance-attributes


@dataclass
class AssociationArguments(RuntimeArguments):
"""Data class for holding association arguments"""

## Input - Primary
primary_input_catalog_path: str = ""
primary_id_column: str = ""
primary_join_column: str = ""

## Input - Join
join_input_catalog_path: str = ""
join_id_column: str = ""
join_foreign_key: str = ""

compute_partition_size: int = 1_000_000_000

def __post_init__(self):
self._check_arguments()

def _check_arguments(self):
super()._check_arguments()
if not self.primary_input_catalog_path:
raise ValueError("primary_input_catalog_path is required")
if not self.primary_id_column:
raise ValueError("primary_id_column is required")
if not self.primary_join_column:
raise ValueError("primary_join_column is required")
if self.primary_id_column in ["primary_id", "join_id"]:
raise ValueError("primary_id_column uses a reserved column name")

if not self.join_input_catalog_path:
raise ValueError("join_input_catalog_path is required")
if not self.join_id_column:
raise ValueError("join_id_column is required")
if not self.join_foreign_key:
raise ValueError("join_foreign_key is required")
if self.join_id_column in ["primary_id", "join_id"]:
raise ValueError("join_id_column uses a reserved column name")

if self.compute_partition_size < 100_000:
raise ValueError("compute_partition_size must be at least 100_000")

def to_catalog_parameters(self) -> CatalogParameters:
"""Convert importing arguments into hipscat catalog parameters.

Returns:
CatalogParameters for catalog being created.
"""
return CatalogParameters(
catalog_name=self.output_catalog_name,
catalog_type="association",
output_path=self.output_path,
)

def additional_runtime_provenance_info(self):
return {
"primary_input_catalog_path": str(self.primary_input_catalog_path),
"primary_id_column": self.primary_id_column,
"primary_join_column": self.primary_join_column,
"join_input_catalog_path": str(self.join_input_catalog_path),
"join_id_column": self.join_id_column,
"join_foreign_key": self.join_foreign_key,
}
202 changes: 202 additions & 0 deletions src/hipscat_import/association/map_reduce.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,202 @@
"""Create partitioned association table between two catalogs"""

import dask.dataframe as dd
import pyarrow.parquet as pq
from hipscat.io import file_io, paths


def map_association(args):
"""Using dask dataframes, create an association between two catalogs.
This will write out sharded parquet files to the temp (intermediate)
directory.

Implementation notes:

Because we may be joining to a column that is NOT the natural/primary key
(on either side of the join), we fetch both the identifying column and the
join predicate column, possibly duplicating one of the columns.

This way, when we drop the join predicate columns at the end of the process,
we will still have the identifying columns. However, it makes the loading
of each input catalog more verbose.
"""
## Read and massage primary input data
## NB: We may be joining on a column that is NOT the natural primary key.
single_primary_column = args.primary_id_column == args.primary_join_column
read_columns = [
"Norder",
"Dir",
"Npix",
]
if single_primary_column:
read_columns = [args.primary_id_column] + read_columns
else:
read_columns = [args.primary_join_column, args.primary_id_column] + read_columns

primary_index = dd.read_parquet(
path=args.primary_input_catalog_path,
columns=read_columns,
dataset={"partitioning": "hive"},
)
if single_primary_column:
## Duplicate the column to simplify later steps
primary_index["primary_id"] = primary_index[args.primary_join_column]
delucchi-cmu marked this conversation as resolved.
Show resolved Hide resolved
rename_columns = {
args.primary_join_column: "primary_join",
"_hipscat_index": "primary_hipscat_index",
}
if not single_primary_column:
rename_columns[args.primary_id_column] = "primary_id"
primary_index = (
primary_index.reset_index()
.rename(columns=rename_columns)
.set_index("primary_join")
)

## Read and massage join input data
single_join_column = args.join_id_column == args.join_foreign_key
read_columns = [
"Norder",
"Dir",
"Npix",
]
if single_join_column:
read_columns = [args.join_id_column] + read_columns
else:
read_columns = [args.join_id_column, args.join_foreign_key] + read_columns

join_index = dd.read_parquet(
path=args.join_input_catalog_path,
columns=read_columns,
dataset={"partitioning": "hive"},
)
if single_join_column:
## Duplicate the column to simplify later steps
join_index["join_id"] = join_index[args.join_id_column]
rename_columns = {
args.join_foreign_key: "join_to_primary",
"_hipscat_index": "join_hipscat_index",
"Norder": "join_Norder",
"Dir": "join_Dir",
"Npix": "join_Npix",
}
if not single_join_column:
rename_columns[args.join_id_column] = "join_id"
join_index = (
join_index.reset_index()
.rename(columns=rename_columns)
.set_index("join_to_primary")
)

## Join the two data sets on the shared join predicate.
join_data = primary_index.merge(
join_index, how="inner", left_index=True, right_index=True
delucchi-cmu marked this conversation as resolved.
Show resolved Hide resolved
)

## Write out a summary of each partition join
groups = (
join_data.groupby(
["Norder", "Dir", "Npix", "join_Norder", "join_Dir", "join_Npix"],
group_keys=False,
)["primary_hipscat_index"]
.count()
.compute()
)
intermediate_partitions_file = file_io.append_paths_to_pointer(
args.tmp_path, "partitions.csv"
)
file_io.write_dataframe_to_csv(
dataframe=groups, file_pointer=intermediate_partitions_file
)

## Drop join predicate columns
join_data = join_data[
[
"Norder",
"Dir",
"Npix",
"join_Norder",
"join_Dir",
"join_Npix",
"primary_id",
"primary_hipscat_index",
"join_id",
"join_hipscat_index",
]
]

## Write out association table shards.
join_data.to_parquet(
path=args.tmp_path,
engine="pyarrow",
partition_on=["Norder", "Dir", "Npix", "join_Norder", "join_Dir", "join_Npix"],
compute_kwargs={"partition_size": args.compute_partition_size},
write_index=False,
)


def reduce_association(input_path, output_path):
"""Collate sharded parquet files into a single parquet file per partition"""
intermediate_partitions_file = file_io.append_paths_to_pointer(
input_path, "partitions.csv"
)
data_frame = file_io.load_csv_to_pandas(intermediate_partitions_file)

## Clean up the dataframe and write out as our new partition join info file.
data_frame = data_frame[data_frame["primary_hipscat_index"] != 0]
data_frame["num_rows"] = data_frame["primary_hipscat_index"]
data_frame = data_frame[
["Norder", "Dir", "Npix", "join_Norder", "join_Dir", "join_Npix", "num_rows"]
]
data_frame = data_frame.sort_values(["Norder", "Npix", "join_Norder", "join_Npix"])
file_io.write_dataframe_to_csv(
dataframe=data_frame,
file_pointer=file_io.append_paths_to_pointer(
output_path, "partition_join_info.csv"
),
index=False,
)

## For each partition, join all parquet shards into single parquet file.
for _, partition in data_frame.iterrows():
input_dir = paths.create_hive_directory_name(
input_path,
["Norder", "Dir", "Npix", "join_Norder", "join_Dir", "join_Npix"],
[
partition["Norder"],
partition["Dir"],
partition["Npix"],
partition["join_Norder"],
partition["join_Dir"],
partition["join_Npix"],
],
)
output_dir = paths.pixel_association_directory(
output_path,
partition["Norder"],
partition["Npix"],
partition["join_Norder"],
partition["join_Npix"],
)
file_io.make_directory(output_dir, exist_ok=True)
output_file = paths.pixel_association_file(
output_path,
partition["Norder"],
partition["Npix"],
partition["join_Norder"],
partition["join_Npix"],
)
table = pq.read_table(input_dir)
rows_written = len(table)

if rows_written != partition["num_rows"]:
raise ValueError(
delucchi-cmu marked this conversation as resolved.
Show resolved Hide resolved
"Unexpected number of objects ",
f" Expected {partition['num_rows']}, wrote {rows_written}",
)

table.to_pandas().set_index("primary_hipscat_index").sort_index().to_parquet(
output_file
)

return data_frame["num_rows"].sum()
50 changes: 50 additions & 0 deletions src/hipscat_import/association/run_association.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
"""Create partitioned association table between two catalogs
using dask dataframes for parallelization

Methods in this file set up a dask pipeline using dataframes.
The actual logic of the map reduce is in the `map_reduce.py` file.
"""

from hipscat.io import file_io, write_metadata
from tqdm import tqdm

from hipscat_import.association.arguments import AssociationArguments
from hipscat_import.association.map_reduce import map_association, reduce_association


def _validate_args(args):
if not args:
raise TypeError("args is required and should be type AssociationArguments")
if not isinstance(args, AssociationArguments):
raise TypeError("args must be type AssociationArguments")


def run(args):
"""Run the association pipeline"""
_validate_args(args)

with tqdm(total=1, desc="Mapping ", disable=not args.progress_bar) as step_progress:
map_association(args)
step_progress.update(1)

rows_written = 0
with tqdm(
total=1, desc="Reducing ", disable=not args.progress_bar
) as step_progress:
rows_written = reduce_association(args.tmp_path, args.catalog_path)
step_progress.update(1)

# All done - write out the metadata
with tqdm(
total=4, desc="Finishing", disable=not args.progress_bar
) as step_progress:
catalog_params = args.to_catalog_parameters()
catalog_params.total_rows = int(rows_written)
write_metadata.write_provenance_info(catalog_params, args.provenance_info())
step_progress.update(1)
write_metadata.write_catalog_info(catalog_params)
step_progress.update(1)
write_metadata.write_parquet_metadata(args.catalog_path)
step_progress.update(1)
file_io.remove_directory(args.tmp_path, ignore_errors=True)
step_progress.update(1)
Loading