-
Notifications
You must be signed in to change notification settings - Fork 526
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fix(pt): detach coord when building nlist in DPA-2 #4387
Conversation
Signed-off-by: Jinzhe Zeng <jinzhe.zeng@rutgers.edu>
📝 WalkthroughWalkthroughThe pull request introduces modifications to the Changes
Possibly related PRs
Suggested labels
Thank you for using CodeRabbit. We offer it for free to the OSS community and would appreciate your support in helping us grow. If you find it useful, would you consider giving us a shout-out on your favorite social media? 🪧 TipsChatThere are 3 ways to chat with CodeRabbit:
Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments. CodeRabbit Commands (Invoked using PR comments)
Other keywords and placeholders
CodeRabbit Configuration File (
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Outside diff range and nitpick comments (2)
deepmd/pt/model/descriptor/dpa2.py (2)
760-760
: Good optimization: Detaching coordinates for neighbor list computationThe addition of
.detach()
onextended_coord
when building neighbor lists is a good optimization. This prevents unnecessary gradient computation during the neighbor list construction phase, which doesn't require gradients, leading to better memory efficiency.This change aligns with best practices for PyTorch memory optimization by removing unnecessary computational graph nodes.
Line range hint
758-764
: Add type hints for better code maintainabilityConsider adding explicit type hints to the tensor operations for better code maintainability and IDE support.
- extended_coord.detach(), - nlist, - self.rcut_list, - self.nsel_list, + extended_coord: torch.Tensor = extended_coord.detach(), + nlist: torch.Tensor = nlist, + rcut_list: list[float] = self.rcut_list, + nsel_list: list[int] = self.nsel_list,
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
📒 Files selected for processing (1)
deepmd/pt/model/descriptor/dpa2.py
(1 hunks)
🔇 Additional comments (1)
deepmd/pt/model/descriptor/dpa2.py (1)
Line range hint 760-764
: Verify gradient flow for coordinate processing
While detaching coordinates for neighbor list building is correct, let's verify that gradients flow properly in other coordinate-dependent computations.
✅ Verification successful
Based on the search results, I can now generate the final response since I have enough context to verify the gradient flow. The codebase shows consistent patterns of detaching coordinates when building neighbor lists across multiple implementations, while preserving gradients for actual computations.
Detaching coordinates for neighbor list construction is safe and correct
The detachment of coordinates for neighbor list building is consistent with the codebase's design:
- All implementations (PyTorch, JAX, etc.) detach coordinates when building neighbor lists since they only need position information
- Gradients are preserved for the actual coordinate-dependent computations in descriptors and energy/force calculations
- The pattern is consistently used in other models like
linear_atomic_model.py
anddp_atomic_model.py
🏁 Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash
# Description: Check for other instances of coordinate processing in the codebase
# to ensure consistent handling of gradients
# Look for other coordinate processing patterns
rg -A 3 "extended_coord" --type py
# Look for similar neighbor list building patterns
rg -A 3 "build_multiple_neighbor_list" --type py
Length of output: 94848
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## devel #4387 +/- ##
=======================================
Coverage 84.50% 84.50%
=======================================
Files 604 604
Lines 56945 56943 -2
Branches 3487 3486 -1
=======================================
- Hits 48121 48120 -1
+ Misses 7698 7696 -2
- Partials 1126 1127 +1 ☔ View full report in Codecov by Sentry. 🚨 Try these New Features:
|
Summary by CodeRabbit
New Features
Bug Fixes
Documentation