Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Out of Memory when using Streaming Dataloader #652

Open
VikaasVarma opened this issue Apr 13, 2024 · 15 comments
Open

Out of Memory when using Streaming Dataloader #652

VikaasVarma opened this issue Apr 13, 2024 · 15 comments
Labels
bug Something isn't working

Comments

@VikaasVarma
Copy link

Environment

  • OS: Ubuntu 22.04

To reproduce

Steps to reproduce the behavior:

When using the StreamingDataloader (or the vanilla pytorch Dataloader) with num_workers>0, the processes slowly take more and more memory until the CPU RAM is filled.

Expected behavior

The dataloader should be able to provide samples indefinitely without using a significant portion of available RAM.

Additional context

Below is the dataset and dataloader implementation. Each sample is roughly 10 MB. With 16 workers, a prefetch factor of 4, and a batch size of 32, the total memory usage should be, at max, 20 GB. The dataset is made up of around 1.3 million shards.

A similar problem seems to be documented in an issue and a blog post. I have recreated the graphs found in the blog post below.

class ImageTokenDataset(StreamingDataset):
    def __init__(
        self,
        remote: str,
        batch_size: int,
        shuffle: bool = False,
        local: str | None = None,
        split: str | None = None,
        transforms: T.Compose = T.Compose([T.ToImage(), T.ToDtype(torch.float32)]),
        input_key: str = "jpg",
        cond_key: str = "cond",
        cond_dropout_rate: float = 0.5,
        predownload: int | None = None,
        **kwargs,
    ) -> None:
        super().__init__(
            local=local,
            remote=remote,
            shuffle=shuffle,
            batch_size=batch_size,
            split=split,
            predownload=predownload,
            **kwargs,
        )

        self.batch_size = batch_size
        self.transforms = transforms
        self.input_key = input_key
        self.cond_key = cond_key
        self.cond_dropout_rate = cond_dropout_rate

    def __getitem__(self, at: int) -> Sample:
        obj = super().__getitem__(at)

        _input = self.transforms(obj[self.input_key])
        cond = torch.tensor(obj[self.cond_key])

        if torch.rand(1) < self.cond_dropout_rate:
            cond = torch.zeros_like(cond)

        return inputs, cond

    def to_dataloader(
        num_workers: int = 8,
        prefetch_factor: int | None = None,
        persistent_workers: bool = True,
        pin_memory: bool = True,
        drop_last: bool = True,
        batch_size: int | None = None,
    ):
        return StreamingDataLoader(
            self,
            batch_size=batch_size or self.batch_size,
            drop_last=drop_last,
            prefetch_factor=prefetch_factor,
            num_workers=num_workers,
            persistent_workers=persistent_workers,
            pin_memory=pin_memory,
        )


if __name__ == "__main__":
    dataset = ImageTokenDataset(
        remote=remote_path,
        batch_size=32
        local="/tmp/dataset/train",
        split="train",
        input_key="jpg",
        cond_key="t5",
        cond_dropout_rate=0.5,
    )
    dataloader = dataset.to_dataloader(
        num_workers=16, persistent_workers=True, pin_memory=False, prefetch_factor=4
    )

    for _ in tqdm(dataloader):
        pass

503128

@VikaasVarma VikaasVarma added the bug Something isn't working label Apr 13, 2024
@miguelalba96
Copy link

miguelalba96 commented Apr 15, 2024

I'm experimenting similar issues when loading image/text pairs (local). The RAM usage starts to increase non-stop (GPU is stable). I managed to "solve" it partially decreasing the number of workers in the data loader (my max CPU per node is 16vCPUs each node has 2 GPUs), so I set num_workers=8 and disabled persistent_workers, I guess if I leave persistent_workers=True the training will crash eventually

(training happens after downloading the data locally to the nodes)
image

maybe streaming is not cleaning up some states leading to that memory accumulation?

@snarayan21
Copy link
Collaborator

Hey y’all, thanks for bringing this issue to our attention. We’re looking into this and will get back to you soon.

@snarayan21
Copy link
Collaborator

Skimmed through the blog and PyTorch issue, is this an issue particular to Streaming or is it on the PyTorch side? StreamingDataLoader is a simple (stateful) subclass of PyTorch’s DataLoader. Does this also happen with other Datasets? @VikaasVarma @miguelalba96

@snarayan21
Copy link
Collaborator

So Streaming is designed for fast random sample access, from shards that live on disk. Samples, outside of dataloader prefetching, are never kept in memory. We conserve RAM to do other things though, like sample partitioning and shuffling, but this happens at the start of training. So I'm inclined to think that this is a PyTorch DataLoader issue, given the links you sent as well.

To track memory usage, maybe you could call gc.get_referents() in a loop to track mem statistics? You might be able to find the memory issues by looking at gc.

@VikaasVarma
Copy link
Author

This does not happen outside of Mosaic or with other datasets. Using PyTorch's Dataloader instead of the StreamingDataLoader also leads to a memory leak. When pulling the data using PyTorch's torchdata to construct the dataset, there is no significant memory overhead.

I don't think the problem lies within the StreamingDataloader. The links seem to point towards large lists of python objects generally causing this issue. There are a few cases of this in the StreamingDataset (the stored shards, spanner's, stream filepaths, etc...).

@tonyf
Copy link

tonyf commented Apr 19, 2024

To echo @VikaasVarma's point here-- the copy on read issues with the torch dataloader comes back to the dataset object storing a large number of naive python objects that can't use shared memory. I noticed that most of the dataset metadata is in fact in shared memory except for self.shards which is a list of Reader objects.

With smaller datasets, we never ran into this issue (or it never came up through the training lifecycle). We're only running into it now with a dataset that is a few orders of magnitude larger (more rows + larger row size and thus more shards).

If this is a copy-on-read issue, the memory wouldn't grow by a factor of row size, only the number of shards which I think is the case.

@XiaohanZhangCMU
Copy link
Collaborator

XiaohanZhangCMU commented Apr 22, 2024

@VikaasVarma Is this a typo in your repro script?

def __getitem__(self, at: int) -> Sample:
    obj = super().__getitem__(at)

    _input = self.transforms(obj[self.input_key])
    cond = torch.tensor(obj[self.cond_key])

    if torch.rand(1) < self.cond_dropout_rate:
        cond = torch.zeros_like(cond) 

    return inputs, cond. # should be _input? 

To clarify, you expected "cond" to be gc by streamingdataset?

@XiaohanZhangCMU
Copy link
Collaborator

@VikaasVarma can you clarify your plot a bit? e.g., what does pss uss shared mean? and x, y axis.
Can you also provide a sample dataset so I can reproduce the plot? Thanks!

@miguelalba96
Copy link

miguelalba96 commented May 1, 2024

I tested again training for longer period:

image

my implementation:

from ast import literal_eval

import torch

from PIL import Image
from streaming import StreamingDataset, StreamingDataLoader

import utils.visual_attribution # <- some of my modules for normal string maniputation

class ImageCaptionDataset(StreamingDataset):
    def __init__(
            self,
            local: str,
            shuffle: bool,
            batch_size: int,
            transformations: Callable,
    ) -> None:
        super().__init__(
            local=local, shuffle=shuffle, batch_size=batch_size,
        )
        self.transformations = transformations

    @staticmethod
    def get_zero_shot_one_hot(zero_shot_attributes: List[int]):
        one_hot_encoded = torch.zeros(len(utils.visual_attribution.VISUAL_CLASSES), dtype=torch.float)
        one_hot_encoded[zero_shot_attributes] = 1.0
        return one_hot_encoded

    def __getitem__(self, idx: int) -> Any:
        obj = super().__getitem__(idx)
        image = Image.open(io.BytesIO(obj["image"]))
        caption = utils.visual_attribution.replace_article_type(obj["caption_simple"]) # does string replacement
        zero_shot_attr = self.get_zero_shot_one_hot(literal_eval(obj["zero_shot_attributes"]))
        return self.transformations(image), caption, zero_shot_attr

I call it like this:

def get_image_transformation_func(split: str):
    transformations = []
    if split == "train":
        transformations += [
            transforms.RandomHorizontalFlip(p=0.5),
            # v2.RandomVerticalFlip(p=0.5)
        ]
    transformations += [
        transforms.Lambda(lambda x: x)
    ]
    return transforms.Compose(transformations)


def collate_fn(batch, processor, tokenizer, max_length: int = 77):
    # samples come from the dataset as CxHxW
    images = processor(
        images=[ex[0] for ex in batch],
        return_tensors="pt"
    )
    captions = tokenizer(
            [ex[1] for ex in batch],
            padding="max_length",
            max_length=max_length,
            return_tensors="pt"
        )
    return {
            # "pixel_values": torch.stack([ex[0] for ex in batch]),
            "pixel_values": images["pixel_values"],
            "input_ids": captions["input_ids"],
            "attention_mask": captions["attention_mask"],
            "labels": torch.stack([ex[2] for ex in batch])
    }


def get_dataloader(split: str, config: configs.ExperimentConfig):
    transform_func = get_image_transformation_func(split)
    dataset = ImageCaptionDataset(
        local=os.path.join(config.local_data_path, split),
        shuffle=True if split == "train" else False,
        batch_size=config.dataset_config.batch_size,
        transformations=transform_func,
    )
    return StreamingDataLoader(
        dataset,
        batch_size=config.dataset_config.batch_size,
        num_workers=config.dataset_config.num_workers,
        collate_fn=partial(
            collate_fn,
            processor=config.dataset_config.processor,
            tokenizer=config.dataset_config.tokenizer),
        drop_last=True,
        pin_memory=config.dataset_config.pin_memory,
        prefetch_factor=config.dataset_config.prefetch_factor,
        # persistent_workers=config.dataset_config.persistent_workers
    )

@huxuan
Copy link
Contributor

huxuan commented Jul 24, 2024

Encounter a similar issue, the CPU Memory usage keep increasing until OOM in about two hours.

@nagadit
Copy link

nagadit commented Sep 8, 2024

Check this issue

#758

I found a memory leak problem, it is in the work of the boto3 library

@wanghao14
Copy link

Encounter similar CPU memory leak issue when training on H800.

@snarayan21
Copy link
Collaborator

Hey @huxuan @wanghao14 @miguelalba96,
As @nagadit mentioned, there seems to be a memory leak with boto3 as detailed in boto/boto3#1670. If you are using boto3 / s3, can you verify if this is causing your problems?

@wanghao14
Copy link

@snarayan21 There is no boto3 in my code.

@miguelalba96
Copy link

@snarayan21 I am using 1.5TB of images stored in shards locally in 4 nodes, each with an entire copy of the data, so technically I am not streaming

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

8 participants