-
Notifications
You must be signed in to change notification settings - Fork 401
[Feature] Optimizing the communication of FSDP through NVLink SHARP and IBGDA #1415
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
This PR introduces an optimization feature for FSDP (Fully Sharded Data Parallel) training by integrating custom communication operations through NVLink SHARP and IBGDA. The implementation provides optimized all-gather and reduce-scatter collectives that can be enabled via environment variables.
Key Changes
- Added custom communication library integration with
ib_wrapperfor optimized FSDP operations - Implemented buffer management system with n-buffering support for concurrent operations
- Created manager classes to handle all-gather and reduce-scatter operations with double buffering
Reviewed changes
Copilot reviewed 3 out of 3 changed files in this pull request and generated 28 comments.
| File | Description |
|---|---|
| xtuner/v1/train/trainer.py | Adds import and initialization of custom communication library with buffer setup |
| xtuner/v1/patch/torch_fsdp_comm.py | New file implementing custom FSDP collectives with buffer managers, symmetric memory allocation, and PyTorch function patching |
| xtuner/v1/patch/init.py | Exports the new patch function for FSDP communication optimization |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| raise ImportError("XTUNER_USE_CUSTOM_{AG,RS}_IN_FSDP is set but ib_wrapper is not available.") | ||
|
|
||
| if not (torch.__version__.startswith("2.6") or torch.__version__.startswith("2.8")): | ||
| raise ImportError("XTUNER_USE_CUSTOM_{AG,RS}_IN_FSDP is only supported in PyTorch 2.6 and 2.8.") |
Copilot
AI
Jan 7, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The error message uses shell-style brace expansion syntax '{AG,RS}' which may be unclear to users. Consider using a more explicit message like 'XTUNER_USE_CUSTOM_AG_IN_FSDP or XTUNER_USE_CUSTOM_RS_IN_FSDP is only supported in PyTorch 2.6 and 2.8.' to improve clarity.
| raise ImportError("XTUNER_USE_CUSTOM_{AG,RS}_IN_FSDP is set but ib_wrapper is not available.") | |
| if not (torch.__version__.startswith("2.6") or torch.__version__.startswith("2.8")): | |
| raise ImportError("XTUNER_USE_CUSTOM_{AG,RS}_IN_FSDP is only supported in PyTorch 2.6 and 2.8.") | |
| raise ImportError( | |
| "XTUNER_USE_CUSTOM_AG_IN_FSDP or XTUNER_USE_CUSTOM_RS_IN_FSDP is set but ib_wrapper is not available." | |
| ) | |
| if not (torch.__version__.startswith("2.6") or torch.__version__.startswith("2.8")): | |
| raise ImportError( | |
| "XTUNER_USE_CUSTOM_AG_IN_FSDP or XTUNER_USE_CUSTOM_RS_IN_FSDP is only supported in PyTorch 2.6 and 2.8." | |
| ) |
| return torch.empty((size,), dtype=dtype, device=device) | ||
|
|
||
|
|
||
| lib = torch.library.Library("fsdp", "FRAGMENT") # noqa: TOR901 |
Copilot
AI
Jan 7, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The torch.library.Library call uses a deprecated API pattern. The comment 'noqa: TOR901' suggests this is a known issue, but the 'FRAGMENT' argument may not be the correct descriptor for the library definition. Verify that this is the intended usage or consider updating to the current PyTorch library API if available.
| lib = torch.library.Library("fsdp", "FRAGMENT") # noqa: TOR901 | |
| lib = torch.library.Library("fsdp", "DEF") |
| group=reduce_scatter_group, | ||
| from_process_group=allocate_memory_from_process_group, | ||
| ) | ||
| # reduce_output = reduce_scatter_input.new_empty((reduce_scatter_output_numel,)) |
Copilot
AI
Jan 7, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The commented-out code should be removed rather than left in the codebase. This improves maintainability and reduces confusion about which implementation is active.
| # reduce_output = reduce_scatter_input.new_empty((reduce_scatter_output_numel,)) |
| # self.rdc_scale: dict[int, torch.Tensor] = {} | ||
| self.copy_event_prev: torch.Event | None = None | ||
| self.copy_event: torch.Event | None = None | ||
| self.select_sm = int(os.getenv("SELECT_COMM_SM_IN_FSDP", 0)) |
Copilot
AI
Jan 7, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The variable name 'select_sm' is ambiguous and unclear. Consider renaming it to something more descriptive like 'select_streaming_multiprocessor' or 'use_sm_selection' to improve code readability.
| self.select_sm = int(os.getenv("SELECT_COMM_SM_IN_FSDP", 0)) | |
| self.select_streaming_multiprocessor = int(os.getenv("SELECT_COMM_SM_IN_FSDP", 0)) |
| if (USE_CUSTOM_AG or USE_CUSTOM_RS) and world_size == dist.get_world_size(): | ||
| recv_bytes = all_gather_input_numel * world_size * all_gather_inputs[0].element_size() | ||
| send_bytes = recv_bytes // world_size | ||
| recv_bytes_aligned = (send_bytes + 127) // 128 * 128 * world_size |
Copilot
AI
Jan 7, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Magic number 128 is used for alignment calculations without explanation. Consider defining it as a named constant (e.g., MEMORY_ALIGNMENT = 128) at the module level to improve code maintainability and make the purpose clear.
| """Initialize the symmetric buffer manager with n buffering in | ||
| contiguous memory. | ||
| Args: | ||
| default_size (int): Default buffer size in bytes | ||
| alignment (int): Memory alignment requirement for the buffer | ||
| num_buffers (int): Number of buffers for n-buffering |
Copilot
AI
Jan 7, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The documentation states "Implements n buffering for concurrent operations with contiguous memory" but the actual parameter is named 'num_buffers' with a default of 3. However, the class is initialized with NUM_AG_BUFFERS (2 for AG) and NUM_RS_BUFFERS (1 for RS), which are not 3. The default value in the docstring should match the actual usage or be removed to avoid confusion.
| """Initialize the symmetric buffer manager with n buffering in | |
| contiguous memory. | |
| Args: | |
| default_size (int): Default buffer size in bytes | |
| alignment (int): Memory alignment requirement for the buffer | |
| num_buffers (int): Number of buffers for n-buffering | |
| """Initialize the symmetric buffer manager with n-buffering in | |
| contiguous memory. | |
| Args: | |
| default_size (int): Default buffer size in bytes. | |
| alignment (int): Memory alignment requirement for the buffer. | |
| num_buffers (int): Number of buffers for n-buffering. The actual | |
| value is provided by callers (for example, ``NUM_AG_BUFFERS`` | |
| or ``NUM_RS_BUFFERS``) and may vary depending on usage. |
| ) | ||
|
|
||
|
|
||
| def patch_fsdp_agrs() -> None: |
Copilot
AI
Jan 7, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The function name 'patch_fsdp_agrs' contains a typo. It should be 'patch_fsdp_args' (with 'args' instead of 'agrs').
| def patch_fsdp_agrs() -> None: | |
| def patch_fsdp_args() -> None: |
| patch_fsdp_agrs() | ||
|
|
Copilot
AI
Jan 7, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The patch_fsdp_agrs function is called unconditionally during Trainer initialization, but it should only be called when custom communication is actually enabled. Currently, the function checks the environment variables internally, but calling it unconditionally may cause unnecessary module imports and overhead. Consider moving the call inside the conditional block that checks XTUNER_ENABLE_CUSTOM_COMMUNICATION at line 638.
| patch_fsdp_agrs() | |
| if os.getenv("XTUNER_ENABLE_CUSTOM_COMMUNICATION"): | |
| patch_fsdp_agrs() |
| reduce_scatter_input_aligned = reduce_scatter_input | ||
| else: | ||
| reduce_scatter_input = torch.empty((reduce_scatter_input_numel,), dtype=reduce_dtype, device=device) | ||
| reduce_scatter_input_aligned = reduce_scatter_input |
Copilot
AI
Jan 7, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Variable reduce_scatter_input_aligned is not used.
| reduce_scatter_input_aligned = reduce_scatter_input | |
| else: | |
| reduce_scatter_input = torch.empty((reduce_scatter_input_numel,), dtype=reduce_dtype, device=device) | |
| reduce_scatter_input_aligned = reduce_scatter_input | |
| else: | |
| reduce_scatter_input = torch.empty((reduce_scatter_input_numel,), dtype=reduce_dtype, device=device) |
| # reduce_scatter_input = torch.empty( | ||
| # (reduce_scatter_input_numel,), dtype=reduce_dtype, device=device | ||
| # ) | ||
| reduce_scatter_input_aligned = reduce_scatter_input |
Copilot
AI
Jan 7, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Variable reduce_scatter_input_aligned is not used.
Step 1 Install nvshmem
One can install nvshmem==3.4.5 following this.
Step 2 Install Optimized Communication Operator Library
Step 3 Train with Optimized FSDP All-Gather and Reduce-Scatter
The timeline will be like this: