-
Notifications
You must be signed in to change notification settings - Fork 233
feat: support stateless group and decouple vLLM in train backend #1842
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Signed-off-by: shuyixiong <219646547+shuyixiong@users.noreply.github.com>
Signed-off-by: shuyixiong <219646547+shuyixiong@users.noreply.github.com>
Signed-off-by: shuyixiong <219646547+shuyixiong@users.noreply.github.com>
e5fdfd1 to
76a01cf
Compare
|
@shuyixiong thanks so much for supporing this! I'll take the rest part. |
Signed-off-by: Yuki Huang <yukih@nvidia.com>
Signed-off-by: Yuki Huang <yukih@nvidia.com>
Signed-off-by: Yuki Huang <yukih@nvidia.com>
Signed-off-by: Yuki Huang <yukih@nvidia.com>
📝 WalkthroughWalkthroughThis change introduces a new Changes
Sequence Diagram(s)sequenceDiagram
participant R0 as Rank 0
participant R1 as Other Ranks
participant TCPStore as TCPStore
participant NCCL as NCCL Communicator
R0->>TCPStore: Generate unique_id
R0->>TCPStore: Store unique_id
par
R0->>NCCL: Init with unique_id
R1->>TCPStore: Retrieve unique_id
R1->>NCCL: Init with unique_id
end
R0->>NCCL: Warmup broadcast (verify connectivity)
R1->>NCCL: Receive warmup broadcast
NCCL->>NCCL: Synchronize all ranks
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~22 minutes 🚥 Pre-merge checks | ✅ 3 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (3 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
🤖 Fix all issues with AI agents
In `@nemo_rl/distributed/stateless_process_group.py`:
- Around line 35-63: The warmup verification in init_nccl_communicator uses
assert which can be disabled; replace it with an explicit check that raises a
clear exception (e.g., RuntimeError) if the broadcast didn't produce the
expected result: after torch.cuda.current_stream().synchronize(), verify
torch.allclose(data, torch.ones(1, device=device)) and if false raise an error
mentioning nccl communicator initialization failure, the rank, and
UNIQUE_ID_KEY/unique_id context so failure causes a hard stop and useful logs
rather than being skipped under -O.
🧹 Nitpick comments (1)
nemo_rl/distributed/stateless_process_group.py (1)
22-34: Add class and method docstrings for public API.This class is part of a public interface used by other modules. Adding Google-style docstrings would improve maintainability and enable proper Sphinx documentation generation.
📝 Suggested docstring addition
class StatelessProcessGroup: + """Stateless NCCL-based distributed communication group. + + Uses TCPStore for rank coordination and NCCL unique ID exchange, + enabling distributed collective operations without persistent + process group state. + + Attributes: + master_address: Host address for TCPStore coordination. + port: Port number for TCPStore. + rank: This process's rank in the group. + world_size: Total number of processes in the group. + tcp_store: TCPStore instance for coordination. + nccl_communicator: NCCL communicator (initialized via init_nccl_communicator). + """ + def __init__(self, master_address: str, port: int, rank: int, world_size: int): + """Initialize the stateless process group. + + Args: + master_address: Host address for the TCPStore. + port: Port number for the TCPStore. + rank: Rank of this process (0-indexed). + world_size: Total number of processes in the group. + """ self.master_address = master_address
terrykong
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks @shuyixiong for cleaning this up! this will definitely help us maintain other backends more easily by decoupling vllm from training backend!
could you update the megatron TP16 plot's logprob error by clipping it? i just wanted to confirm that it's not steadily increasing
|
Hi @shuyixiong, I have a quick question. Is change using nccl4py, which was developed recently? |
@terrykong added, both main and this PR have high logprob error in dsv3 TP32, this change won't cause the logprob error rising. and from |
yes, it is using this. |
Closes #501
As title. Add stateless group support, so that we could remove the dependency of vLLM in train backend.
Test Result
Tested on llama3.1-8b-instruct and dsv3, and only tested on non-colocated case since the change only affect this.
Summary by CodeRabbit
New Features
Chores
✏️ Tip: You can customize this high-level summary in your review settings.