-
Notifications
You must be signed in to change notification settings - Fork 233
feat: Support lora in dtensor grpo workflow by merging weight #1797
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
|
|
68263ea to
7db41f8
Compare
|
📝 WalkthroughWalkthroughThis pull request adds comprehensive LoRA (Low-Rank Adaptation) support to the GRPO framework. Changes include new LoRA configuration blocks, runtime quantization compatibility checks, weight merging logic in the DTensor policy worker, patched LinearLoRA forward implementations, and extensive functional and unit tests covering standard, async, and non-colocated deployment scenarios. Changes
Sequence DiagramsequenceDiagram
participant Test as Test Harness
participant Setup as setup_model_and_optimizer
participant Worker as dtensor_policy_worker_v2
participant Patches as patches.py
participant Model as LinearLoRA Module
Test->>Setup: Initialize with lora_enabled=true
Setup->>Setup: Check for quantized modules
Setup-->>Setup: Assert no quantization + LoRA
Setup->>Worker: Create worker with LoRA config
Worker->>Patches: apply_lora_linear_forward_patch()
Patches->>Model: Monkey-patch forward method
Note over Worker: During inference/training
Worker->>Worker: dtensor_params_generator iterates params
Worker->>Worker: Skip lora_A.weight, lora_B.weight
Worker->>Worker: _maybe_merge_lora_weight(fqn, tensor)
Worker-->>Worker: Merge lora_B @ lora_A * scale
Worker->>Model: Forward pass with merged weights
Model->>Model: patched_lora_linear_forward(x)
Model-->>Model: Apply dropout & linear transform
Model-->>Worker: Output tensor
Estimated code review effort🎯 4 (Complex) | ⏱️ ~50 minutes Possibly related PRs
Suggested reviewers
🚥 Pre-merge checks | ✅ 2 | ❌ 2❌ Failed checks (2 warnings)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing touches
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py (1)
83-95: Stale docstring parameter.The docstring mentions
peft_configas a parameter, but it doesn't exist in the function signature. This appears to be leftover from earlier iterations.📝 Suggested fix
def dtensor_params_generator( model: nn.Module, target_dtype: torch.dtype ) -> Generator[tuple[str, torch.Tensor], None, None]: """Generator that yields (name, tensor) pairs, converting DTensors to local tensors and adapting to HF format. Args: model: The model whose parameters to generate. target_dtype: The dtype to convert tensors to. - peft_config: Optional LoRA config for filtering which layers to merge. Yields: Tuples of (fully_qualified_name, tensor) where tensors are converted to target dtype and made contiguous. """
🤖 Fix all issues with AI agents
In `@nemo_rl/models/automodel/setup.py`:
- Around line 482-496: Change the broad exception and the assert: replace the
bare except Exception around "import bitsandbytes as bnb" with "except
ImportError" to narrowly catch import failures, and replace the "assert False,
'Quantized modules are not supported with LoRA'" inside the loop that checks
module.quant_state (within the lora_enabled block iterating model.modules())
with a raised exception such as "raise AssertionError('Quantized modules are not
supported with LoRA')" so the guard can't be stripped by optimized Python.
In `@tests/unit/models/generation/test_vllm_generation.py`:
- Around line 1005-1007: Move the LoRA+FP8 compatibility assertion so it runs
before the vLLM policy is instantiated: place the assert not (enable_lora and
vllm_precision == "fp8") check immediately before the code that creates/assigns
vllm_policy (the vllm_policy creation block) so the test validates configuration
(enable_lora and vllm_precision) first and avoids initializing vllm_policy when
the combination is invalid.
🧹 Nitpick comments (1)
tests/unit/models/generation/test_vllm_generation.py (1)
708-716: **Consider using or removing the unusedenable_loraparameter.**Theenable_loraparameter is passed torun_hf_train_processbut is not used within the function body. While it may be intentionally included for test identification or future use, it creates dead code that could confuse maintainers.Either use the parameter for LoRA-specific behavior validation within the function, or prefix it with an underscore (
_enable_lora) to indicate it's intentionally unused.
yuki-97
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks @RayenTian , overall LGTM!
I remember previously we seem to have some memory issue after adding lora in grpo? how is it now?
Signed-off-by: ruit <ruit@nvidia.com>
Signed-off-by: ruit <ruit@nvidia.com>
Signed-off-by: ruit <ruit@nvidia.com>
|
Signed-off-by: ruit <ruit@nvidia.com>
a787336 to
4900dfc
Compare
|
I retested it on main TOT, and it also meet an OOM. |
|
|
|
|
Summary
grpo_math_1B.yamland introduce a Qwen3-8B LoRA recipe.Issues
closes #1597
Changes
nemo_rl/models/policy/workers/dtensor_policy_worker_v2.pylora_A/lora_Btensors and release temporary tensors to reduce memory.examples/configs/grpo_math_1B.yamlexamples/configs/recipes/llm/grpo-qwen3-8B-base-1n8g-fsdp2-lora.yamltests/functional/*tests/unit/models/generation/test_vllm_generation.pyTesting
The compare of two implementation approach: Split Weight and Merge Weight
Split Weight
PRs:
Design
Enable Lora for both vllm and dtensor backend. Update base model weight and lora weight at the first step, and only refit lora weights in the following step.
Both inference and train compute with$xW^T + scale\times x (BA)^T$
lm_headdifferently for various models. A patch is needed for VLLM to skip thelm_head. (refer to https://nvidia.slack.com/archives/C0A1J3A7G85/p1768803141238469)Results:
Time
lora_dim(128)
SeqLen(4096)
step (33)
Merge Weight
PR:
Design
Only enable lora in dtensor backend and VLLM is agnostic to LoRA. Merge base model weight and lora weight before refit each time.
Dtensor compute lora via$xW^T + scale\times x (BA)^T$
Vllm compute lora via$x W_{vllm}^T= x (W^T + scale\times(BA)^T)$
lm_headhandling across different frameworks.train/gen_kl_errorexhibits a slight mismatch. Notably, this divergence does not escalate without bound but rather stabilizes at a certain level.gen_kl_errorcan be controlled to within 0.001, so we believe this is acceptable.Result
lora_dim(128)
SeqLen(4096)
step (33)
lora_dim(128)
SeqLen(1024)
step (33)
Summary by CodeRabbit
New Features
Tests
✏️ Tip: You can customize this high-level summary in your review settings.