Skip to content

Commit

Permalink
Check the total number of rows throughout the catalog import pipeline (
Browse files Browse the repository at this point in the history
…#345)

* Check the total number of rows throughout the catalog import pipeline

* Use returned total rows.

* Black formatting.

* Use hipscat from main for docs, as well as tests.
  • Loading branch information
delucchi-cmu authored Jul 23, 2024
1 parent 6c4451f commit 314a79a
Show file tree
Hide file tree
Showing 7 changed files with 66 additions and 3 deletions.
1 change: 1 addition & 0 deletions .github/workflows/build-documentation.yml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ jobs:
run: |
sudo apt-get update
python -m pip install --upgrade pip
if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
if [ -f docs/requirements.txt ]; then pip install -r docs/requirements.txt; fi
pip install .
- name: Install notebook requirements
Expand Down
3 changes: 3 additions & 0 deletions src/hipscat_import/catalog/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,9 @@ class ImportArguments(RuntimeArguments):
use_schema_file: str | None = None
"""path to a parquet file with schema metadata. this will be used for column
metadata when writing the files, if specified"""
expected_total_rows: int = 0
"""number of expected rows found in the dataset. if non-zero, and we find we have
a different number of rows, the pipeline will exit."""
constant_healpix_order: int = -1
"""healpix order to use when mapping. if this is
a positive number, this will be the order of all final pixels and we
Expand Down
7 changes: 7 additions & 0 deletions src/hipscat_import/catalog/resume_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,7 @@ def get_alignment_file(
highest_healpix_order,
lowest_healpix_order,
pixel_threshold,
expected_total_rows,
) -> str:
"""Get a pointer to the existing alignment file for the pipeline, or
generate a new alignment using provided arguments.
Expand All @@ -217,6 +218,7 @@ def get_alignment_file(
lowest_healpix_order (int): the lowest healpix order (e.g. 1-5). specifying a lowest order
constrains the partitioning to prevent spatially large pixels.
threshold (int): the maximum number of objects allowed in a single pixel
expected_total_rows (int): number of expected rows found in the dataset.
Returns:
path to cached alignment file.
Expand Down Expand Up @@ -249,6 +251,11 @@ def get_alignment_file(
self.destination_pixel_map = [
(order, pix, count) for (order, pix, count) in self.destination_pixel_map if int(count) > 0
]
total_rows = sum(count for (_, _, count) in self.destination_pixel_map)
if total_rows != expected_total_rows:
raise ValueError(
f"Number of rows ({total_rows}) does not match expectation ({expected_total_rows})"
)

return file_name

Expand Down
18 changes: 16 additions & 2 deletions src/hipscat_import/catalog/run_import.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,13 +47,20 @@ def run(args, client):

with args.resume_plan.print_progress(total=2, stage_name="Binning") as step_progress:
raw_histogram = args.resume_plan.read_histogram(args.mapping_healpix_order)
total_rows = int(raw_histogram.sum())
if args.expected_total_rows > 0 and args.expected_total_rows != total_rows:
raise ValueError(
f"Number of rows ({total_rows}) does not match expectation ({args.expected_total_rows})"
)

step_progress.update(1)
alignment_file = args.resume_plan.get_alignment_file(
raw_histogram,
args.constant_healpix_order,
args.highest_healpix_order,
args.lowest_healpix_order,
args.pixel_threshold,
total_rows,
)

step_progress.update(1)
Expand Down Expand Up @@ -112,7 +119,7 @@ def run(args, client):

# All done - write out the metadata
with args.resume_plan.print_progress(total=5, stage_name="Finishing") as step_progress:
catalog_info = args.to_catalog_info(int(raw_histogram.sum()))
catalog_info = args.to_catalog_info(total_rows)
io.write_provenance_info(
catalog_base_dir=args.catalog_path,
dataset_info=catalog_info,
Expand All @@ -131,7 +138,14 @@ def run(args, client):
partition_info_file = paths.get_partition_info_pointer(args.catalog_path)
partition_info.write_to_file(partition_info_file, storage_options=args.output_storage_options)
if not args.debug_stats_only:
write_parquet_metadata(args.catalog_path, storage_options=args.output_storage_options)
parquet_rows = write_parquet_metadata(
args.catalog_path, storage_options=args.output_storage_options
)
if total_rows > 0 and parquet_rows != total_rows:
raise ValueError(
f"Number of rows in parquet ({parquet_rows}) does not match expectation ({total_rows})"
)

else:
partition_info.write_to_metadata_files(
args.catalog_path, storage_options=args.output_storage_options
Expand Down
2 changes: 1 addition & 1 deletion tests/hipscat_import/catalog/test_map_reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,7 @@ def test_split_pixels_headers(formats_headers_csv, assert_parquet_file_ids, tmp_
plan = ResumePlan(tmp_path=tmp_path, progress_bar=False, input_paths=["foo1"])
raw_histogram = np.full(12, 0)
raw_histogram[11] = 131
alignment_file = plan.get_alignment_file(raw_histogram, -1, 0, 0, 1_000)
alignment_file = plan.get_alignment_file(raw_histogram, -1, 0, 0, 1_000, 131)
mr.split_pixels(
input_file=formats_headers_csv,
pickled_reader_file=pickle_file_reader(tmp_path, get_file_reader("csv")),
Expand Down
15 changes: 15 additions & 0 deletions tests/hipscat_import/catalog/test_resume_plan.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Test catalog resume logic"""

import numpy as np
import numpy.testing as npt
import pytest

Expand Down Expand Up @@ -110,6 +111,20 @@ def test_read_write_histogram(tmp_path):
npt.assert_array_equal(result, histogram.to_array())


def test_get_alignment_file(tmp_path):
plan = ResumePlan(tmp_path=tmp_path, progress_bar=False, input_paths=["foo1"])
raw_histogram = np.full(12, 0)
raw_histogram[11] = 131
alignment_file = plan.get_alignment_file(raw_histogram, -1, 0, 0, 1_000, 131)

alignment_file2 = plan.get_alignment_file(raw_histogram, -1, 0, 0, 1_000, 131)

assert alignment_file == alignment_file2

with pytest.raises(ValueError, match="does not match expectation"):
plan.get_alignment_file(raw_histogram, -1, 0, 0, 1_000, 130)


def never_fails():
"""Method never fails, but never marks intermediate success file."""
return
Expand Down
23 changes: 23 additions & 0 deletions tests/hipscat_import/catalog/test_run_import.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,3 +347,26 @@ def test_dask_runner_stats_only(dask_client, small_sky_parts_dir, tmp_path):
assert catalog.catalog_info.ra_column == "ra"
assert catalog.catalog_info.dec_column == "dec"
assert len(catalog.get_healpix_pixels()) == 1


@pytest.mark.dask
def test_import_mismatch_expectation(
dask_client,
small_sky_parts_dir,
tmp_path,
):
"""Test that the pipeline execution fails if the number of rows does not match
explicit (but wrong) expectation."""
args = ImportArguments(
output_artifact_name="small_sky",
input_path=small_sky_parts_dir,
file_reader="csv",
output_path=tmp_path,
dask_tmp=tmp_path,
highest_healpix_order=1,
progress_bar=False,
expected_total_rows=1_000,
)

with pytest.raises(ValueError, match="does not match expectation"):
runner.run(args, dask_client)

0 comments on commit 314a79a

Please sign in to comment.