Skip to content

Commit

Permalink
Move existing catalog check. (#466)
Browse files Browse the repository at this point in the history
  • Loading branch information
delucchi-cmu authored Jan 6, 2025
1 parent ca4700f commit bacabec
Show file tree
Hide file tree
Showing 8 changed files with 48 additions and 71 deletions.
27 changes: 7 additions & 20 deletions src/hats_import/catalog/run_import.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

import os
import pickle
from pathlib import Path

import hats.io.file_io as io
from hats.catalog import PartitionInfo
Expand All @@ -19,27 +18,13 @@
from hats_import.catalog.resume_plan import ResumePlan


def _validate_arguments(args):
"""
Verify that the args for run are valid: they exist, are of the appropriate type,
and do not specify an output which is a valid catalog.
Raises ValueError if they are invalid.
"""
def run(args, client):
"""Run catalog creation pipeline."""
if not args:
raise ValueError("args is required and should be type ImportArguments")
if not isinstance(args, ImportArguments):
raise ValueError("args must be type ImportArguments")

potential_path = Path(args.output_path) / args.output_artifact_name
if is_valid_catalog(potential_path):
raise ValueError(f"Output path {potential_path} already contains a valid catalog")


def run(args, client):
"""Run catalog creation pipeline."""
_validate_arguments(args)

resume_plan = ResumePlan(import_args=args)

pickled_reader_file = os.path.join(resume_plan.tmp_path, "reader.pickle")
Expand Down Expand Up @@ -137,7 +122,7 @@ def run(args, client):

# All done - write out the metadata
if resume_plan.should_run_finishing:
with resume_plan.print_progress(total=4, stage_name="Finishing") as step_progress:
with resume_plan.print_progress(total=5, stage_name="Finishing") as step_progress:
partition_info = PartitionInfo.from_healpix(resume_plan.get_destination_pixels())
partition_info_file = paths.get_partition_info_pointer(args.catalog_path)
partition_info.write_to_file(partition_info_file)
Expand All @@ -151,12 +136,14 @@ def run(args, client):
else:
partition_info.write_to_metadata_files(args.catalog_path)
step_progress.update(1)
io.write_fits_image(raw_histogram, paths.get_point_map_file_pointer(args.catalog_path))
step_progress.update(1)
catalog_info = args.to_table_properties(
total_rows, partition_info.get_highest_order(), partition_info.calculate_fractional_coverage()
)
catalog_info.to_properties_file(args.catalog_path)
step_progress.update(1)
io.write_fits_image(raw_histogram, paths.get_point_map_file_pointer(args.catalog_path))
step_progress.update(1)
resume_plan.clean_resume_files()
step_progress.update(1)
assert is_valid_catalog(args.catalog_path)
step_progress.update(1)
8 changes: 0 additions & 8 deletions src/hats_import/margin_cache/margin_cache.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,5 @@
from pathlib import Path

from hats.catalog import PartitionInfo
from hats.io import file_io, parquet_metadata, paths
from hats.io.validation import is_valid_catalog

import hats_import.margin_cache.margin_cache_map_reduce as mcmr
from hats_import.margin_cache.margin_cache_resume_plan import MarginCachePlan
Expand All @@ -18,11 +15,6 @@ def generate_margin_cache(args, client):
args (MarginCacheArguments): A valid `MarginCacheArguments` object.
client (dask.distributed.Client): A dask distributed client object.
"""
potential_path = Path(args.output_path) / args.output_artifact_name
# Verify that the planned output path is not occupied by a valid catalog
if is_valid_catalog(potential_path):
raise ValueError(f"Output path {potential_path} already contains a valid catalog")

resume_plan = MarginCachePlan(args)
original_catalog_metadata = paths.get_common_metadata_pointer(args.input_catalog_path)

Expand Down
3 changes: 3 additions & 0 deletions src/hats_import/runtime_arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from pathlib import Path

from hats.io import file_io
from hats.io.validation import is_valid_catalog
from upath import UPath

# pylint: disable=too-many-instance-attributes
Expand Down Expand Up @@ -89,6 +90,8 @@ def _check_arguments(self):
raise ValueError("dask_threads_per_worker should be greater than 0")

self.catalog_path = file_io.get_upath(self.output_path) / self.output_artifact_name
if is_valid_catalog(self.catalog_path):
raise ValueError(f"Output path {self.catalog_path} already contains a valid catalog")
if not self.resume:
file_io.remove_directory(self.catalog_path, ignore_errors=True)
file_io.make_directory(self.catalog_path, exist_ok=True)
Expand Down
13 changes: 13 additions & 0 deletions tests/hats_import/catalog/test_argument_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,3 +241,16 @@ def test_check_healpix_order_range():
check_healpix_order_range("two", "order_field")
with pytest.raises(TypeError, match="not supported"):
check_healpix_order_range(5, "order_field", upper_bound="ten")


def test_no_import_overwrite(small_sky_object_catalog, parquet_shards_dir):
"""Runner should refuse to overwrite a valid catalog"""
catalog_dir = small_sky_object_catalog.parent
catalog_name = small_sky_object_catalog.name
with pytest.raises(ValueError, match="already contains a valid catalog"):
ImportArguments(
input_path=parquet_shards_dir,
output_path=catalog_dir,
output_artifact_name=catalog_name,
file_reader="parquet",
)
30 changes: 0 additions & 30 deletions tests/hats_import/catalog/test_run_import.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,9 @@
from hats.pixel_math.sparse_histogram import SparseHistogram

import hats_import.catalog.run_import as runner
import hats_import.margin_cache.margin_cache as margin_runner
from hats_import.catalog.arguments import ImportArguments
from hats_import.catalog.file_readers import CsvReader
from hats_import.catalog.resume_plan import ResumePlan
from hats_import.margin_cache.margin_cache_arguments import MarginCacheArguments


def test_empty_args():
Expand All @@ -33,34 +31,6 @@ def test_bad_args():
runner.run(args, None)


def test_no_import_overwrite(small_sky_object_catalog, parquet_shards_dir):
"""Runner should refuse to overwrite a valid catalog"""
catalog_dir = small_sky_object_catalog.parent
catalog_name = small_sky_object_catalog.name
args = ImportArguments(
input_path=parquet_shards_dir,
output_path=catalog_dir,
output_artifact_name=catalog_name,
file_reader="parquet",
)
with pytest.raises(ValueError, match="already contains a valid catalog"):
runner.run(args, None)


def test_no_margin_cache_overwrite(small_sky_object_catalog):
"""Runner should refuse to generate margin cache which overwrites valid catalog"""
catalog_dir = small_sky_object_catalog.parent
catalog_name = small_sky_object_catalog.name
args = MarginCacheArguments(
input_catalog_path=small_sky_object_catalog,
output_path=catalog_dir,
margin_threshold=10.0,
output_artifact_name=catalog_name,
)
with pytest.raises(ValueError, match="already contains a valid catalog"):
margin_runner.generate_margin_cache(args, None)


@pytest.mark.dask
def test_resume_dask_runner(
dask_client,
Expand Down
15 changes: 7 additions & 8 deletions tests/hats_import/catalog/test_run_round_trip.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,15 +287,14 @@ def test_import_delete_provided_temp_directory(
"""Test that ALL intermediate files (and temporary base directory) are deleted
after successful import, when both `delete_intermediate_parquet_files` and
`delete_resume_log_files` are set to True."""
output_dir = tmp_path_factory.mktemp("small_sky_object_catalog")
output_dir = tmp_path_factory.mktemp("catalogs")
# Provided temporary directory, outside `output_dir`
temp = tmp_path_factory.mktemp("intermediate_files")
base_intermediate_dir = temp / "small_sky_object_catalog" / "intermediate"

# When at least one of the delete flags is set to False we do
# not delete the provided temporary base directory.
args = ImportArguments(
output_artifact_name="small_sky_object_catalog",
output_artifact_name="keep_log_files",
input_path=small_sky_parts_dir,
file_reader="csv",
output_path=output_dir,
Expand All @@ -307,10 +306,10 @@ def test_import_delete_provided_temp_directory(
delete_resume_log_files=False,
)
runner.run(args, dask_client)
assert_stage_level_files_exist(base_intermediate_dir)
assert_stage_level_files_exist(temp / "keep_log_files" / "intermediate")

args = ImportArguments(
output_artifact_name="small_sky_object_catalog",
output_artifact_name="keep_parquet_intermediate",
input_path=small_sky_parts_dir,
file_reader="csv",
output_path=output_dir,
Expand All @@ -323,11 +322,11 @@ def test_import_delete_provided_temp_directory(
resume=False,
)
runner.run(args, dask_client)
assert_intermediate_parquet_files_exist(base_intermediate_dir)
assert_intermediate_parquet_files_exist(temp / "keep_parquet_intermediate" / "intermediate")

# The temporary directory is deleted.
args = ImportArguments(
output_artifact_name="small_sky_object_catalog",
output_artifact_name="remove_all_intermediate",
input_path=small_sky_parts_dir,
file_reader="csv",
output_path=output_dir,
Expand All @@ -340,7 +339,7 @@ def test_import_delete_provided_temp_directory(
resume=False,
)
runner.run(args, dask_client)
assert not os.path.exists(temp)
assert not os.path.exists(temp / "remove_all_intermediate")


def assert_stage_level_files_exist(base_intermediate_dir):
Expand Down
13 changes: 13 additions & 0 deletions tests/hats_import/margin_cache/test_arguments_margin_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,3 +126,16 @@ def test_to_table_properties(small_sky_source_catalog, tmp_path):
assert catalog_info.total_rows == 10
assert catalog_info.ra_column == "source_ra"
assert catalog_info.dec_column == "source_dec"


def test_no_margin_cache_overwrite(small_sky_object_catalog):
"""Runner should refuse to generate margin cache which overwrites valid catalog"""
catalog_dir = small_sky_object_catalog.parent
catalog_name = small_sky_object_catalog.name
with pytest.raises(ValueError, match="already contains a valid catalog"):
MarginCacheArguments(
input_catalog_path=small_sky_object_catalog,
output_path=catalog_dir,
margin_threshold=10.0,
output_artifact_name=catalog_name,
)
10 changes: 5 additions & 5 deletions tests/hats_import/test_runtime_arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,10 +123,10 @@ def test_dask_args(tmp_path):
)


def test_extra_property_dict(test_data_dir):
def test_extra_property_dict(tmp_path):
args = RuntimeArguments(
output_artifact_name="small_sky_source_catalog",
output_path=test_data_dir,
output_path=tmp_path,
)

properties = args.extra_property_dict()
Expand All @@ -141,13 +141,13 @@ def test_extra_property_dict(test_data_dir):
# Most values are dynamic, but these are some safe assumptions.
assert properties["hats_builder"].startswith("hats")
assert properties["hats_creation_date"].startswith("20")
assert properties["hats_estsize"] > 1_000
assert properties["hats_estsize"] >= 0
assert properties["hats_release_date"].startswith("20")
assert properties["hats_version"].startswith("v")

args = RuntimeArguments(
output_artifact_name="small_sky_source_catalog",
output_path=test_data_dir,
output_path=tmp_path,
addl_hats_properties={"foo": "bar"},
)

Expand All @@ -164,7 +164,7 @@ def test_extra_property_dict(test_data_dir):
# Most values are dynamic, but these are some safe assumptions.
assert properties["hats_builder"].startswith("hats")
assert properties["hats_creation_date"].startswith("20")
assert properties["hats_estsize"] > 1_000
assert properties["hats_estsize"] >= 0
assert properties["hats_release_date"].startswith("20")
assert properties["hats_version"].startswith("v")
assert properties["foo"] == "bar"

0 comments on commit bacabec

Please sign in to comment.