Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Check the total number of rows throughout the catalog import pipeline #345

Merged
merged 5 commits into from
Jul 23, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/hipscat_import/catalog/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,9 @@ class ImportArguments(RuntimeArguments):
use_schema_file: str | None = None
"""path to a parquet file with schema metadata. this will be used for column
metadata when writing the files, if specified"""
expected_total_rows: int = 0
"""number of expected rows found in the dataset. if non-zero, and we find we have
a different number of rows, the pipeline will exit."""
constant_healpix_order: int = -1
"""healpix order to use when mapping. if this is
a positive number, this will be the order of all final pixels and we
Expand Down
7 changes: 7 additions & 0 deletions src/hipscat_import/catalog/resume_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,7 @@ def get_alignment_file(
highest_healpix_order,
lowest_healpix_order,
pixel_threshold,
expected_total_rows,
) -> str:
"""Get a pointer to the existing alignment file for the pipeline, or
generate a new alignment using provided arguments.
Expand All @@ -217,6 +218,7 @@ def get_alignment_file(
lowest_healpix_order (int): the lowest healpix order (e.g. 1-5). specifying a lowest order
constrains the partitioning to prevent spatially large pixels.
threshold (int): the maximum number of objects allowed in a single pixel
expected_total_rows (int): number of expected rows found in the dataset.

Returns:
path to cached alignment file.
Expand Down Expand Up @@ -249,6 +251,11 @@ def get_alignment_file(
self.destination_pixel_map = [
(order, pix, count) for (order, pix, count) in self.destination_pixel_map if int(count) > 0
]
total_rows = sum(count for (_, _, count) in self.destination_pixel_map)
if total_rows != expected_total_rows:
hombit marked this conversation as resolved.
Show resolved Hide resolved
raise ValueError(
f"Number of rows ({total_rows}) does not match expectation ({expected_total_rows})"
)

return file_name

Expand Down
9 changes: 8 additions & 1 deletion src/hipscat_import/catalog/run_import.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,13 +47,20 @@ def run(args, client):

with args.resume_plan.print_progress(total=2, stage_name="Binning") as step_progress:
raw_histogram = args.resume_plan.read_histogram(args.mapping_healpix_order)
total_rows = int(raw_histogram.sum())
if args.expected_total_rows > 0 and args.expected_total_rows != total_rows:
raise ValueError(
f"Number of rows ({total_rows}) does not match expectation ({args.expected_total_rows})"
)

step_progress.update(1)
alignment_file = args.resume_plan.get_alignment_file(
raw_histogram,
args.constant_healpix_order,
args.highest_healpix_order,
args.lowest_healpix_order,
args.pixel_threshold,
total_rows,
)

step_progress.update(1)
Expand Down Expand Up @@ -112,7 +119,7 @@ def run(args, client):

# All done - write out the metadata
with args.resume_plan.print_progress(total=5, stage_name="Finishing") as step_progress:
catalog_info = args.to_catalog_info(int(raw_histogram.sum()))
catalog_info = args.to_catalog_info(total_rows)
io.write_provenance_info(
catalog_base_dir=args.catalog_path,
dataset_info=catalog_info,
Expand Down
2 changes: 1 addition & 1 deletion tests/hipscat_import/catalog/test_map_reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,7 @@ def test_split_pixels_headers(formats_headers_csv, assert_parquet_file_ids, tmp_
plan = ResumePlan(tmp_path=tmp_path, progress_bar=False, input_paths=["foo1"])
raw_histogram = np.full(12, 0)
raw_histogram[11] = 131
alignment_file = plan.get_alignment_file(raw_histogram, -1, 0, 0, 1_000)
alignment_file = plan.get_alignment_file(raw_histogram, -1, 0, 0, 1_000, 131)
mr.split_pixels(
input_file=formats_headers_csv,
pickled_reader_file=pickle_file_reader(tmp_path, get_file_reader("csv")),
Expand Down
15 changes: 15 additions & 0 deletions tests/hipscat_import/catalog/test_resume_plan.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Test catalog resume logic"""

import numpy as np
import numpy.testing as npt
import pytest

Expand Down Expand Up @@ -110,6 +111,20 @@ def test_read_write_histogram(tmp_path):
npt.assert_array_equal(result, histogram.to_array())


def test_get_alignment_file(tmp_path):
plan = ResumePlan(tmp_path=tmp_path, progress_bar=False, input_paths=["foo1"])
raw_histogram = np.full(12, 0)
raw_histogram[11] = 131
alignment_file = plan.get_alignment_file(raw_histogram, -1, 0, 0, 1_000, 131)

alignment_file2 = plan.get_alignment_file(raw_histogram, -1, 0, 0, 1_000, 131)

assert alignment_file == alignment_file2

with pytest.raises(ValueError, match="does not match expectation"):
plan.get_alignment_file(raw_histogram, -1, 0, 0, 1_000, 130)


def never_fails():
"""Method never fails, but never marks intermediate success file."""
return
Expand Down
23 changes: 23 additions & 0 deletions tests/hipscat_import/catalog/test_run_import.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,3 +347,26 @@ def test_dask_runner_stats_only(dask_client, small_sky_parts_dir, tmp_path):
assert catalog.catalog_info.ra_column == "ra"
assert catalog.catalog_info.dec_column == "dec"
assert len(catalog.get_healpix_pixels()) == 1


@pytest.mark.dask
def test_import_mismatch_expectation(
dask_client,
small_sky_parts_dir,
tmp_path,
):
"""Test that the pipeline execution fails if the number of rows does not match
explicit (but wrong) expectation."""
args = ImportArguments(
output_artifact_name="small_sky",
input_path=small_sky_parts_dir,
file_reader="csv",
output_path=tmp_path,
dask_tmp=tmp_path,
highest_healpix_order=1,
progress_bar=False,
expected_total_rows=1_000,
)

with pytest.raises(ValueError, match="does not match expectation"):
runner.run(args, dask_client)