Skip to content

Conversation

@RayenTian
Copy link
Contributor

@RayenTian RayenTian commented Jan 20, 2026

Summary

  • Merge LoRA adapter weights into base linear weights when exporting dtensor state and skip standalone LoRA adapter tensors.
  • Add LoRA configuration defaults to grpo_math_1B.yaml and introduce a Qwen3-8B LoRA recipe.
  • Expand LoRA coverage in functional and unit tests (vLLM generation + GRPO LoRA suites).

Issues

closes #1597

Changes

  • nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py
    • Merge LoRA weights into base weights during state export.
    • Skip lora_A/lora_B tensors and release temporary tensors to reduce memory.
  • examples/configs/grpo_math_1B.yaml
    • Add LoRA config section with defaults/documentation.
  • examples/configs/recipes/llm/grpo-qwen3-8B-base-1n8g-fsdp2-lora.yaml
    • New LoRA recipe for Qwen3-8B.
  • tests/functional/*
    • Add GRPO LoRA functional tests (sync/async/non-colocated) and include in nightly.
  • tests/unit/models/generation/test_vllm_generation.py
    • Add LoRA config coverage and parameters in vLLM tests.

Testing

image

The compare of two implementation approach: Split Weight and Merge Weight

Split Weight

PRs:

Design

Enable Lora for both vllm and dtensor backend. Update base model weight and lora weight at the first step, and only refit lora weights in the following step.

Both inference and train compute with $xW^T + scale\times x (BA)^T$

  • The code is more complex. Additional processing is required for the refit process.
  • ❗ VLLM and AutoModel handle LoRA addition to the lm_head differently for various models. A patch is needed for VLLM to skip the lm_head. (refer to https://nvidia.slack.com/archives/C0A1J3A7G85/p1768803141238469)
  • ❗ Performance degrades because VLLM computes much slower with LoRA enabled.

Results:

  • Time

    Qwen3-0.6B
    lora_dim(128)
    SeqLen(4096)
    step (33)
    Total Time(s)/Step Generation Time(s)/Step Train Time(s)/Step Refit Time(s)/Step
    W/O Lora 55 37 10.2 0.4
    Lora Split Weight 67 48 9.9 0.5
    • Enabling LoRA in vLLM leads to a significant performance drop.
    • Refit Latency: Even with our current approach of only updating LoRA weights, we still see a slight increase in refit time.

Merge Weight

PR:

Design

Only enable lora in dtensor backend and VLLM is agnostic to LoRA. Merge base model weight and lora weight before refit each time.

Dtensor compute lora via $xW^T + scale\times x (BA)^T$

Vllm compute lora via $x W_{vllm}^T= x (W^T + scale\times(BA)^T)$

  • 👍The code is simpler; no extra consideration is needed for lm_head handling across different frameworks.
  • 👍Superior performance.
  • ❗Extra training-inference discrepancy is introduced because the calculations of $xW^T + scale\times x (BA)^T$ and $x W_{vllm}^T= x (W^T + scale\times(BA)^T)$ are not bitwise identical across hardware.
    • Post several dozen training steps, the train/gen_kl_error exhibits a slight mismatch. Notably, this divergence does not escalate without bound but rather stabilizes at a certain level.
    • Can observe this result on both qwen and llama.
    • Our findings confirm that identical data and the NeMo RL DTensor backend do not guarantee identical logprobs. Specifically, a non-zero KL divergence is observed between the merged approach (base + LoRA weights combined upfront) and the additive approach (base and LoRA outputs computed independently and summed).
    • However, this gen_kl_error can be controlled to within 0.001, so we believe this is acceptable.

Result

Qwen3-0.6B
lora_dim(128)
SeqLen(4096)
step (33)
Total Time(s)/Step Generation Time(s)/Step Train Time(s)/Step Refit Time(s)/Step
W/O Lora 55 37 10.2 0.4
Lora Split Weight 67 48 9.9 0.5
Lora Merge Weight 56.9 36.58 10.57 0.549
llama3.1-8B
lora_dim(128)
SeqLen(1024)
step (33)
Total Time(s)/Step Generation Time(s)/Step Train Time(s)/Step Refit Time(s)/Step
W/O Lora 28.82 10.4 8.55 0.54
Lora Merge Weight 30.98 10.9 7.6 0.74

Summary by CodeRabbit

  • New Features

    • Added LoRA (Low-Rank Adaptation) support for policy training with configurable rank, scaling, and dropout options
    • New Qwen3 8B model configuration with LoRA enabled for GRPO experiments
    • Added validation to prevent incompatible LoRA configurations
  • Tests

    • Added comprehensive functional and unit tests for LoRA scenarios including async and non-colocated training configurations

✏️ Tip: You can customize this high-level summary in your review settings.

@github-actions
Copy link

⚠️ File Consistency Check

Check based on commit: 1ccb5be (PR #1797 from ruit/lora_merge_weight)

⚠️ DTensor Policy Worker Synchronization Warning

The file nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py was modified in this PR, but nemo_rl/models/policy/workers/dtensor_policy_worker.py was not updated.

Why this matters:
These files contain related DTensor policy worker implementations that should be kept synchronized to ensure consistency across different versions.

Action required:

  • Please review if the changes in nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py should also be applied to nemo_rl/models/policy/workers/dtensor_policy_worker.py
  • Update nemo_rl/models/policy/workers/dtensor_policy_worker.py if necessary to maintain consistency
  • If the files are intentionally different, please add a comment in the PR explaining why

Files to check:

  • Modified: nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py
  • Not modified: nemo_rl/models/policy/workers/dtensor_policy_worker.py

This check ensures that related file implementations remain synchronized across the codebase. If you believe this warning is incorrect or the files should intentionally differ, please add a comment explaining the reasoning.

@RayenTian RayenTian added the CI:L1 Run doctests, unit tests, and functional tests label Jan 20, 2026
@RayenTian RayenTian removed the CI:L1 Run doctests, unit tests, and functional tests label Jan 20, 2026
@github-actions
Copy link

⚠️ File Consistency Check

Check based on commit: acad57c (PR #1797 from ruit/lora_merge_weight)

⚠️ DTensor Policy Worker Synchronization Warning

The file nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py was modified in this PR, but nemo_rl/models/policy/workers/dtensor_policy_worker.py was not updated.

Why this matters:
These files contain related DTensor policy worker implementations that should be kept synchronized to ensure consistency across different versions.

Action required:

  • Please review if the changes in nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py should also be applied to nemo_rl/models/policy/workers/dtensor_policy_worker.py
  • Update nemo_rl/models/policy/workers/dtensor_policy_worker.py if necessary to maintain consistency
  • If the files are intentionally different, please add a comment in the PR explaining why

Files to check:

  • Modified: nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py
  • Not modified: nemo_rl/models/policy/workers/dtensor_policy_worker.py

This check ensures that related file implementations remain synchronized across the codebase. If you believe this warning is incorrect or the files should intentionally differ, please add a comment explaining the reasoning.

@RayenTian RayenTian added the CI:L1 Run doctests, unit tests, and functional tests label Jan 21, 2026
@github-actions
Copy link

⚠️ File Consistency Check

Check based on commit: 68263ea (PR #1797 from ruit/lora_merge_weight)

⚠️ DTensor Policy Worker Synchronization Warning

The file nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py was modified in this PR, but nemo_rl/models/policy/workers/dtensor_policy_worker.py was not updated.

Why this matters:
These files contain related DTensor policy worker implementations that should be kept synchronized to ensure consistency across different versions.

Action required:

  • Please review if the changes in nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py should also be applied to nemo_rl/models/policy/workers/dtensor_policy_worker.py
  • Update nemo_rl/models/policy/workers/dtensor_policy_worker.py if necessary to maintain consistency
  • If the files are intentionally different, please add a comment in the PR explaining why

Files to check:

  • Modified: nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py
  • Not modified: nemo_rl/models/policy/workers/dtensor_policy_worker.py

This check ensures that related file implementations remain synchronized across the codebase. If you believe this warning is incorrect or the files should intentionally differ, please add a comment explaining the reasoning.

@RayenTian RayenTian force-pushed the ruit/lora_merge_weight branch from 68263ea to 7db41f8 Compare January 26, 2026 06:52
@github-actions
Copy link

⚠️ File Consistency Check

Check based on commit: 7db41f8 (PR #1797 from ruit/lora_merge_weight)

⚠️ DTensor Policy Worker Synchronization Warning

The file nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py was modified in this PR, but nemo_rl/models/policy/workers/dtensor_policy_worker.py was not updated.

Why this matters:
These files contain related DTensor policy worker implementations that should be kept synchronized to ensure consistency across different versions.

Action required:

  • Please review if the changes in nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py should also be applied to nemo_rl/models/policy/workers/dtensor_policy_worker.py
  • Update nemo_rl/models/policy/workers/dtensor_policy_worker.py if necessary to maintain consistency
  • If the files are intentionally different, please add a comment in the PR explaining why

Files to check:

  • Modified: nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py
  • Not modified: nemo_rl/models/policy/workers/dtensor_policy_worker.py

This check ensures that related file implementations remain synchronized across the codebase. If you believe this warning is incorrect or the files should intentionally differ, please add a comment explaining the reasoning.

@RayenTian RayenTian added CI:L1 Run doctests, unit tests, and functional tests and removed CI:L1 Run doctests, unit tests, and functional tests labels Jan 26, 2026
@RayenTian RayenTian marked this pull request as ready for review January 28, 2026 02:09
@RayenTian RayenTian requested review from a team as code owners January 28, 2026 02:09
@coderabbitai
Copy link
Contributor

coderabbitai bot commented Jan 28, 2026

📝 Walkthrough

Walkthrough

This pull request adds comprehensive LoRA (Low-Rank Adaptation) support to the GRPO framework. Changes include new LoRA configuration blocks, runtime quantization compatibility checks, weight merging logic in the DTensor policy worker, patched LinearLoRA forward implementations, and extensive functional and unit tests covering standard, async, and non-colocated deployment scenarios.

Changes

Cohort / File(s) Summary
LoRA Configuration Blocks
examples/configs/grpo_math_1B.yaml, examples/configs/recipes/llm/grpo-qwen3-8B-base-1n8g-fsdp2-lora.yaml
Added lora\_cfg configuration with parameters for rank, alpha, dropout, module targeting, and Triton kernel usage; Qwen3 8B base model configuration with LoRA enabled (dim=128, alpha=128)
LoRA Runtime Validation
nemo_rl/models/automodel/setup.py
Added guard to prevent LoRA usage with quantized models by checking for bitsandbytes QuantState; raises assertion if incompatibility detected
LoRA Weight Integration
nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py
Introduced \_maybe\_merge\_lora\_weight helper for merging LinearLoRA weights into base tensors; modified dtensor\_params\_generator to skip LoRA-specific parameters and apply merged weights; extended prepare\_refit\_info to exclude lora weights from state-dict metadata; integrated LoRA forward patch during initialization
LoRA Forward Patch
nemo_rl/models/policy/workers/patches.py
Added patched\_lora\_linear\_forward for LinearLoRA modules with configurable dropout timing; added apply\_lora\_linear\_forward\_patch utility to monkey-patch the forward method at runtime
Functional Tests (GPU)
tests/functional/grpo_automodel_lora.sh, tests/functional/grpo_automodel_lora_async.sh, tests/functional/grpo_automodel_lora_non_colocated.sh
Three new test scripts covering LoRA with standard synchronous, async vLLM engine, and non-colocated worker setups; each includes TensorBoard metrics validation (train/reward > 0.06 at step 3)
Functional Test Registry
tests/functional/L1\\_Functional\\_Tests\\_GPU.sh
Registered three new LoRA test scripts in functional test harness
vLLM Generation Unit Tests
tests/unit/models/generation/test_vllm_generation.py
Extended test fixtures with LoRAConfig; added enable\_lora parameter to run\_hf\_train\_process and test functions; wired LoRA configuration into vLLM and DTensor paths; added guard preventing LoRA with FP8 precision
Test Suite Integration
tests/test_suites/llm/grpo-qwen3-8B-base-1n8g-fsdp2-lora.sh
New Qwen3 8B LoRA test script with metrics validation and checkpoint cleanup
Nightly Test Registry
tests/test_suites/nightly.txt
Registered Qwen3 8B LoRA test in nightly suite under new "lora" subsection

Sequence Diagram

sequenceDiagram
    participant Test as Test Harness
    participant Setup as setup_model_and_optimizer
    participant Worker as dtensor_policy_worker_v2
    participant Patches as patches.py
    participant Model as LinearLoRA Module

    Test->>Setup: Initialize with lora_enabled=true
    Setup->>Setup: Check for quantized modules
    Setup-->>Setup: Assert no quantization + LoRA
    Setup->>Worker: Create worker with LoRA config
    Worker->>Patches: apply_lora_linear_forward_patch()
    Patches->>Model: Monkey-patch forward method
    Note over Worker: During inference/training
    Worker->>Worker: dtensor_params_generator iterates params
    Worker->>Worker: Skip lora_A.weight, lora_B.weight
    Worker->>Worker: _maybe_merge_lora_weight(fqn, tensor)
    Worker-->>Worker: Merge lora_B @ lora_A * scale
    Worker->>Model: Forward pass with merged weights
    Model->>Model: patched_lora_linear_forward(x)
    Model-->>Model: Apply dropout & linear transform
    Model-->>Worker: Output tensor
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~50 minutes

Possibly related PRs

Suggested reviewers

  • joyang-nv
  • yuki-97
  • terrykong
🚥 Pre-merge checks | ✅ 2 | ❌ 2
❌ Failed checks (2 warnings)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 61.54% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
Test Results For Major Changes ⚠️ Warning PR introduces major LoRA feature with complex weight merging and forward pass modifications, but explicitly states testing was not run and provides no test results or convergence validation. Execute functional tests and document results in PR description, including convergence validation and numerical output verification to demonstrate the feature works correctly.
✅ Passed checks (2 passed)
Check name Status Explanation
Title check ✅ Passed The title clearly and specifically describes the main change: adding LoRA support in dtensor GRPO workflow with weight merging, which aligns with the core implementation across multiple modified files.
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing touches
  • 📝 Generate docstrings

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py (1)

83-95: Stale docstring parameter.

The docstring mentions peft_config as a parameter, but it doesn't exist in the function signature. This appears to be leftover from earlier iterations.

📝 Suggested fix
 def dtensor_params_generator(
     model: nn.Module, target_dtype: torch.dtype
 ) -> Generator[tuple[str, torch.Tensor], None, None]:
     """Generator that yields (name, tensor) pairs, converting DTensors to local tensors and adapting to HF format.

     Args:
         model: The model whose parameters to generate.
         target_dtype: The dtype to convert tensors to.
-        peft_config: Optional LoRA config for filtering which layers to merge.

     Yields:
         Tuples of (fully_qualified_name, tensor) where tensors are converted to target dtype and made contiguous.
     """
🤖 Fix all issues with AI agents
In `@nemo_rl/models/automodel/setup.py`:
- Around line 482-496: Change the broad exception and the assert: replace the
bare except Exception around "import bitsandbytes as bnb" with "except
ImportError" to narrowly catch import failures, and replace the "assert False,
'Quantized modules are not supported with LoRA'" inside the loop that checks
module.quant_state (within the lora_enabled block iterating model.modules())
with a raised exception such as "raise AssertionError('Quantized modules are not
supported with LoRA')" so the guard can't be stripped by optimized Python.

In `@tests/unit/models/generation/test_vllm_generation.py`:
- Around line 1005-1007: Move the LoRA+FP8 compatibility assertion so it runs
before the vLLM policy is instantiated: place the assert not (enable_lora and
vllm_precision == "fp8") check immediately before the code that creates/assigns
vllm_policy (the vllm_policy creation block) so the test validates configuration
(enable_lora and vllm_precision) first and avoids initializing vllm_policy when
the combination is invalid.
🧹 Nitpick comments (1)
tests/unit/models/generation/test_vllm_generation.py (1)

708-716: **Consider using or removing the unused enable_lora parameter.**The enable_lora parameter is passed to run_hf_train_process but is not used within the function body. While it may be intentionally included for test identification or future use, it creates dead code that could confuse maintainers.

Either use the parameter for LoRA-specific behavior validation within the function, or prefix it with an underscore (_enable_lora) to indicate it's intentionally unused.

Copy link
Contributor

@yuki-97 yuki-97 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks @RayenTian , overall LGTM!

I remember previously we seem to have some memory issue after adding lora in grpo? how is it now?

Signed-off-by: ruit <ruit@nvidia.com>
Signed-off-by: ruit <ruit@nvidia.com>
@github-actions
Copy link

⚠️ File Consistency Check

Check based on commit: a787336 (PR #1797 from ruit/lora_merge_weight)

⚠️ DTensor Policy Worker Synchronization Warning

The file nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py was modified in this PR, but nemo_rl/models/policy/workers/dtensor_policy_worker.py was not updated.

Why this matters:
These files contain related DTensor policy worker implementations that should be kept synchronized to ensure consistency across different versions.

Action required:

  • Please review if the changes in nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py should also be applied to nemo_rl/models/policy/workers/dtensor_policy_worker.py
  • Update nemo_rl/models/policy/workers/dtensor_policy_worker.py if necessary to maintain consistency
  • If the files are intentionally different, please add a comment in the PR explaining why

Files to check:

  • Modified: nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py
  • Not modified: nemo_rl/models/policy/workers/dtensor_policy_worker.py

This check ensures that related file implementations remain synchronized across the codebase. If you believe this warning is incorrect or the files should intentionally differ, please add a comment explaining the reasoning.

Signed-off-by: ruit <ruit@nvidia.com>
@RayenTian RayenTian force-pushed the ruit/lora_merge_weight branch from a787336 to 4900dfc Compare January 28, 2026 08:05
@github-actions
Copy link

⚠️ File Consistency Check

Check based on commit: 4900dfc (PR #1797 from ruit/lora_merge_weight)

⚠️ DTensor Policy Worker Synchronization Warning

The file nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py was modified in this PR, but nemo_rl/models/policy/workers/dtensor_policy_worker.py was not updated.

Why this matters:
These files contain related DTensor policy worker implementations that should be kept synchronized to ensure consistency across different versions.

Action required:

  • Please review if the changes in nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py should also be applied to nemo_rl/models/policy/workers/dtensor_policy_worker.py
  • Update nemo_rl/models/policy/workers/dtensor_policy_worker.py if necessary to maintain consistency
  • If the files are intentionally different, please add a comment in the PR explaining why

Files to check:

  • Modified: nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py
  • Not modified: nemo_rl/models/policy/workers/dtensor_policy_worker.py

This check ensures that related file implementations remain synchronized across the codebase. If you believe this warning is incorrect or the files should intentionally differ, please add a comment explaining the reasoning.

@RayenTian
Copy link
Contributor Author

thanks @RayenTian , overall LGTM!

I remember previously we seem to have some memory issue after adding lora in grpo? how is it now?

I retested it on main TOT, and it also meet an OOM.

@RayenTian RayenTian added CI:L1 Run doctests, unit tests, and functional tests and removed CI:L1 Run doctests, unit tests, and functional tests labels Jan 28, 2026
@github-actions
Copy link

⚠️ File Consistency Check

Check based on commit: 4b16dd9 (PR #1797 from ruit/lora_merge_weight)

⚠️ DTensor Policy Worker Synchronization Warning

The file nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py was modified in this PR, but nemo_rl/models/policy/workers/dtensor_policy_worker.py was not updated.

Why this matters:
These files contain related DTensor policy worker implementations that should be kept synchronized to ensure consistency across different versions.

Action required:

  • Please review if the changes in nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py should also be applied to nemo_rl/models/policy/workers/dtensor_policy_worker.py
  • Update nemo_rl/models/policy/workers/dtensor_policy_worker.py if necessary to maintain consistency
  • If the files are intentionally different, please add a comment in the PR explaining why

Files to check:

  • Modified: nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py
  • Not modified: nemo_rl/models/policy/workers/dtensor_policy_worker.py

This check ensures that related file implementations remain synchronized across the codebase. If you believe this warning is incorrect or the files should intentionally differ, please add a comment explaining the reasoning.

@github-actions
Copy link

⚠️ File Consistency Check

Check based on commit: 1c362b6 (PR #1797 from ruit/lora_merge_weight)

⚠️ DTensor Policy Worker Synchronization Warning

The file nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py was modified in this PR, but nemo_rl/models/policy/workers/dtensor_policy_worker.py was not updated.

Why this matters:
These files contain related DTensor policy worker implementations that should be kept synchronized to ensure consistency across different versions.

Action required:

  • Please review if the changes in nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py should also be applied to nemo_rl/models/policy/workers/dtensor_policy_worker.py
  • Update nemo_rl/models/policy/workers/dtensor_policy_worker.py if necessary to maintain consistency
  • If the files are intentionally different, please add a comment in the PR explaining why

Files to check:

  • Modified: nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py
  • Not modified: nemo_rl/models/policy/workers/dtensor_policy_worker.py

This check ensures that related file implementations remain synchronized across the codebase. If you believe this warning is incorrect or the files should intentionally differ, please add a comment explaining the reasoning.

@RayenTian RayenTian added CI:L1 Run doctests, unit tests, and functional tests and removed CI:L1 Run doctests, unit tests, and functional tests labels Jan 29, 2026
@terrykong terrykong enabled auto-merge (squash) January 29, 2026 08:20
@github-actions
Copy link

⚠️ File Consistency Check

Check based on commit: 54c9495 (PR #1797 from ruit/lora_merge_weight)

⚠️ DTensor Policy Worker Synchronization Warning

The file nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py was modified in this PR, but nemo_rl/models/policy/workers/dtensor_policy_worker.py was not updated.

Why this matters:
These files contain related DTensor policy worker implementations that should be kept synchronized to ensure consistency across different versions.

Action required:

  • Please review if the changes in nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py should also be applied to nemo_rl/models/policy/workers/dtensor_policy_worker.py
  • Update nemo_rl/models/policy/workers/dtensor_policy_worker.py if necessary to maintain consistency
  • If the files are intentionally different, please add a comment in the PR explaining why

Files to check:

  • Modified: nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py
  • Not modified: nemo_rl/models/policy/workers/dtensor_policy_worker.py

This check ensures that related file implementations remain synchronized across the codebase. If you believe this warning is incorrect or the files should intentionally differ, please add a comment explaining the reasoning.

@github-actions
Copy link

⚠️ File Consistency Check

Check based on commit: 8818d97 (PR #1797 from ruit/lora_merge_weight)

⚠️ DTensor Policy Worker Synchronization Warning

The file nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py was modified in this PR, but nemo_rl/models/policy/workers/dtensor_policy_worker.py was not updated.

Why this matters:
These files contain related DTensor policy worker implementations that should be kept synchronized to ensure consistency across different versions.

Action required:

  • Please review if the changes in nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py should also be applied to nemo_rl/models/policy/workers/dtensor_policy_worker.py
  • Update nemo_rl/models/policy/workers/dtensor_policy_worker.py if necessary to maintain consistency
  • If the files are intentionally different, please add a comment in the PR explaining why

Files to check:

  • Modified: nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py
  • Not modified: nemo_rl/models/policy/workers/dtensor_policy_worker.py

This check ensures that related file implementations remain synchronized across the codebase. If you believe this warning is incorrect or the files should intentionally differ, please add a comment explaining the reasoning.

@yuki-97 yuki-97 added CI:L1 Run doctests, unit tests, and functional tests and removed CI:L1 Run doctests, unit tests, and functional tests labels Jan 30, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CI:L1 Run doctests, unit tests, and functional tests

Projects

None yet

Development

Successfully merging this pull request may close these issues.

LoRa DTensor GPRO

4 participants