-
Notifications
You must be signed in to change notification settings - Fork 233
fix: Unify custom model logits extraction across all inference methods #1815
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
This PR standardizes logits extraction logic across multiple inference methods to properly handle custom MoE models that output logits directly as tensors instead of objects with a .logits attribute.
Changes:
- Updated
get_logprobs(),score(), andget_topk_logits()to use consistent three-case logits extraction pattern (tensor output, missing.logits, has.logits) - Fixed float32 conversion bug in
score()that incorrectly referencedoutputs.logitsinstead of the extractedlogitsvariable - Added
del outputsstatements to release memory after logits extraction
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
This commit fixes incomplete handling of custom MoE models (e.g., from AutoModel) that output logits directly as tensors instead of objects with a .logits attribute. Changes: - get_logprobs(): Updated to use complete logits extraction logic with proper handling for all three cases (tensor output, missing .logits, has .logits) - score(): Added missing isinstance check for tensor outputs and fixed the float32 conversion bug that incorrectly referenced outputs.logits - get_topk_logits(): Already had the correct logic, ensuring consistency All methods now follow the same pattern as the train() method for extracting logits from model outputs, preventing AttributeError when using custom models. Related to issue #1810 Signed-off-by: Zhaopeng Qiu <alexq@nvidia.com>
d6590d7 to
8809467
Compare
|
|
Note Other AI code review bot(s) detectedCodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review. 📝 WalkthroughWalkthroughModified logits extraction logic in a policy worker to handle multiple model output shapes uniformly. Added explicit memory cleanup of outputs objects and standardized float32 conversion for logits across multiple methods. Changes
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~20 minutes 🚥 Pre-merge checks | ✅ 2 | ❌ 2❌ Failed checks (2 warnings)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing touches
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
🤖 Fix all issues with AI agents
In `@nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py`:
- Around line 826-834: The later cleanup currently attempts to "del outputs,
logits" but "outputs" was already deleted earlier in the code, causing a
NameError; update that later cleanup to only delete "logits" (or guard deletion
with a check like 'if "outputs" in locals()' before deleting) so you don't
attempt to delete the already-removed variable — target the later del statement
that references outputs and logits and remove outputs from it (or make it
conditional) while keeping logits deletion.
♻️ Duplicate comments (1)
nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py (1)
1239-1239: Comment is inconsistent with code.The comment says "Get logprobs" but the code extracts logits. This should be "Get logits" to be consistent with the comments at lines 826 and 1082.
Suggested fix
- # Get logprobs + # Get logits
|
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Signed-off-by: alexchiu <alexq@nvidia.com>
fe4a0ac to
abe38e1
Compare
|
This commit fixes incomplete handling of custom MoE models (e.g., from AutoModel) that output logits directly as tensors instead of objects with a .logits attribute.
Changes:
All methods now follow the same pattern as the train() method for extracting logits from model outputs, preventing AttributeError when using custom models.
Related to issue #1810
What does this PR do ?
Add a one line overview of what this PR aims to accomplish.
Issues
List issues that this PR closes (syntax):
Usage
# Add a code snippet demonstrating how to use thisBefore your PR is "Ready for review"
Pre checks:
Additional Information
Summary by CodeRabbit
✏️ Tip: You can customize this high-level summary in your review settings.