Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[aot.export] Potential Memory Leak #281

Open
egebeysel opened this issue Nov 20, 2024 · 11 comments
Open

[aot.export] Potential Memory Leak #281

egebeysel opened this issue Nov 20, 2024 · 11 comments

Comments

@egebeysel
Copy link

egebeysel commented Nov 20, 2024

Hi,

I was trying to run a benchmark suite that involves exporting multiple torch.nn.Modules and realised that the aot.export() function might be causing a memory leak, resulting in the state_dict of the nn.Module and the ExportedProgram not being released even they shouldn't be referenced anymore.

A concrete and minimal reproducer of the problem:

import time

import torch
from transformers import AutoModelForImageClassification
import iree.turbine.aot as aot


def create_random_inputs():
    return (torch.randn(1,3,224,224, dtype=torch.float32),)

def run_reproducer():
    model = AutoModelForImageClassification.from_pretrained("microsoft/resnet-50")
    inputs = create_random_inputs()
    # Export the model
    # Here is where the leaks occurs, without the following line, the model state_dict is released
    exported_model = aot.export(model, args=inputs, dynamic_shapes=None, strict_export=False)
    print(exported_model)


if __name__ == "__main__":
    for _ in range(2):
        # Model should be released from memory on the second iteration
        run_reproducer()
    print("end loop")
    # To have a more observable flamegraph in the end
    time.sleep(5)

In the second iteration of the loop, one would expect the first model and the exported program objects be released from the memory, even though running the program with memray begs to differ:

Image

and the functions that allocate memory that hasn't been released within the time frame:

Image

There are 2 copies of the state_dict and ExportedProgram that are being kept in memory above but to better observe the (de)allocations, the above memray graphs can also be reproduced as follows (one has to install the dependencies of iree-turbine + memray):

$ PYTHONMALLOC=malloc python3 -m memray run -o reproducer.bin path/to/reproducer.py
$ memray flamegraph --leaks --temporal reproducer.bin

Now I would like to take on the issue myself but before I dive into it, I wanted to ask for any pointers that can be useful, or if I'm missing a point? Any help or pointer to where the problem might be is much appreciated.

P.S.: One can use any nn.Module, I used the resnet-50 because it is big enough to observe the memray graph.

@stellaraccident
Copy link
Collaborator

Oops. Thanks for the analysis. I suspect what is going on is that internally, the mechanism creates a new class object that is not being immediately collected, and that class object may be inadvertently holding on to the state dict.

As further evidence, we may want to insert an explicit call to gc.collect() to force a full collection and see if that releases the memory. That isn't a good approach for production but would tell us if that is what we are dealing with. If that is the case, we need to break the cycle somehow so that garbage that only the generational collector can collect does not include such heavy weight references.

We're also slowly moving to get rid of that lower level class mechanism, which was needed to bridge certain programming model issues in the early days. This analysis may raise the priority of that -- but I also expect there is more of an easy hack that breaks the cycle when done with export. Just need to find the cycle.

@egebeysel
Copy link
Author

As further evidence, we may want to insert an explicit call to gc.collect() to force a full collection and see if that releases the memory. That isn't a good approach for production but would tell us if that is what we are dealing with. If that is the case, we need to break the cycle somehow so that garbage that only the generational collector can collect does not include such heavy weight references.

thanks for the reply! I think so too, the LambdaCompiledModule class object that get created and put into the _all_compiled_module_class_infos weak dict of iree/turbine/aot/compiled_module.py is still accessible after the outer loop of the reproducer, therefore, the corresponding CompiledModuleClassInfo value in the dict holds on to it's all_exports, which then is holding onto the state dict.

As per the gc.collect(), it doesn't really help if I place it right before the sleep, so I'm guessing there indeed is a cycle.

@stellaraccident
Copy link
Collaborator

Something nefarious is going on if GC.collect doesn't get it because that should be handing cycles. This likely means that there is an unintended strong reference, not just a cycle.

@egebeysel
Copy link
Author

Something nefarious is going on if GC.collect doesn't get it because that should be handing cycles. This likely means that there is an unintended strong reference, not just a cycle.

looking into it

@egebeysel
Copy link
Author

I think so too, the LambdaCompiledModule class object that get created and put into the _all_compiled_module_class_infos weak dict of iree/turbine/aot/compiled_module.py is still accessible after the outer loop of the reproducer, therefore, the corresponding CompiledModuleClassInfo value in the dict holds on to it's all_exports, which then is holding onto the state dict.

@stellaraccident
On a second look, I think I'm mistaken and the LambdaCompiledModules are not accessible and the entry is deleted from the weak dictionary. However, the memory situation is exactly the same with gc.collect - at least for the state dict.

@stellaraccident
Copy link
Collaborator

Well this is a mystery indeed. I'm afraid I don't have another theory without putting hands on it. But if I were fishing, I would look at class garbage collection

@egebeysel
Copy link
Author

egebeysel commented Dec 4, 2024

So I've had the time to do a little more digging, I believe the issue is caused by the RefMapper and RefTracker classes. ModuleBuilder has these RefTrackers, but I believe the global_ref_tracker is of interest in this case.

If we were to run the above reproducer, the RefTracker and the corresponding RefMappings are kept in-between and after iterations and are not garbage collected. Therefore, the corresponding value attributes of the mappings keep the torch.Tensor and the np.ndarray (as the backing buffer of DenseResourceElementsAttr that is constructed here).

Now I'm not entirely sure why this happens, but the weakref.finalize objects that are registered in RefTracker here are still alive after the loop in the reproducer. I couldn't really see the finalizer function being bound to or owning the referrent (tensor) .

However, if I were to run the reproducer without the above 2 lines that register the finalizer, it passes the tests and the RefTracker and the corresponding RefMappings - therefore the tensors and buffers- are actually released in-between iterations.

Here is a memory usage graph of that, the last dip corresponds to a gc.collect() call right after the loop in the reproducer. Before it
s called there's only one model (its corresponding state dict - np.array) in memory and the previous ones are released. (Also, the loop has 3 iterations instead of 2). I believe there exists (?) a further circular dependency that gc.collect() resolves, but it's only the last model. After that, no model relevant tensors or arrays are in memory.

Image

Now, I don't know why that weakref.finalize is actually there, but is there a case where the torch.Tensor would be released from memory - and therefore that the finalizer be called - without the RefTrackers themselves being released? Wouldn't the state dict of the model always be released after a model is imported? Would leaving this part out have any semantic impact?

@maxbartel
Copy link
Contributor

maxbartel commented Dec 9, 2024

@stellaraccident Could you maybe help out here? We are not fully sure if this will still will work as you intended with this change. Thanks!

@stellaraccident
Copy link
Collaborator

Thanks for identifying the smoking gun. I'm on vacation for the next two weeks but will definitely take the time to follow your analysis and get a solution landed when I am able. I recall there being trickiness with that finalizer but I will need to refresh state to think through it properly.

@egebeysel
Copy link
Author

Hi @stellaraccident , did you have time to look at this? Is there indeed some trickiness with the finalizer? If not, I would raise a PR to torch-mlir removing the 2 lines and the corresponding _ref_finalizer method.

@stellaraccident
Copy link
Collaborator

Thanks for the reminder: I remembered this was outstanding but got buried after vacation. I'll have a look this weekend or first thing next week.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants