Skip to content

Conversation

@HIT-cwh
Copy link
Collaborator

@HIT-cwh HIT-cwh commented Jan 7, 2026

Step 1 Install nvshmem

One can install nvshmem==3.4.5 following this.

Step 2 Install Optimized Communication Operator Library

git clone xxx (ib_tep repo)
cd ib_tep
export TORCH_CUDA_ARCH_LIST=9.0
export NVSHMEM_DIR=/path/to/installed/nvshmem

# Development
python setup.py build
# Installation
python setup.py install

Step 3 Train with Optimized FSDP All-Gather and Reduce-Scatter

export XTUNER_ENABLE_CUSTOM_COMMUNICATION=1
export SYMM_BUF_SIZE=0 # Auto-resize symmetric buffer during runtime if smaller than required
export XTUNER_USE_CUSTOM_AG_IN_FSDP=1 # 使用自定义 All gather
export XTUNER_USE_CUSTOM_RS_IN_FSDP=1 # 使用自定义 Reduce scatter
export GRID_IB_AG=4
export GRID_IB_RS=4
export XTUNER_SELECT_COMM_SM_IN_FSDP=0

export XTUNER_SM_MARGIN=8
export DISTRIBUTED_COMMUNICATION_SM=8

torchrun --nproc-per-node=8  \
    xtuner/v1/train/cli/sft.py \
    --config /path/to/your/config.py

The timeline will be like this:

image

Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR introduces an optimization feature for FSDP (Fully Sharded Data Parallel) training by integrating custom communication operations through NVLink SHARP and IBGDA. The implementation provides optimized all-gather and reduce-scatter collectives that can be enabled via environment variables.

Key Changes

  • Added custom communication library integration with ib_wrapper for optimized FSDP operations
  • Implemented buffer management system with n-buffering support for concurrent operations
  • Created manager classes to handle all-gather and reduce-scatter operations with double buffering

Reviewed changes

Copilot reviewed 3 out of 3 changed files in this pull request and generated 28 comments.

File Description
xtuner/v1/train/trainer.py Adds import and initialization of custom communication library with buffer setup
xtuner/v1/patch/torch_fsdp_comm.py New file implementing custom FSDP collectives with buffer managers, symmetric memory allocation, and PyTorch function patching
xtuner/v1/patch/init.py Exports the new patch function for FSDP communication optimization

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +1083 to +1086
raise ImportError("XTUNER_USE_CUSTOM_{AG,RS}_IN_FSDP is set but ib_wrapper is not available.")

if not (torch.__version__.startswith("2.6") or torch.__version__.startswith("2.8")):
raise ImportError("XTUNER_USE_CUSTOM_{AG,RS}_IN_FSDP is only supported in PyTorch 2.6 and 2.8.")
Copy link

Copilot AI Jan 7, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The error message uses shell-style brace expansion syntax '{AG,RS}' which may be unclear to users. Consider using a more explicit message like 'XTUNER_USE_CUSTOM_AG_IN_FSDP or XTUNER_USE_CUSTOM_RS_IN_FSDP is only supported in PyTorch 2.6 and 2.8.' to improve clarity.

Suggested change
raise ImportError("XTUNER_USE_CUSTOM_{AG,RS}_IN_FSDP is set but ib_wrapper is not available.")
if not (torch.__version__.startswith("2.6") or torch.__version__.startswith("2.8")):
raise ImportError("XTUNER_USE_CUSTOM_{AG,RS}_IN_FSDP is only supported in PyTorch 2.6 and 2.8.")
raise ImportError(
"XTUNER_USE_CUSTOM_AG_IN_FSDP or XTUNER_USE_CUSTOM_RS_IN_FSDP is set but ib_wrapper is not available."
)
if not (torch.__version__.startswith("2.6") or torch.__version__.startswith("2.8")):
raise ImportError(
"XTUNER_USE_CUSTOM_AG_IN_FSDP or XTUNER_USE_CUSTOM_RS_IN_FSDP is only supported in PyTorch 2.6 and 2.8."
)

Copilot uses AI. Check for mistakes.
return torch.empty((size,), dtype=dtype, device=device)


lib = torch.library.Library("fsdp", "FRAGMENT") # noqa: TOR901
Copy link

Copilot AI Jan 7, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The torch.library.Library call uses a deprecated API pattern. The comment 'noqa: TOR901' suggests this is a known issue, but the 'FRAGMENT' argument may not be the correct descriptor for the library definition. Verify that this is the intended usage or consider updating to the current PyTorch library API if available.

Suggested change
lib = torch.library.Library("fsdp", "FRAGMENT") # noqa: TOR901
lib = torch.library.Library("fsdp", "DEF")

Copilot uses AI. Check for mistakes.
group=reduce_scatter_group,
from_process_group=allocate_memory_from_process_group,
)
# reduce_output = reduce_scatter_input.new_empty((reduce_scatter_output_numel,))
Copy link

Copilot AI Jan 7, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The commented-out code should be removed rather than left in the codebase. This improves maintainability and reduces confusion about which implementation is active.

Suggested change
# reduce_output = reduce_scatter_input.new_empty((reduce_scatter_output_numel,))

Copilot uses AI. Check for mistakes.
# self.rdc_scale: dict[int, torch.Tensor] = {}
self.copy_event_prev: torch.Event | None = None
self.copy_event: torch.Event | None = None
self.select_sm = int(os.getenv("SELECT_COMM_SM_IN_FSDP", 0))
Copy link

Copilot AI Jan 7, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The variable name 'select_sm' is ambiguous and unclear. Consider renaming it to something more descriptive like 'select_streaming_multiprocessor' or 'use_sm_selection' to improve code readability.

Suggested change
self.select_sm = int(os.getenv("SELECT_COMM_SM_IN_FSDP", 0))
self.select_streaming_multiprocessor = int(os.getenv("SELECT_COMM_SM_IN_FSDP", 0))

Copilot uses AI. Check for mistakes.
if (USE_CUSTOM_AG or USE_CUSTOM_RS) and world_size == dist.get_world_size():
recv_bytes = all_gather_input_numel * world_size * all_gather_inputs[0].element_size()
send_bytes = recv_bytes // world_size
recv_bytes_aligned = (send_bytes + 127) // 128 * 128 * world_size
Copy link

Copilot AI Jan 7, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Magic number 128 is used for alignment calculations without explanation. Consider defining it as a named constant (e.g., MEMORY_ALIGNMENT = 128) at the module level to improve code maintainability and make the purpose clear.

Copilot uses AI. Check for mistakes.
Comment on lines +61 to +67
"""Initialize the symmetric buffer manager with n buffering in
contiguous memory.
Args:
default_size (int): Default buffer size in bytes
alignment (int): Memory alignment requirement for the buffer
num_buffers (int): Number of buffers for n-buffering
Copy link

Copilot AI Jan 7, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The documentation states "Implements n buffering for concurrent operations with contiguous memory" but the actual parameter is named 'num_buffers' with a default of 3. However, the class is initialized with NUM_AG_BUFFERS (2 for AG) and NUM_RS_BUFFERS (1 for RS), which are not 3. The default value in the docstring should match the actual usage or be removed to avoid confusion.

Suggested change
"""Initialize the symmetric buffer manager with n buffering in
contiguous memory.
Args:
default_size (int): Default buffer size in bytes
alignment (int): Memory alignment requirement for the buffer
num_buffers (int): Number of buffers for n-buffering
"""Initialize the symmetric buffer manager with n-buffering in
contiguous memory.
Args:
default_size (int): Default buffer size in bytes.
alignment (int): Memory alignment requirement for the buffer.
num_buffers (int): Number of buffers for n-buffering. The actual
value is provided by callers (for example, ``NUM_AG_BUFFERS``
or ``NUM_RS_BUFFERS``) and may vary depending on usage.

Copilot uses AI. Check for mistakes.
)


def patch_fsdp_agrs() -> None:
Copy link

Copilot AI Jan 7, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The function name 'patch_fsdp_agrs' contains a typo. It should be 'patch_fsdp_args' (with 'args' instead of 'agrs').

Suggested change
def patch_fsdp_agrs() -> None:
def patch_fsdp_args() -> None:

Copilot uses AI. Check for mistakes.
Comment on lines +486 to 487
patch_fsdp_agrs()

Copy link

Copilot AI Jan 7, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The patch_fsdp_agrs function is called unconditionally during Trainer initialization, but it should only be called when custom communication is actually enabled. Currently, the function checks the environment variables internally, but calling it unconditionally may cause unnecessary module imports and overhead. Consider moving the call inside the conditional block that checks XTUNER_ENABLE_CUSTOM_COMMUNICATION at line 638.

Suggested change
patch_fsdp_agrs()
if os.getenv("XTUNER_ENABLE_CUSTOM_COMMUNICATION"):
patch_fsdp_agrs()

Copilot uses AI. Check for mistakes.
Comment on lines +702 to +705
reduce_scatter_input_aligned = reduce_scatter_input
else:
reduce_scatter_input = torch.empty((reduce_scatter_input_numel,), dtype=reduce_dtype, device=device)
reduce_scatter_input_aligned = reduce_scatter_input
Copy link

Copilot AI Jan 7, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Variable reduce_scatter_input_aligned is not used.

Suggested change
reduce_scatter_input_aligned = reduce_scatter_input
else:
reduce_scatter_input = torch.empty((reduce_scatter_input_numel,), dtype=reduce_dtype, device=device)
reduce_scatter_input_aligned = reduce_scatter_input
else:
reduce_scatter_input = torch.empty((reduce_scatter_input_numel,), dtype=reduce_dtype, device=device)

Copilot uses AI. Check for mistakes.
# reduce_scatter_input = torch.empty(
# (reduce_scatter_input_numel,), dtype=reduce_dtype, device=device
# )
reduce_scatter_input_aligned = reduce_scatter_input
Copy link

Copilot AI Jan 7, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Variable reduce_scatter_input_aligned is not used.

Copilot uses AI. Check for mistakes.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant