Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: enable composable and customizable sampler in PyTorch data loader #1900

Merged
merged 19 commits into from
Feb 2, 2024

Conversation

eddyxu
Copy link
Contributor

@eddyxu eddyxu commented Feb 2, 2024

  • Provide a set of composable Sampler that works with lance dataset
  • New ruff made a bunch of format changes

@eddyxu eddyxu changed the title feat: enable customized sampler in PyTorch feat: enable composable and customizable sampler in PyTorch data loader Feb 2, 2024
@eddyxu eddyxu requested review from westonpace, wjones127 and chebbyChefNEQ and removed request for westonpace and wjones127 February 2, 2024 01:00
@eddyxu eddyxu force-pushed the lei/pytorch_example branch from 89faedd to 48eb133 Compare February 2, 2024 01:04
@@ -184,3 +186,130 @@ def reservoir_sampling(stream: Iterable[T], k: int) -> list[T]:
samples = [i.item for i in heap]
del heap
return samples


class Sampler(ABC):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right now the implementations just scan in order, they don't randomize (which almost makes them not really meet the definition of "sampling".) Users could do shuffling / reservoir sampling on the batches, but it would much more efficient to do it on fragment_ids and batch indices. Do you have any plans to integrate that with this API?

Copy link
Contributor Author

@eddyxu eddyxu Feb 2, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reservoir shuffling is I/O friendly with sequential read (which is NFS friendly), while yielding random batch in uniformity distribution. The time complexity is O(k + log(n/k)) where n is the # of batches, and k is small, with k memory foot prints, and amortizes the file path lookup and metadata overhead cross the scan. Within Lance itself, it is more performant to run read_batch than take. In many cases, Reservoir shuffling can provide pretty decent performance. Need more performance numbers for sure.

I might need to get another PR out to put np.random.select(fragments) and reservior_shuffle(batches) tho. This one established the APIs.

Copy link
Contributor Author

@eddyxu eddyxu Feb 2, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That being said, the reservior sampling need to be change to

def reservoir_sampling(stream: Iterable[T], k: int, rank: int, world_sizae: int) -> list[T]:
    rng = np.random.default_rng()
    heap = []
    for idx, item in enumerate(stream):
        entry = PrioritizedItem(rng.integers(0, k * 2), item)
        if len(heap) < k:
            heappush(heap, entry)
        else:
            vic = heappushpop(heap, entry)
            if idx % world_size == rank:   ## <<<<< CHANGE TO YIELD HERE
               yield vic
            del vic
        if idx % 10240 == 0:
            logging.info("Force Python GC")
            gc.collect()
    samples = [i.item for i in heap]
    del heap
    return samples

Run this with n=1M, k=8, world_size=1

Each rank / process will process a subset of the batches.
"""

def __init__(self, rank: int, world_size: int):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: add a from_torch method

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the one read from torch distributed?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah

Each rank / process will process a subset of the fragments.
"""

def __init__(self, rank: int, world_size: int):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

@eddyxu eddyxu merged commit 5407db8 into main Feb 2, 2024
9 checks passed
@eddyxu eddyxu deleted the lei/pytorch_example branch February 2, 2024 02:56
westonpace added a commit that referenced this pull request Feb 9, 2024
When support was added for sampling in #1900 it broke support for
filtering on full scans (sampling and filtering is not yet supported).
This PR repairs support for filtering on full scans.

Closes #1932
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants