Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pull resume logic into helper class. #95

Merged
merged 15 commits into from
Jul 12, 2023
18 changes: 18 additions & 0 deletions src/hipscat_import/catalog/file_readers.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import pyarrow as pa
import pyarrow.parquet as pq
from astropy.table import Table
from hipscat.io import file_io

# pylint: disable=too-few-public-methods,too-many-arguments

Expand Down Expand Up @@ -87,6 +88,19 @@ def provenance_info(self) -> dict:
dictionary with all argument_name -> argument_value as key -> value pairs.
"""

def regular_file_exists(self, input_file):
"""Check that the `input_file` points to a single regular file

Raises
FileNotFoundError: if nothing exists at path, or directory found.
"""
if not file_io.does_file_or_directory_exist(input_file):
raise FileNotFoundError(f"File not found at path: {input_file}")
if not file_io.is_regular_file(input_file):
raise FileNotFoundError(
f"Directory found at path - requires regular file: {input_file}"
)


class CsvReader(InputReader):
"""CSV reader for the most common CSV reading arguments.
Expand Down Expand Up @@ -125,6 +139,8 @@ def __init__(
self.kwargs = kwargs

def read(self, input_file):
self.regular_file_exists(input_file)

if self.schema_file:
schema_parquet = pd.read_parquet(
self.schema_file, dtype_backend="numpy_nullable"
Expand Down Expand Up @@ -206,6 +222,7 @@ def __init__(
self.kwargs = kwargs

def read(self, input_file):
self.regular_file_exists(input_file)
table = Table.read(input_file, memmap=True, **self.kwargs)
if self.column_names:
table.keep_columns(self.column_names)
Expand Down Expand Up @@ -243,6 +260,7 @@ def __init__(self, chunksize=500_000, **kwargs):
self.kwargs = kwargs

def read(self, input_file):
self.regular_file_exists(input_file)
parquet_file = pq.read_table(input_file, **self.kwargs)
for smaller_table in parquet_file.to_batches(max_chunksize=self.chunksize):
yield pa.Table.from_batches([smaller_table]).to_pandas()
Expand Down
13 changes: 6 additions & 7 deletions src/hipscat_import/catalog/map_reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from hipscat.io import FilePointer, file_io, paths

from hipscat_import.catalog.file_readers import InputReader
from hipscat_import.catalog.resume_plan import ResumePlan

# pylint: disable=too-many-locals,too-many-arguments

Expand Down Expand Up @@ -56,12 +57,6 @@ def _iterate_input_file(
dec_column,
):
"""Helper function to handle input file reading and healpix pixel calculation"""
if not file_io.does_file_or_directory_exist(input_file):
raise FileNotFoundError(f"File not found at path: {input_file}")
if not file_io.is_regular_file(input_file):
raise FileNotFoundError(
f"Directory found at path - requires regular file: {input_file}"
)
if not file_reader:
raise NotImplementedError("No file reader implemented")

Expand All @@ -86,6 +81,8 @@ def _iterate_input_file(
def map_to_pixels(
input_file: FilePointer,
file_reader: InputReader,
cache_path: FilePointer,
mapping_key,
highest_order,
ra_column,
dec_column,
Expand Down Expand Up @@ -115,7 +112,9 @@ def map_to_pixels(
):
mapped_pixel, count_at_pixel = np.unique(mapped_pixels, return_counts=True)
histo[mapped_pixel] += count_at_pixel.astype(np.int64)
return histo
ResumePlan.write_partial_histogram(
tmp_path=cache_path, mapping_key=mapping_key, histogram=histo
)


def split_pixels(
Expand Down
96 changes: 59 additions & 37 deletions src/hipscat_import/catalog/resume_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,17 +29,18 @@ class ResumePlan:
feedback of planning progress"""
input_paths: List[FilePointer] = field(default_factory=list)
"""resolved list of all files that will be used in the importer"""
map_files: List[str] = field(default_factory=list)
"""list of files that have yet to be mapped"""
map_files: List[Tuple[str, str]] = field(default_factory=list)
"""list of files (and job keys) that have yet to be mapped"""
split_keys: List[Tuple[str, str]] = field(default_factory=list)
"""set of files (and job keys) that have yet to be split"""

MAPPING_START_LOG_FILE = "mapping_start_log.txt"
MAPPING_DONE_LOG_FILE = "mapping_done_log.txt"
MAPPING_LOG_FILE = "mapping_log.txt"
SPLITTING_LOG_FILE = "splitting_log.txt"
HISTOGRAM_BINARY_FILE = "mapping_histogram.binary"
REDUCING_LOG_FILE = "reducing_log.txt"

HISTOGRAM_BINARY_FILE = "mapping_histogram.binary"
HISTOGRAMS_DIR = "histograms"

MAPPING_DONE_FILE = "mapping_done"
SPLITTING_DONE_FILE = "splitting_done"
REDUCING_DONE_FILE = "reducing_done"
Expand Down Expand Up @@ -79,59 +80,84 @@ def gather_plan(self):

## Gather keys for execution.
self.input_paths.sort()
for test_path in self.input_paths:
if not file_io.does_file_or_directory_exist(test_path):
raise FileNotFoundError(f"{test_path} not found on local storage")
step_progress.update(1)
if not mapping_done:
mapping_start_keys = self._read_log_keys(self.MAPPING_START_LOG_FILE)
mapping_done_keys = self._read_log_keys(self.MAPPING_DONE_LOG_FILE)
if len(mapping_done_keys) != len(mapping_start_keys):
raise ValueError(
"Resume logs are corrupted. "
"Delete temp directory and restart import pipeline."
)
mapped_paths = set(mapping_start_keys)
mapped_keys = set(self._read_log_keys(self.MAPPING_LOG_FILE))
self.map_files = [
file_path
for file_path in self.input_paths
if f"map_{file_path}" not in mapped_paths
(f"map_{i}", file_path)
smcguire-cmu marked this conversation as resolved.
Show resolved Hide resolved
for i, file_path in enumerate(self.input_paths)
if f"map_{i}" not in mapped_keys
]
if not splitting_done:
split_keys = set(self._read_log_keys(self.SPLITTING_LOG_FILE))
self.split_keys = [
(f"split_{i}", file_path)
for i, file_path in enumerate(self.input_paths)
]
self.split_keys = [
(key, file)
for (key, file) in self.split_keys
if key not in split_keys
if f"split_{i}" not in split_keys
]
## We don't pre-gather the plan for the reducing keys.
## It requires the full destination pixel map.
step_progress.update(1)

def read_histogram(self, highest_healpix_order):
"""Read a numpy array at the indicated directory.
Otherwise, return histogram of appropriate shape."""
def read_histogram(self, healpix_order):
"""Return histogram with healpix_order'd shape

- Try to find a combined histogram
- Otherwise, combine histograms from partials
- Otherwise, return an empty histogram
."""
full_histogram = pixel_math.empty_histogram(healpix_order)

## Look for the single combined histogram file.
file_name = file_io.append_paths_to_pointer(
self.tmp_path, self.HISTOGRAM_BINARY_FILE
)
if file_io.does_file_or_directory_exist(file_name):
with open(file_name, "rb") as file_handle:
return frombuffer(file_handle.read(), dtype=np.int64)
return pixel_math.empty_histogram(highest_healpix_order)

def mark_mapping_done(self, mapping_key: str, histogram):
"""Add mapping key to done list and update raw histogram"""
self._write_log_key(self.MAPPING_START_LOG_FILE, mapping_key)
## Otherwise:
# - read all the partial histograms
# - combine into a single histogram
# - write out as a single histogram for future reads
# - remove all partial histograms
histogram_files = file_io.find_files_matching_path(
self.tmp_path, self.HISTOGRAMS_DIR, "**.binary"
)
for file_name in histogram_files:
with open(file_name, "rb") as file_handle:
full_histogram = np.add(
full_histogram, frombuffer(file_handle.read(), dtype=np.int64)
)

file_name = file_io.append_paths_to_pointer(
self.tmp_path, self.HISTOGRAM_BINARY_FILE
)
with open(file_name, "wb+") as file_handle:
file_handle.write(full_histogram.data)
file_io.remove_directory(
file_io.append_paths_to_pointer(self.tmp_path, self.HISTOGRAMS_DIR),
ignore_errors=True,
)
return full_histogram

@classmethod
def write_partial_histogram(cls, tmp_path, mapping_key: str, histogram):
"""Write partial histogram to a special intermediate directory"""
file_io.make_directory(
file_io.append_paths_to_pointer(tmp_path, cls.HISTOGRAMS_DIR),
exist_ok=True,
)

file_name = file_io.append_paths_to_pointer(
tmp_path, cls.HISTOGRAMS_DIR, f"{mapping_key}.binary"
)
with open(file_name, "wb+") as file_handle:
file_handle.write(histogram.data)
self._write_log_key(self.MAPPING_DONE_LOG_FILE, mapping_key)

def mark_mapping_done(self, mapping_key: str):
"""Add mapping key to done list and update raw histogram"""
smcguire-cmu marked this conversation as resolved.
Show resolved Hide resolved
self._write_log_key(self.MAPPING_LOG_FILE, mapping_key)

def is_mapping_done(self) -> bool:
"""Are there files left to map?"""
Expand Down Expand Up @@ -166,11 +192,7 @@ def get_reduce_items(self, destination_pixel_map):
reduce_items = [
(hp_pixel, source_pixels, f"{hp_pixel.order}_{hp_pixel.pixel}")
for hp_pixel, source_pixels in destination_pixel_map.items()
]
reduce_items = [
(hp_pixel, source_pixels, key)
for (hp_pixel, source_pixels, key) in reduce_items
if key not in reduced_keys
if f"{hp_pixel.order}_{hp_pixel.pixel}" not in reduced_keys
]
return reduce_items

Expand Down
24 changes: 12 additions & 12 deletions src/hipscat_import/catalog/run_import.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,42 +18,40 @@
def _map_pixels(args, client):
"""Generate a raw histogram of object counts in each healpix pixel"""

raw_histogram = args.resume_plan.read_histogram(args.mapping_healpix_order)
if args.resume_plan.is_mapping_done():
return raw_histogram
return

reader_future = client.scatter(args.file_reader)
futures = []
for file_path in args.resume_plan.map_files:
map_key = f"map_{file_path}"
for key, file_path in args.resume_plan.map_files:
futures.append(
client.submit(
mr.map_to_pixels,
key=map_key,
key=key,
input_file=file_path,
cache_path=args.tmp_path,
file_reader=reader_future,
mapping_key=key,
highest_order=args.mapping_healpix_order,
ra_column=args.ra_column,
dec_column=args.dec_column,
)
)

some_error = False
for future, result in tqdm(
as_completed(futures, with_results=True),
for future in tqdm(
as_completed(futures),
desc="Mapping ",
total=len(futures),
disable=(not args.progress_bar),
):
if future.status == "error": # pragma: no cover
some_error = True
else:
raw_histogram = np.add(raw_histogram, result)
args.resume_plan.mark_mapping_done(future.key, raw_histogram)
args.resume_plan.mark_mapping_done(future.key)
if some_error: # pragma: no cover
raise RuntimeError("Some mapping stages failed. See logs for details.")
args.resume_plan.set_mapping_done()
return raw_histogram


def _split_pixels(args, alignment_future, client):
Expand Down Expand Up @@ -147,11 +145,13 @@ def run(args, client):
raise ValueError("args is required and should be type ImportArguments")
if not isinstance(args, ImportArguments):
raise ValueError("args must be type ImportArguments")
raw_histogram = _map_pixels(args, client)
_map_pixels(args, client)

with tqdm(
total=1, desc="Binning ", disable=not args.progress_bar
total=2, desc="Binning ", disable=not args.progress_bar
) as step_progress:
raw_histogram = args.resume_plan.read_histogram(args.mapping_healpix_order)
step_progress.update(1)
if args.constant_healpix_order >= 0:
alignment = np.full(len(raw_histogram), None)
for pixel_num, pixel_sum in enumerate(raw_histogram):
Expand Down
10 changes: 0 additions & 10 deletions tests/hipscat_import/catalog/test_argument_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,16 +66,6 @@ def test_invalid_paths(blank_data_dir, tmp_path):
input_format="parquet",
)

## Bad input file
with pytest.raises(FileNotFoundError):
ImportArguments(
output_catalog_name="catalog",
input_file_list=["/foo/path"],
overwrite=True,
output_path=tmp_path,
input_format="csv",
)


def test_good_paths(blank_data_dir, blank_data_file, tmp_path):
"""Required arguments are provided, and paths are found."""
Expand Down
Loading