diff --git a/encodings/bytebool/src/array.rs b/encodings/bytebool/src/array.rs index 813d1b0539b..d5027fbbd80 100644 --- a/encodings/bytebool/src/array.rs +++ b/encodings/bytebool/src/array.rs @@ -88,7 +88,7 @@ impl VTable for ByteBoolVTable { if buffers.len() != 1 { vortex_bail!("Expected 1 buffer, got {}", buffers.len()); } - let buffer = buffers[0].clone().try_to_host()?; + let buffer = buffers[0].clone().try_to_host_sync()?; Ok(ByteBoolArray::new(buffer, validity)) } diff --git a/encodings/fastlanes/src/bitpacking/vtable/mod.rs b/encodings/fastlanes/src/bitpacking/vtable/mod.rs index d4085c5443f..c15c5b1b0ff 100644 --- a/encodings/fastlanes/src/bitpacking/vtable/mod.rs +++ b/encodings/fastlanes/src/bitpacking/vtable/mod.rs @@ -176,7 +176,7 @@ impl VTable for BitPackedVTable { if buffers.len() != 1 { vortex_bail!("Expected 1 buffer, got {}", buffers.len()); } - let packed = buffers[0].clone().try_to_host()?; + let packed = buffers[0].clone().try_to_host_sync()?; let load_validity = |child_idx: usize| { if children.len() == child_idx { diff --git a/encodings/fsst/src/array.rs b/encodings/fsst/src/array.rs index 32e5aa83b78..e4765114510 100644 --- a/encodings/fsst/src/array.rs +++ b/encodings/fsst/src/array.rs @@ -111,8 +111,8 @@ impl VTable for FSSTVTable { if buffers.len() != 2 { vortex_bail!(InvalidArgument: "Expected 2 buffers, got {}", buffers.len()); } - let symbols = Buffer::::from_byte_buffer(buffers[0].clone().try_to_host()?); - let symbol_lengths = Buffer::::from_byte_buffer(buffers[1].clone().try_to_host()?); + let symbols = Buffer::::from_byte_buffer(buffers[0].clone().try_to_host_sync()?); + let symbol_lengths = Buffer::::from_byte_buffer(buffers[1].clone().try_to_host_sync()?); if children.len() != 2 { vortex_bail!(InvalidArgument: "Expected 2 children, got {}", children.len()); diff --git a/encodings/pco/src/array.rs b/encodings/pco/src/array.rs index 87504c5b423..15266f32ae5 100644 --- a/encodings/pco/src/array.rs +++ b/encodings/pco/src/array.rs @@ -131,11 +131,11 @@ impl VTable for PcoVTable { vortex_ensure!(buffers.len() >= metadata.0.chunks.len()); let chunk_metas = buffers[..metadata.0.chunks.len()] .iter() - .map(|b| b.clone().try_to_host()) + .map(|b| b.clone().try_to_host_sync()) .collect::>>()?; let pages = buffers[metadata.0.chunks.len()..] .iter() - .map(|b| b.clone().try_to_host()) + .map(|b| b.clone().try_to_host_sync()) .collect::>>()?; let expected_n_pages = metadata diff --git a/encodings/sparse/src/lib.rs b/encodings/sparse/src/lib.rs index 6fdfa6582cc..4ac744d9b2a 100644 --- a/encodings/sparse/src/lib.rs +++ b/encodings/sparse/src/lib.rs @@ -131,7 +131,7 @@ impl VTable for SparseVTable { } let fill_value = Scalar::new( dtype.clone(), - ScalarValue::from_protobytes(&buffers[0].clone().try_to_host()?)?, + ScalarValue::from_protobytes(&buffers[0].clone().try_to_host_sync()?)?, ); SparseArray::try_new(patch_indices, patch_values, len, fill_value) diff --git a/encodings/zstd/src/array.rs b/encodings/zstd/src/array.rs index f3f2bef268a..8adef0542a1 100644 --- a/encodings/zstd/src/array.rs +++ b/encodings/zstd/src/array.rs @@ -132,16 +132,16 @@ impl VTable for ZstdVTable { None, buffers .iter() - .map(|b| b.clone().try_to_host()) + .map(|b| b.clone().try_to_host_sync()) .collect::>>()?, ) } else { // with dictionary ( - Some(buffers[0].clone().try_to_host()?), + Some(buffers[0].clone().try_to_host_sync()?), buffers[1..] .iter() - .map(|b| b.clone().try_to_host()) + .map(|b| b.clone().try_to_host_sync()) .collect::>>()?, ) }; @@ -368,7 +368,7 @@ impl ZstdArray { n_values }; - let value_bytes = values.buffer_handle().try_to_host()?; + let value_bytes = values.buffer_handle().try_to_host_sync()?; // Align frames to buffer alignment. This is necessary for overaligned buffers. let alignment = *value_bytes.alignment(); let step_width = (values_per_frame * byte_width).div_ceil(alignment) * alignment; diff --git a/vortex-array/src/array/visitor.rs b/vortex-array/src/array/visitor.rs index 9f3c4454b9b..0c418987d62 100644 --- a/vortex-array/src/array/visitor.rs +++ b/vortex-array/src/array/visitor.rs @@ -115,7 +115,7 @@ impl ArrayVisitorExt for A {} pub trait ArrayBufferVisitor { fn visit_buffer_handle(&mut self, handle: &BufferHandle) -> VortexResult<()> { - self.visit_buffer(&handle.clone().try_to_host()?); + self.visit_buffer(&handle.clone().try_to_host_sync()?); Ok(()) } fn visit_buffer(&mut self, buffer: &ByteBuffer); diff --git a/vortex-array/src/arrays/bool/vtable/mod.rs b/vortex-array/src/arrays/bool/vtable/mod.rs index cfc24b26331..709776da262 100644 --- a/vortex-array/src/arrays/bool/vtable/mod.rs +++ b/vortex-array/src/arrays/bool/vtable/mod.rs @@ -100,7 +100,7 @@ impl VTable for BoolVTable { vortex_bail!("Expected 0 or 1 child, got {}", children.len()); }; - let buffer = buffers[0].clone().try_to_host()?; + let buffer = buffers[0].clone().try_to_host_sync()?; let bits = BitBuffer::new_with_offset(buffer, len, metadata.offset as usize); BoolArray::try_new(bits, validity) diff --git a/vortex-array/src/arrays/constant/vtable/mod.rs b/vortex-array/src/arrays/constant/vtable/mod.rs index e9a0ff52d16..264c66eddd9 100644 --- a/vortex-array/src/arrays/constant/vtable/mod.rs +++ b/vortex-array/src/arrays/constant/vtable/mod.rs @@ -80,7 +80,7 @@ impl VTable for ConstantVTable { if buffers.len() != 1 { vortex_bail!("Expected 1 buffer, got {}", buffers.len()); } - let buffer = buffers[0].clone().try_to_host()?; + let buffer = buffers[0].clone().try_to_host_sync()?; let sv = ScalarValue::from_protobytes(&buffer)?; let scalar = Scalar::new(dtype.clone(), sv); Ok(ConstantArray::new(scalar, len)) diff --git a/vortex-array/src/arrays/decimal/vtable/mod.rs b/vortex-array/src/arrays/decimal/vtable/mod.rs index 41e131b8727..30c99fdabfb 100644 --- a/vortex-array/src/arrays/decimal/vtable/mod.rs +++ b/vortex-array/src/arrays/decimal/vtable/mod.rs @@ -91,7 +91,7 @@ impl VTable for DecimalVTable { if buffers.len() != 1 { vortex_bail!("Expected 1 buffer, got {}", buffers.len()); } - let buffer = buffers[0].clone().try_to_host()?; + let buffer = buffers[0].clone().try_to_host_sync()?; let validity = if children.is_empty() { Validity::from(dtype.nullability()) diff --git a/vortex-array/src/arrays/primitive/array/conversion.rs b/vortex-array/src/arrays/primitive/array/conversion.rs index 8bfbd8a52d5..90d2c9dd4f4 100644 --- a/vortex-array/src/arrays/primitive/array/conversion.rs +++ b/vortex-array/src/arrays/primitive/array/conversion.rs @@ -83,7 +83,7 @@ impl PrimitiveArray { ) } - Buffer::from_byte_buffer(self.buffer_handle().to_host()) + Buffer::from_byte_buffer(self.buffer_handle().to_host_sync()) } /// Consume the array and get a host Buffer containing the data values. @@ -96,7 +96,7 @@ impl PrimitiveArray { ) } - Buffer::from_byte_buffer(self.buffer.into_host()) + Buffer::from_byte_buffer(self.buffer.into_host_sync()) } /// Extract a mutable buffer from the PrimitiveArray. Attempts to do this with zero-copy @@ -115,7 +115,7 @@ impl PrimitiveArray { self.ptype() ) } - let buffer = Buffer::::from_byte_buffer(self.buffer.into_host()); + let buffer = Buffer::::from_byte_buffer(self.buffer.into_host_sync()); buffer.try_into_mut() } } diff --git a/vortex-array/src/arrays/primitive/vtable/visitor.rs b/vortex-array/src/arrays/primitive/vtable/visitor.rs index 65910a17658..b52bb33dbc9 100644 --- a/vortex-array/src/arrays/primitive/vtable/visitor.rs +++ b/vortex-array/src/arrays/primitive/vtable/visitor.rs @@ -10,7 +10,7 @@ use crate::vtable::VisitorVTable; impl VisitorVTable for PrimitiveVTable { fn visit_buffers(array: &PrimitiveArray, visitor: &mut dyn ArrayBufferVisitor) { - visitor.visit_buffer(&array.buffer_handle().to_host()); + visitor.visit_buffer(&array.buffer_handle().to_host_sync()); } fn visit_children(array: &PrimitiveArray, visitor: &mut dyn ArrayChildVisitor) { diff --git a/vortex-array/src/arrays/varbin/vtable/mod.rs b/vortex-array/src/arrays/varbin/vtable/mod.rs index 1194dcbada6..caa68afc08f 100644 --- a/vortex-array/src/arrays/varbin/vtable/mod.rs +++ b/vortex-array/src/arrays/varbin/vtable/mod.rs @@ -102,7 +102,7 @@ impl VTable for VarBinVTable { if buffers.len() != 1 { vortex_bail!("Expected 1 buffer, got {}", buffers.len()); } - let bytes = buffers[0].clone().try_to_host()?; + let bytes = buffers[0].clone().try_to_host_sync()?; VarBinArray::try_new(offsets, bytes, dtype.clone(), validity) } diff --git a/vortex-array/src/arrays/varbinview/vtable/mod.rs b/vortex-array/src/arrays/varbinview/vtable/mod.rs index 3cb17a0d02e..8068ec724f8 100644 --- a/vortex-array/src/arrays/varbinview/vtable/mod.rs +++ b/vortex-array/src/arrays/varbinview/vtable/mod.rs @@ -82,7 +82,7 @@ impl VTable for VarBinViewVTable { } let mut buffers: Vec = buffers .iter() - .map(|b| b.clone().try_to_host()) + .map(|b| b.clone().try_to_host_sync()) .collect::>>()?; let views = buffers.pop().vortex_expect("buffers non-empty"); diff --git a/vortex-array/src/buffer.rs b/vortex-array/src/buffer.rs index 9d5934b07ee..7f084f1ee4e 100644 --- a/vortex-array/src/buffer.rs +++ b/vortex-array/src/buffer.rs @@ -8,6 +8,7 @@ use std::hash::Hasher; use std::ops::Range; use std::sync::Arc; +use futures::future::BoxFuture; use vortex_buffer::ALIGNMENT_TO_HOST_COPY; use vortex_buffer::Alignment; use vortex_buffer::ByteBuffer; @@ -59,7 +60,23 @@ pub trait DeviceBuffer: 'static + Send + Sync + Debug + DynEq + DynHash { /// # Errors /// /// This operation may fail, depending on the device implementation and the underlying hardware. - fn copy_to_host(&self, alignment: Alignment) -> VortexResult; + fn copy_to_host_sync(&self, alignment: Alignment) -> VortexResult; + + /// Copies the device buffer to a host buffer asynchronously. + /// + /// Schedules an async copy and returns a future that completes when the copy is finished. + /// + /// # Arguments + /// + /// * `alignment` - The memory alignment to use for the host buffer. + /// + /// # Errors + /// + /// Returns an error if the async copy operation fails. + fn copy_to_host( + &self, + alignment: Alignment, + ) -> VortexResult>>; /// Create a new buffer that references a subrange of this buffer at the given /// slice indices. @@ -147,7 +164,7 @@ impl BufferHandle { /// let values = buffer![1u32, 2u32, 3u32, 4u32]; /// let handle = BufferHandle::new_host(values.into_byte_buffer()); /// let sliced = handle.slice_typed::(1..4); - /// let result = Buffer::::from_byte_buffer(sliced.to_host()); + /// let result = Buffer::::from_byte_buffer(sliced.to_host_sync()); /// assert_eq!(result, buffer![2, 3, 4]); /// ``` pub fn slice_typed(&self, range: Range) -> Self { @@ -214,8 +231,8 @@ impl BufferHandle { /// result in a panic. /// /// See also: [`try_to_host`][Self::try_to_host]. - pub fn to_host(&self) -> ByteBuffer { - self.try_to_host() + pub fn to_host_sync(&self) -> ByteBuffer { + self.try_to_host_sync() .vortex_expect("to_host: copy from device to host failed") } @@ -228,8 +245,8 @@ impl BufferHandle { /// # Panics /// /// See the panic documentation on [`to_host`][Self::to_host]. - pub fn into_host(self) -> ByteBuffer { - self.try_into_host() + pub fn into_host_sync(self) -> ByteBuffer { + self.try_into_host_sync() .vortex_expect("into_host: copy from device to host failed") } @@ -239,22 +256,93 @@ impl BufferHandle { /// /// If it is a device allocation, then this issues an operation that attempts to copy the data /// from the device into a host-resident buffer, and returns a handle to that buffer. - pub fn try_to_host(&self) -> VortexResult { + pub fn try_to_host_sync(&self) -> VortexResult { match &self.0 { Inner::Host(b) => Ok(b.clone()), - Inner::Device(device) => device.copy_to_host(ALIGNMENT_TO_HOST_COPY), + Inner::Device(device) => device.copy_to_host_sync(ALIGNMENT_TO_HOST_COPY), } } /// Attempts to load this buffer into a host-resident allocation, consuming the handle. /// /// See also [`try_to_host`][Self::try_to_host]. - pub fn try_into_host(self) -> VortexResult { + pub fn try_into_host_sync(self) -> VortexResult { match self.0 { Inner::Host(b) => Ok(b), + Inner::Device(device) => device.copy_to_host_sync(ALIGNMENT_TO_HOST_COPY), + } + } + + /// Asynchronously copies the buffer to the host. + /// + /// This is a no-op if the buffer is already on the host. + /// + /// # Returns + /// + /// A future that resolves to the host buffer when the copy completes. + /// + /// # Errors + /// + /// Returns an error if the async copy operation fails. + pub fn try_to_host(&self) -> VortexResult>> { + match &self.0 { + Inner::Host(b) => { + let buffer = b.clone(); + Ok(Box::pin(async move { Ok(buffer) })) + } Inner::Device(device) => device.copy_to_host(ALIGNMENT_TO_HOST_COPY), } } + + /// Asynchronously copies the buffer to the host, consuming the handle. + /// + /// This is a no-op if the buffer is already on the host. + /// + /// # Returns + /// + /// A future that resolves to the host buffer when the copy completes. + /// + /// # Errors + /// + /// Returns an error if the async copy operation fails. + pub fn try_into_host(self) -> VortexResult>> { + match self.0 { + Inner::Host(b) => Ok(Box::pin(async move { Ok(b) })), + Inner::Device(device) => device.copy_to_host(ALIGNMENT_TO_HOST_COPY), + } + } + + /// Asynchronously copies the buffer to the host. + /// + /// # Panics + /// + /// Any errors triggered by the copying from device to host will result in a panic. + pub fn to_host(&self) -> BoxFuture<'static, ByteBuffer> { + let future = self + .try_to_host() + .vortex_expect("to_host: failed to initiate copy from device to host"); + Box::pin(async move { + future + .await + .vortex_expect("to_host: copy from device to host failed") + }) + } + + /// Asynchronously copies the buffer to the host, consuming the handle. + /// + /// # Panics + /// + /// Any errors triggered by the copying from device to host will result in a panic. + pub fn into_host(self) -> BoxFuture<'static, ByteBuffer> { + let future = self + .try_into_host() + .vortex_expect("into_host: failed to initiate copy from device to host"); + Box::pin(async move { + future + .await + .vortex_expect("into_host: copy from device to host failed") + }) + } } impl ArrayHash for BufferHandle { diff --git a/vortex-array/src/compute/conformance/take.rs b/vortex-array/src/compute/conformance/take.rs index 8a6f31a0847..0385ace1327 100644 --- a/vortex-array/src/compute/conformance/take.rs +++ b/vortex-array/src/compute/conformance/take.rs @@ -72,8 +72,8 @@ fn test_take_all(array: &dyn Array) { ) { (Canonical::Primitive(orig_prim), Canonical::Primitive(result_prim)) => { assert_eq!( - orig_prim.buffer_handle().to_host(), - result_prim.buffer_handle().to_host() + orig_prim.buffer_handle().to_host_sync(), + result_prim.buffer_handle().to_host_sync() ); } _ => { diff --git a/vortex-array/src/serde.rs b/vortex-array/src/serde.rs index 85fff5a8621..35821184740 100644 --- a/vortex-array/src/serde.rs +++ b/vortex-array/src/serde.rs @@ -491,7 +491,7 @@ impl ArrayParts { segment: BufferHandle, ) -> VortexResult { // TODO: this can also work with device buffers. - let segment = segment.try_to_host()?; + let segment = segment.try_to_host_sync()?; // We align each buffer individually, so we remove alignment requirements on the buffer. let segment = segment.aligned(Alignment::none()); @@ -612,6 +612,6 @@ impl TryFrom for ArrayParts { type Error = VortexError; fn try_from(value: BufferHandle) -> Result { - Self::try_from(value.try_to_host()?) + Self::try_from(value.try_to_host_sync()?) } } diff --git a/vortex-cuda/src/device_buffer.rs b/vortex-cuda/src/device_buffer.rs index 721b1b3e6e3..26021d4582b 100644 --- a/vortex-cuda/src/device_buffer.rs +++ b/vortex-cuda/src/device_buffer.rs @@ -9,6 +9,8 @@ use cudarc::driver::CudaSlice; use cudarc::driver::CudaView; use cudarc::driver::DevicePtr; use cudarc::driver::DeviceRepr; +use cudarc::driver::sys; +use futures::future::BoxFuture; use vortex_array::buffer::BufferHandle; use vortex_array::buffer::DeviceBuffer; use vortex_buffer::Alignment; @@ -17,6 +19,8 @@ use vortex_buffer::ByteBuffer; use vortex_error::VortexResult; use vortex_error::vortex_err; +use crate::stream::await_stream_callback; + /// A CUDA device buffer with offset and length tracking. pub struct CudaDeviceBuffer { inner: Arc>, @@ -114,22 +118,86 @@ impl DeviceBuffer for CudaDeviceBuffer /// # Errors /// /// Returns an error if the CUDA memory copy operation fails. - fn copy_to_host(&self, alignment: Alignment) -> VortexResult { + fn copy_to_host_sync(&self, alignment: Alignment) -> VortexResult { let mut host_buffer = BufferMut::::with_capacity_aligned(self.len, alignment); - let view = self.as_view(); - self.inner - .stream() - // TODO(0ax1): make the copy async - .memcpy_dtoh(&view, unsafe { - host_buffer.set_len(self.len); - host_buffer.as_mut_slice() - }) + // Add offset to device pointer to account for any previous slicing operations. + let src_ptr = self.device_ptr + (self.offset * size_of::()) as u64; + + // SAFETY: We pass a valid pointer to a buffer with sufficient capacity. + // `cuMemcpyDtoHAsync_v2` fully initializes the memory. + unsafe { + sys::cuMemcpyDtoH_v2( + host_buffer.spare_capacity_mut().as_mut_ptr().cast(), + src_ptr, + self.len * size_of::(), + ) + .result() .map_err(|e| vortex_err!("Failed to copy from device to host: {}", e))?; + } + + // SAFETY: `cuMemcpyDtoHAsync_v2` fully initialized the buffer. + unsafe { + host_buffer.set_len(self.len); + } Ok(host_buffer.freeze().into_byte_buffer()) } + /// Copies a device buffer to host memory asynchronously. + /// + /// Allocates host memory, schedules an async copy, and returns a future + /// that completes when the copy is finished. + /// + /// # Arguments + /// + /// * `alignment` - The memory alignment to use for the host buffer. + /// + /// # Returns + /// + /// A future that resolves to the host buffer when the copy completes. + fn copy_to_host( + &self, + alignment: Alignment, + ) -> VortexResult>> { + let stream = self.inner.stream(); + + // Add offset to device pointer to account for any previous slicing operations. + let src_ptr = self.device_ptr + (self.offset * size_of::()) as u64; + + let mut host_buffer: BufferMut = BufferMut::with_capacity_aligned(self.len, alignment); + let len = self.len; + + // SAFETY: We pass a valid pointer to a buffer with sufficient capacity. + // `cuMemcpyDtoHAsync_v2` fully initializes the memory. + unsafe { + sys::cuMemcpyDtoHAsync_v2( + host_buffer.spare_capacity_mut().as_mut_ptr().cast(), + src_ptr, + len * size_of::(), + stream.cu_stream(), + ) + .result() + .map_err(|e| vortex_err!("Failed to schedule async copy to host: {}", e))?; + } + + let cuda_slice = Arc::clone(&self.inner); + + Ok(Box::pin(async move { + await_stream_callback(cuda_slice.stream()).await?; + + // Keep device memory alive until copy completes. + let _keep_alive = cuda_slice; + + // SAFETY: `cuMemcpyDtoHAsync_v2` fully initialized the buffer. + unsafe { + host_buffer.set_len(len); + } + + Ok(host_buffer.freeze().into_byte_buffer()) + })) + } + /// Slices the CUDA device buffer to a subrange. fn slice(&self, range: Range) -> Arc { let new_offset = self.offset + range.start; diff --git a/vortex-cuda/src/executor.rs b/vortex-cuda/src/executor.rs index 557774d2a10..4ae59c6dd2b 100644 --- a/vortex-cuda/src/executor.rs +++ b/vortex-cuda/src/executor.rs @@ -12,14 +12,9 @@ use cudarc::driver::CudaSlice; use cudarc::driver::CudaStream; use cudarc::driver::DevicePtrMut; use cudarc::driver::DeviceRepr; -use cudarc::driver::DriverError; use cudarc::driver::LaunchArgs; -use cudarc::driver::result; use cudarc::driver::result::memcpy_htod_async; -use cudarc::driver::sys; use futures::future::BoxFuture; -use kanal::Sender; -use result::stream; use vortex_array::Array; use vortex_array::ArrayRef; use vortex_array::Canonical; @@ -33,76 +28,7 @@ use vortex_error::vortex_err; use crate::CudaDeviceBuffer; use crate::CudaSession; use crate::session::CudaSessionExt; - -/// Registers a callback and asynchronously waits for its completion. -/// -/// This function can be used to asynchronously wait for events previously -/// submitted to the stream to complete, e.g. async device buffer allocations. -/// -/// Note: This is not equivalent to calling sync on a stream but only awaits -/// the registered callback to complete. -/// -/// # Arguments -/// -/// * `stream` - The CUDA stream to wait on -pub async fn await_stream_callback(stream: &CudaStream) -> Result<(), DriverError> { - let rx = register_stream_callback(stream)?; - - rx.recv() - .await - .map_err(|_| DriverError(sys::CUresult::CUDA_ERROR_UNKNOWN)) -} - -/// Registers a host function callback on the stream. -/// -/// # Returns -/// -/// An async receiver that receives a message when all preceding work on the -/// stream completes. -/// -/// # Errors -/// -/// Returns an error if registering the host callback function fails. -fn register_stream_callback(stream: &CudaStream) -> Result, DriverError> { - let (tx, rx) = kanal::bounded::<()>(1); - - // There are 2 different scenarios how `tx` gets freed. When the callback - // is invoked or during cleanup in case the registration fails. - let tx_ptr = Box::into_raw(Box::new(tx)); - - /// Called from CUDA driver thread when all preceding work on the stream completes. - unsafe extern "C" fn callback(user_data: *mut std::ffi::c_void) { - // SAFETY: The memory of `tx` is manually managed has not been freed - // before. We have unique ownership and can therefore free it. - let tx = unsafe { Box::from_raw(user_data as *mut Sender<()>) }; - - // Blocking send as we're in a callback invoked by the CUDA driver. - #[expect(clippy::expect_used)] - tx.send(()) - // A send should never fail. Panic otherwise. - .expect("CUDA callback receiver dropped unexpectedly"); - } - - // SAFETY: - // 1. Valid handle from the borrowed `CudaStream`. - // 2. Valid function pointer with the the correct signature - // 3. Valid user data pointer which is consumed exactly once - unsafe { - stream::launch_host_function( - stream.cu_stream(), - callback, - tx_ptr as *mut std::ffi::c_void, - ) - .inspect_err(|_| { - // SAFETY: Registration failed, so callback will never run. - // Therefore, we need to free the `user_data` passed to the - // callback in the error case. - drop(Box::from_raw(tx_ptr)); - })?; - } - - Ok(rx.to_async()) -} +use crate::stream::await_stream_callback; /// CUDA kernel events recorded before and after kernel launch. #[derive(Debug)] @@ -189,26 +115,25 @@ impl CudaExecutionCtx { Ok(CudaDeviceBuffer::new(cuda_slice)) } - /// Copies a pinned host buffer to the device asynchronously. + /// Copies a host buffer to the device asynchronously. /// /// Allocates device memory, schedules an async copy, and returns a future /// that completes when the copy is finished. /// /// # Arguments /// - /// * `handle` - The host buffer to copy. Must be a host buffer (not already on device). + /// * `handle` - The host buffer to copy. Must be a host buffer. /// - /// # Safety + /// # Returns /// - /// The returned future captures the source `BufferHandle` to keep the host - /// memory alive until the copy completes. + /// A future that resolves to the device buffer handle when the copy completes. pub fn copy_buffer_to_device_async( &self, handle: BufferHandle, ) -> VortexResult>> { let host_buffer = handle .as_host_opt() - .ok_or_else(|| vortex_err!("Buffer is neither on host nor device"))?; + .ok_or_else(|| vortex_err!("Buffer is not on host"))?; let mut cuda_slice: CudaSlice = self.device_alloc(host_buffer.len() / size_of::())?; let device_ptr = cuda_slice.device_ptr_mut(&self.stream).0; @@ -226,9 +151,7 @@ impl CudaExecutionCtx { Ok(Box::pin(async move { // Await async copy completion using callback-based async wait. - await_stream_callback(&stream) - .await - .map_err(|e| vortex_err!("CUDA stream wait failed: {}", e))?; + await_stream_callback(&stream).await?; // Keep source memory alive until copy completes. let _keep_alive = handle; diff --git a/vortex-cuda/src/kernel/encodings/for_.rs b/vortex-cuda/src/kernel/encodings/for_.rs index 33d04be614f..bea4c690081 100644 --- a/vortex-cuda/src/kernel/encodings/for_.rs +++ b/vortex-cuda/src/kernel/encodings/for_.rs @@ -109,9 +109,9 @@ mod tests { let mut cuda_ctx = CudaSession::create_execution_ctx(VortexSession::empty()) .vortex_expect("failed to create execution context"); - // Create u8 offset values that cycle through 0-255, creating 5000 elements + // Create u8 offset values that cycle through 0-245, creating 5000 elements #[allow(clippy::cast_possible_truncation)] - let input_data: Vec = (0..5000).map(|i| (i % 256) as u8).collect(); + let input_data: Vec = (0..5000).map(|i| (i % 246) as u8).collect(); let for_array = FoRArray::try_new( PrimitiveArray::new(Buffer::from(input_data.clone()), NonNullable).into_array(), @@ -126,14 +126,12 @@ mod tests { .vortex_expect("GPU decompression failed"); let result_buf = - Buffer::::from_byte_buffer(result.as_primitive().buffer_handle().to_host()); + Buffer::::from_byte_buffer(result.as_primitive().buffer_handle().to_host().await); + assert_eq!(result_buf.len(), input_data.len()); assert_eq!( result_buf, - input_data - .iter() - .map(|&val| val.wrapping_add(10)) - .collect::>() + input_data.iter().map(|&val| val + 10).collect::>() ); } @@ -162,13 +160,14 @@ mod tests { .vortex_expect("GPU decompression failed"); let result_buf = - Buffer::::from_byte_buffer(result.as_primitive().buffer_handle().to_host()); + Buffer::::from_byte_buffer(result.as_primitive().buffer_handle().to_host().await); + assert_eq!(result_buf.len(), input_data.len()); assert_eq!( result_buf, input_data .iter() - .map(|&val| val.wrapping_add(1000)) + .map(|&val| val + 1000) .collect::>() ); } @@ -198,13 +197,14 @@ mod tests { .vortex_expect("GPU decompression failed"); let result_buf = - Buffer::::from_byte_buffer(result.as_primitive().buffer_handle().to_host()); + Buffer::::from_byte_buffer(result.as_primitive().buffer_handle().to_host().await); + assert_eq!(result_buf.len(), input_data.len()); assert_eq!( result_buf, input_data .iter() - .map(|&val| val.wrapping_add(100000)) + .map(|&val| val + 100000) .collect::>() ); } @@ -234,13 +234,14 @@ mod tests { .vortex_expect("GPU decompression failed"); let result_buf = - Buffer::::from_byte_buffer(result.as_primitive().buffer_handle().to_host()); + Buffer::::from_byte_buffer(result.as_primitive().buffer_handle().to_host().await); + assert_eq!(result_buf.len(), input_data.len()); assert_eq!( result_buf, input_data .iter() - .map(|&val| val.wrapping_add(1000000u64)) + .map(|&val| val + 1000000u64) .collect::>() ); } diff --git a/vortex-cuda/src/lib.rs b/vortex-cuda/src/lib.rs index 8075a70283a..cc0a434de3b 100644 --- a/vortex-cuda/src/lib.rs +++ b/vortex-cuda/src/lib.rs @@ -7,6 +7,7 @@ mod device_buffer; pub mod executor; mod kernel; mod session; +mod stream; use std::process::Command; diff --git a/vortex-cuda/src/stream.rs b/vortex-cuda/src/stream.rs new file mode 100644 index 00000000000..ba1f264ee60 --- /dev/null +++ b/vortex-cuda/src/stream.rs @@ -0,0 +1,83 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +//! CUDA stream utility functions. + +use cudarc::driver::CudaStream; +use cudarc::driver::result::stream; +use kanal::Sender; +use vortex_error::VortexResult; +use vortex_error::vortex_err; + +/// Registers a callback and asynchronously waits for its completion. +/// +/// This function can be used to asynchronously wait for events previously +/// submitted to the stream to complete, e.g. async buffer allocations. +/// +/// Note: This is not equivalent to calling sync on a stream but only awaits +/// the registered callback to complete. +/// +/// # Arguments +/// +/// * `stream` - The CUDA stream to wait on +/// +/// # Errors +/// +/// Returns an error if registering the stream callback fails or if the callback +/// channel closes unexpectedly. +pub async fn await_stream_callback(stream: &CudaStream) -> VortexResult<()> { + let rx = register_stream_callback(stream)?; + + rx.recv() + .await + .map_err(|e| vortex_err!("CUDA stream callback channel closed unexpectedly: {}", e)) +} + +/// Registers a host function callback on the stream. +/// +/// # Returns +/// +/// An async receiver that receives a message when all preceding work on the +/// stream completes. +/// +/// # Errors +/// +/// Returns an error if registering the host callback function fails. +fn register_stream_callback(stream: &CudaStream) -> VortexResult> { + let (tx, rx) = kanal::bounded::<()>(1); + + let tx_ptr = Box::into_raw(Box::new(tx)); + + /// Called from CUDA driver thread when all preceding work on the stream completes. + unsafe extern "C" fn callback(user_data: *mut std::ffi::c_void) { + // SAFETY: The memory of `tx` is manually managed has not been freed + // before. We have unique ownership and can therefore free it. + let tx = unsafe { Box::from_raw(user_data as *mut Sender<()>) }; + + // Blocking send as we're in a callback invoked by the CUDA driver. + #[expect(clippy::expect_used)] + tx.send(()) + // A send should never fail. Panic otherwise. + .expect("CUDA callback receiver dropped unexpectedly"); + } + + // SAFETY: + // 1. Valid handle from the borrowed `CudaStream`. + // 2. Valid function pointer with the the correct signature + // 3. Valid user data pointer which is consumed exactly once + unsafe { + stream::launch_host_function( + stream.cu_stream(), + callback, + tx_ptr as *mut std::ffi::c_void, + ) + .map_err(|err| { + // SAFETY: Registration failed, so the callback will never run. + // We have unique ownership and can therefore free it. + drop(Box::from_raw(tx_ptr)); + vortex_err!("Failed to register CUDA stream callback: {}", err) + })?; + } + + Ok(rx.to_async()) +} diff --git a/vortex-python/src/serde/parts.rs b/vortex-python/src/serde/parts.rs index 900ec95ff02..faf9ab29537 100644 --- a/vortex-python/src/serde/parts.rs +++ b/vortex-python/src/serde/parts.rs @@ -88,7 +88,7 @@ impl PyArrayParts { let mut buffers = Vec::with_capacity(slf.nbuffers()); for buffer in (0..slf.nbuffers()).map(|i| slf.buffer(i)) { - let buffer: ByteBuffer = buffer.and_then(|b| b.try_to_host())?; + let buffer: ByteBuffer = buffer.and_then(|b| b.try_to_host_sync())?; let addr = buffer.as_ptr() as usize; let size = buffer.len();