diff --git a/transformer_engine/pytorch/cpu_offload.py b/transformer_engine/pytorch/cpu_offload.py index 58ed0630660..73dda74190f 100644 --- a/transformer_engine/pytorch/cpu_offload.py +++ b/transformer_engine/pytorch/cpu_offload.py @@ -19,6 +19,7 @@ from .quantized_tensor import ( restore_from_saved, prepare_for_saving, + QuantizedTensor, ) @@ -255,6 +256,8 @@ def start_offload(self): Start offloading of tensors. Puts copy from GPU to CPU tasks on offload stream. Before each copy event, the offload stream waits for the event signalling that the tensor is ready to be offloaded. This event is recorded in the start_offload or push_tensor call. + + Note: tensor_list only contains regular tensors (QuantizedTensors are decomposed in push_tensor). """ self._validate_state(func_name="start_offload", allowed_states=["not_offloaded"]) self.state = "offload_started" @@ -275,19 +278,18 @@ def start_offload(self): with torch.cuda.stream(self.offload_stream): if allocate_cpu_buffers: - # empty_like is defined also for QuantizedTensors offloaded_tensor = torch.empty_like( tensor, device=torch.device("cpu"), pin_memory=True ) self.cpu_tensor_group.tensor_list.append(offloaded_tensor) else: - assert self.cpu_tensor_group.tensor_list[tensor_id].shape == tensor.shape, ( + offloaded_tensor = self.cpu_tensor_group.tensor_list[tensor_id] + assert offloaded_tensor.shape == tensor.shape, ( "CPU buffer shape does not match the offloaded tensor shape:" - f" {self.cpu_tensor_group.tensor_list[tensor_id].shape} != {tensor.shape} " - " Make sure that tensor shaped do not change between" + f" {offloaded_tensor.shape} != {tensor.shape} " + "Make sure that tensor shapes do not change between" " iterations if retain_pinned_cpu_buffers is True." ) - offloaded_tensor = self.cpu_tensor_group.tensor_list[tensor_id] offloaded_tensor.copy_(tensor, non_blocking=True) # aux is a dictionary that contains auxiliary data like information which tensors were deduplicated, @@ -318,6 +320,9 @@ def start_reload(self): """ Start reloading of tensors. It allocates new tensors on GPU and puts copy from CPU tasks on offload stream. + + Note: tensor_list only contains regular tensors (QuantizedTensors are decomposed in push_tensor + and reconstructed in pop_tensor). """ self._validate_state(func_name="start_reload", allowed_states=["offload_finished"]) self.state = "reload_started" @@ -330,7 +335,6 @@ def start_reload(self): # cannot move tensors from pool of one stream to another without # calling cudaFree and cudaMalloc again. - # empty_like is defined also for QuantizedTensors. reloaded_tensor = torch.empty_like(tensor, device=torch.device("cuda")) self.offload_stream.wait_stream(torch.cuda.current_stream()) @@ -347,15 +351,26 @@ def start_reload(self): self.bwd_gpu_tensor_group ) - def push_tensor(self, tensor: torch.Tensor) -> int | torch.Tensor: + def push_tensor(self, tensor: torch.Tensor) -> int | torch.Tensor | tuple[list, list]: """ It is called when a tensor is saved for backward pass. If tensor is offloaded, returns int representing the index of the tensor in the offloaded tensor group. If tensor is not offloaded, returns the tensor itself. + For QuantizedTensor, returns (list of push results for each component, tensor_objs) tuple. """ self._validate_state(func_name="push_tensor", allowed_states=["not_offloaded"]) + # For QuantizedTensor: decompose into component tensors, push each one recursively + if isinstance(tensor, QuantizedTensor): + # Make a copy because prepare_for_saving modifies the object (sets fields to None) + tensor_copy = tensor.detach() + # Inline prepare_for_saving logic - QuantizedTensor is a torch.Tensor subclass, + # so the generic prepare_for_saving would not call tensor.prepare_for_saving() + saved_tensors, tensor_obj = tensor_copy.prepare_for_saving() + push_results = [self.push_tensor(t) if t is not None else None for t in saved_tensors] + return (push_results, [tensor_obj]) + if self._check_if_offload(tensor): self.fwd_gpu_tensor_group.tensor_list.append(tensor) # The group is processed and offloaded at the end of the forward pass of current layer. @@ -370,23 +385,39 @@ def push_tensor(self, tensor: torch.Tensor) -> int | torch.Tensor: return len(self.fwd_gpu_tensor_group.tensor_list) - 1 return tensor - def pop_tensor(self, tensor_or_tensor_id: torch.Tensor | int) -> torch.Tensor: + def pop_tensor( + self, tensor_or_tensor_id: torch.Tensor | int | tuple[list, list] + ) -> torch.Tensor: """ It is called when a tensor is used in backward pass. Returns the tensor. If tensor was offloaded/reloaded, wait for the reload of a tensor to finish. + For QuantizedTensor (tuple input), reconstructs from component tensors. """ self._validate_state( func_name="pop_tensor", allowed_states=["not_offloaded", "reload_started"] ) - # 1. tensor not offloaded + # 1. tensor not offloaded (regular tensor returned as-is from push) if isinstance(tensor_or_tensor_id, torch.Tensor): return tensor_or_tensor_id - # 2. the layer was not offloaded at all + + # 2. QuantizedTensor case: tuple of (push_results, tensor_objs) + if isinstance(tensor_or_tensor_id, tuple): + push_results, tensor_objs = tensor_or_tensor_id + # Recursively pop each component + reloaded_tensors = [ + self.pop_tensor(pr) if pr is not None else None for pr in push_results + ] + # Inline restore_from_saved - tensor_objs[0] is the QuantizedTensor copy + tensor_obj = tensor_objs[0] + tensor_obj.restore_from_saved(reloaded_tensors) + return tensor_obj + + # 3. Regular tensor index case if self.state == "not_offloaded": return self.fwd_gpu_tensor_group.tensor_list[tensor_or_tensor_id] - # 3. the layer was offloaded + # 4. the layer was offloaded assert self.state == "reload_started" # wait for the tensor to be reloaded torch.cuda.current_stream().wait_event( @@ -419,6 +450,10 @@ def _check_if_offload(self, t: torch.Tensor) -> bool: ) return False + # Only offload tensors with at least 256k elements (~1MB for float32) + if t.numel() < 256 * 1024: + return False + return True return False diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index 80479dccf48..484d04c4366 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -254,7 +254,8 @@ std::vector multi_tensor_quantize(const std::vector &ten std::vector split_quantize(const at::Tensor &tensor, const std::vector &split_sections, - std::vector quantizer_list); + std::vector quantizer_list, + bool disable_bulk_allocation = false); /*************************************************************************************************** * Bias gradient fusions diff --git a/transformer_engine/pytorch/csrc/extensions/cast.cpp b/transformer_engine/pytorch/csrc/extensions/cast.cpp index ac541435c7e..2c080873434 100644 --- a/transformer_engine/pytorch/csrc/extensions/cast.cpp +++ b/transformer_engine/pytorch/csrc/extensions/cast.cpp @@ -944,7 +944,8 @@ void split_quantize_nvfp4_impl(const TensorWrapper &input, std::vector split_quantize(const at::Tensor &tensor, const std::vector &split_sections, - std::vector quantizer_list) { + std::vector quantizer_list, + bool disable_bulk_allocation) { init_extension(); // Check number of tensors @@ -996,22 +997,24 @@ std::vector split_quantize(const at::Tensor &tensor, enum class QuantizationMethod { UNFUSED, FUSED_NVFP4 }; AllocationMethod allocation_method = AllocationMethod::UNFUSED; QuantizationMethod quantization_method = QuantizationMethod::UNFUSED; - if (std::all_of(quantizer_list.begin(), quantizer_list.end(), - [](const py::handle &quantizer) -> bool { - return detail::IsFloat8BlockwiseQuantizers(quantizer.ptr()); - })) { - allocation_method = AllocationMethod::BULK_FP8_BLOCKWISE; - } else if (std::all_of(quantizer_list.begin(), quantizer_list.end(), - [](const py::handle &quantizer) -> bool { - return detail::IsMXFP8Quantizers(quantizer.ptr()); - })) { - allocation_method = AllocationMethod::BULK_MXFP8; - } else if (std::all_of(quantizer_list.begin(), quantizer_list.end(), - [](const py::handle &quantizer) -> bool { - return detail::IsNVFP4Quantizers(quantizer.ptr()); - })) { - allocation_method = AllocationMethod::BULK_NVFP4; - quantization_method = QuantizationMethod::FUSED_NVFP4; + if (!disable_bulk_allocation) { + if (std::all_of(quantizer_list.begin(), quantizer_list.end(), + [](const py::handle &quantizer) -> bool { + return detail::IsFloat8BlockwiseQuantizers(quantizer.ptr()); + })) { + allocation_method = AllocationMethod::BULK_FP8_BLOCKWISE; + } else if (std::all_of(quantizer_list.begin(), quantizer_list.end(), + [](const py::handle &quantizer) -> bool { + return detail::IsMXFP8Quantizers(quantizer.ptr()); + })) { + allocation_method = AllocationMethod::BULK_MXFP8; + } else if (std::all_of(quantizer_list.begin(), quantizer_list.end(), + [](const py::handle &quantizer) -> bool { + return detail::IsNVFP4Quantizers(quantizer.ptr()); + })) { + allocation_method = AllocationMethod::BULK_NVFP4; + quantization_method = QuantizationMethod::FUSED_NVFP4; + } } // Allocate output tensors diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index d0f450bc712..62337ee7e7b 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -248,7 +248,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { "Multi-tensor quantize", py::arg("tensor_list"), py::arg("quantizer_list")); m.def("split_quantize", &transformer_engine::pytorch::split_quantize, "Split and multi-tensor quantize", py::arg("tensor"), py::arg("split_sections"), - py::arg("quantizer_list")); + py::arg("quantizer_list"), py::arg("disable_bulk_allocation") = false); m.def("te_general_grouped_gemm", &transformer_engine::pytorch::te_general_grouped_gemm, "Grouped GEMM"); m.def("fp8_transpose", &transformer_engine::pytorch::fp8_transpose, "Transpose with FP8 I/O", diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index c4d35a9c2cd..3b5bc73eb9e 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -143,7 +143,9 @@ def forward( inp_view = inp.reshape(-1, in_features) inputmats: list if fp8 and not debug: - inputmats = tex.split_quantize(inp_view, m_splits, input_quantizers) + inputmats = tex.split_quantize( + inp_view, m_splits, input_quantizers, disable_bulk_allocation=cpu_offloading + ) elif debug: inputmats = DebugQuantizer.multi_tensor_quantize( inp_view, input_quantizers, m_splits, activation_dtype diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index b65f7005eb3..6cbfcd88597 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -428,7 +428,8 @@ def forward( # weights if weights are externally touched outside this module ctx.weight_object = weight - mark_not_offload(weight, weightmat, bias) + if cpu_offloading: + mark_not_offload(weight, weightmat, bias) # TODO(ksivamani): Check memory usage tensors_to_save, tensor_objects = prepare_for_saving( saved_inputmat,