Skip to content

Conversation

@zpqiu
Copy link
Contributor

@zpqiu zpqiu commented Jan 23, 2026

This commit fixes incomplete handling of custom MoE models (e.g., from AutoModel) that output logits directly as tensors instead of objects with a .logits attribute.

Changes:

  • get_logprobs(): Updated to use complete logits extraction logic with proper handling for all three cases (tensor output, missing .logits, has .logits)
  • score(): Added missing isinstance check for tensor outputs and fixed the float32 conversion bug that incorrectly referenced outputs.logits
  • get_topk_logits(): Already had the correct logic, ensuring consistency

All methods now follow the same pattern as the train() method for extracting logits from model outputs, preventing AttributeError when using custom models.

Related to issue #1810

What does this PR do ?

Add a one line overview of what this PR aims to accomplish.

Issues

List issues that this PR closes (syntax):

Usage

  • You can potentially add a usage example below
# Add a code snippet demonstrating how to use this

Before your PR is "Ready for review"

Pre checks:

  • Make sure you read and followed Contributor guidelines
  • Did you write any new necessary tests?
  • Did you run the unit tests and functional tests locally? Visit our Testing Guide for how to run tests
  • Did you add or update any necessary documentation? Visit our Document Development Guide for how to write, build and test the docs.

Additional Information

  • ...

Summary by CodeRabbit

  • Refactor
    • Improved model output handling for better compatibility across different architectures.
    • Optimized memory management for enhanced efficiency.
    • Ensured consistent data type handling across processing operations.

✏️ Tip: You can customize this high-level summary in your review settings.

@zpqiu zpqiu requested a review from Copilot January 23, 2026 06:29
@zpqiu zpqiu requested review from a team as code owners January 23, 2026 06:29
@github-actions
Copy link

⚠️ File Consistency Check

Check based on commit: d6590d7 (PR #1815 from fix/custom-model-logits-extraction)

⚠️ DTensor Policy Worker Synchronization Warning

The file nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py was modified in this PR, but nemo_rl/models/policy/workers/dtensor_policy_worker.py was not updated.

Why this matters:
These files contain related DTensor policy worker implementations that should be kept synchronized to ensure consistency across different versions.

Action required:

  • Please review if the changes in nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py should also be applied to nemo_rl/models/policy/workers/dtensor_policy_worker.py
  • Update nemo_rl/models/policy/workers/dtensor_policy_worker.py if necessary to maintain consistency
  • If the files are intentionally different, please add a comment in the PR explaining why

Files to check:

  • Modified: nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py
  • Not modified: nemo_rl/models/policy/workers/dtensor_policy_worker.py

This check ensures that related file implementations remain synchronized across the codebase. If you believe this warning is incorrect or the files should intentionally differ, please add a comment explaining the reasoning.

Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR standardizes logits extraction logic across multiple inference methods to properly handle custom MoE models that output logits directly as tensors instead of objects with a .logits attribute.

Changes:

  • Updated get_logprobs(), score(), and get_topk_logits() to use consistent three-case logits extraction pattern (tensor output, missing .logits, has .logits)
  • Fixed float32 conversion bug in score() that incorrectly referenced outputs.logits instead of the extracted logits variable
  • Added del outputs statements to release memory after logits extraction

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

This commit fixes incomplete handling of custom MoE models (e.g., from AutoModel)
that output logits directly as tensors instead of objects with a .logits attribute.

Changes:
- get_logprobs(): Updated to use complete logits extraction logic with proper
  handling for all three cases (tensor output, missing .logits, has .logits)
- score(): Added missing isinstance check for tensor outputs and fixed the
  float32 conversion bug that incorrectly referenced outputs.logits
- get_topk_logits(): Already had the correct logic, ensuring consistency

All methods now follow the same pattern as the train() method for extracting
logits from model outputs, preventing AttributeError when using custom models.

Related to issue #1810

Signed-off-by: Zhaopeng Qiu <alexq@nvidia.com>
@zpqiu zpqiu force-pushed the fix/custom-model-logits-extraction branch from d6590d7 to 8809467 Compare January 23, 2026 06:32
@github-actions
Copy link

⚠️ File Consistency Check

Check based on commit: 8809467 (PR #1815 from fix/custom-model-logits-extraction)

⚠️ DTensor Policy Worker Synchronization Warning

The file nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py was modified in this PR, but nemo_rl/models/policy/workers/dtensor_policy_worker.py was not updated.

Why this matters:
These files contain related DTensor policy worker implementations that should be kept synchronized to ensure consistency across different versions.

Action required:

  • Please review if the changes in nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py should also be applied to nemo_rl/models/policy/workers/dtensor_policy_worker.py
  • Update nemo_rl/models/policy/workers/dtensor_policy_worker.py if necessary to maintain consistency
  • If the files are intentionally different, please add a comment in the PR explaining why

Files to check:

  • Modified: nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py
  • Not modified: nemo_rl/models/policy/workers/dtensor_policy_worker.py

This check ensures that related file implementations remain synchronized across the codebase. If you believe this warning is incorrect or the files should intentionally differ, please add a comment explaining the reasoning.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Jan 23, 2026

Note

Other AI code review bot(s) detected

CodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review.

📝 Walkthrough

Walkthrough

Modified logits extraction logic in a policy worker to handle multiple model output shapes uniformly. Added explicit memory cleanup of outputs objects and standardized float32 conversion for logits across multiple methods.

Changes

Cohort / File(s) Summary
Logits extraction and memory management
nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py
Unified logits extraction across get_logprobs, score, and get_topk_logits methods to support: direct Tensor/DTensor outputs, computation from last_hidden_state via lm_head when logits attribute missing, and fallback to outputs.logits. Added explicit deletion of outputs objects post-extraction for memory efficiency. Replaced conditional float32 conversions with uniform logits.to(torch.float32) call.

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~20 minutes

🚥 Pre-merge checks | ✅ 2 | ❌ 2
❌ Failed checks (2 warnings)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 66.67% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
Test Results For Major Changes ⚠️ Warning PR description lacks test results or testing information for significant changes to critical inference methods affecting logits extraction across multiple code paths. Add test results demonstrating changes work for custom and Hugging Face models, validate no regressions occur, and fix critical NameError at line 932 by removing outputs from deletion statement.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title clearly and specifically describes the main change: unifying custom model logits extraction across inference methods, which directly aligns with the PR's core objective of fixing incomplete handling of custom models.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing touches
  • 📝 Generate docstrings

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Fix all issues with AI agents
In `@nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py`:
- Around line 826-834: The later cleanup currently attempts to "del outputs,
logits" but "outputs" was already deleted earlier in the code, causing a
NameError; update that later cleanup to only delete "logits" (or guard deletion
with a check like 'if "outputs" in locals()' before deleting) so you don't
attempt to delete the already-removed variable — target the later del statement
that references outputs and logits and remove outputs from it (or make it
conditional) while keeping logits deletion.
♻️ Duplicate comments (1)
nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py (1)

1239-1239: Comment is inconsistent with code.

The comment says "Get logprobs" but the code extracts logits. This should be "Get logits" to be consistent with the comments at lines 826 and 1082.

Suggested fix
-                    # Get logprobs
+                    # Get logits

@github-actions
Copy link

⚠️ File Consistency Check

Check based on commit: fe4a0ac (PR #1815 from fix/custom-model-logits-extraction)

⚠️ DTensor Policy Worker Synchronization Warning

The file nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py was modified in this PR, but nemo_rl/models/policy/workers/dtensor_policy_worker.py was not updated.

Why this matters:
These files contain related DTensor policy worker implementations that should be kept synchronized to ensure consistency across different versions.

Action required:

  • Please review if the changes in nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py should also be applied to nemo_rl/models/policy/workers/dtensor_policy_worker.py
  • Update nemo_rl/models/policy/workers/dtensor_policy_worker.py if necessary to maintain consistency
  • If the files are intentionally different, please add a comment in the PR explaining why

Files to check:

  • Modified: nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py
  • Not modified: nemo_rl/models/policy/workers/dtensor_policy_worker.py

This check ensures that related file implementations remain synchronized across the codebase. If you believe this warning is incorrect or the files should intentionally differ, please add a comment explaining the reasoning.

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Signed-off-by: alexchiu <alexq@nvidia.com>
@zpqiu zpqiu force-pushed the fix/custom-model-logits-extraction branch from fe4a0ac to abe38e1 Compare January 23, 2026 06:56
@github-actions
Copy link

⚠️ File Consistency Check

Check based on commit: abe38e1 (PR #1815 from fix/custom-model-logits-extraction)

⚠️ DTensor Policy Worker Synchronization Warning

The file nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py was modified in this PR, but nemo_rl/models/policy/workers/dtensor_policy_worker.py was not updated.

Why this matters:
These files contain related DTensor policy worker implementations that should be kept synchronized to ensure consistency across different versions.

Action required:

  • Please review if the changes in nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py should also be applied to nemo_rl/models/policy/workers/dtensor_policy_worker.py
  • Update nemo_rl/models/policy/workers/dtensor_policy_worker.py if necessary to maintain consistency
  • If the files are intentionally different, please add a comment in the PR explaining why

Files to check:

  • Modified: nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py
  • Not modified: nemo_rl/models/policy/workers/dtensor_policy_worker.py

This check ensures that related file implementations remain synchronized across the codebase. If you believe this warning is incorrect or the files should intentionally differ, please add a comment explaining the reasoning.

@zpqiu zpqiu added the CI:L2 Run doctests, unit tests, functional tests, and convergence tests label Jan 23, 2026
@zpqiu zpqiu requested a review from terrykong January 24, 2026 00:14
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CI:L2 Run doctests, unit tests, functional tests, and convergence tests

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants