Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion encodings/bytebool/src/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ impl VTable for ByteBoolVTable {
if buffers.len() != 1 {
vortex_bail!("Expected 1 buffer, got {}", buffers.len());
}
let buffer = buffers[0].clone().try_to_host()?;
let buffer = buffers[0].clone().try_to_host_sync()?;

Ok(ByteBoolArray::new(buffer, validity))
}
Expand Down
2 changes: 1 addition & 1 deletion encodings/fastlanes/src/bitpacking/vtable/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ impl VTable for BitPackedVTable {
if buffers.len() != 1 {
vortex_bail!("Expected 1 buffer, got {}", buffers.len());
}
let packed = buffers[0].clone().try_to_host()?;
let packed = buffers[0].clone().try_to_host_sync()?;

let load_validity = |child_idx: usize| {
if children.len() == child_idx {
Expand Down
4 changes: 2 additions & 2 deletions encodings/fsst/src/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -111,8 +111,8 @@ impl VTable for FSSTVTable {
if buffers.len() != 2 {
vortex_bail!(InvalidArgument: "Expected 2 buffers, got {}", buffers.len());
}
let symbols = Buffer::<Symbol>::from_byte_buffer(buffers[0].clone().try_to_host()?);
let symbol_lengths = Buffer::<u8>::from_byte_buffer(buffers[1].clone().try_to_host()?);
let symbols = Buffer::<Symbol>::from_byte_buffer(buffers[0].clone().try_to_host_sync()?);
let symbol_lengths = Buffer::<u8>::from_byte_buffer(buffers[1].clone().try_to_host_sync()?);

if children.len() != 2 {
vortex_bail!(InvalidArgument: "Expected 2 children, got {}", children.len());
Expand Down
4 changes: 2 additions & 2 deletions encodings/pco/src/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -131,11 +131,11 @@ impl VTable for PcoVTable {
vortex_ensure!(buffers.len() >= metadata.0.chunks.len());
let chunk_metas = buffers[..metadata.0.chunks.len()]
.iter()
.map(|b| b.clone().try_to_host())
.map(|b| b.clone().try_to_host_sync())
.collect::<VortexResult<Vec<_>>>()?;
let pages = buffers[metadata.0.chunks.len()..]
.iter()
.map(|b| b.clone().try_to_host())
.map(|b| b.clone().try_to_host_sync())
.collect::<VortexResult<Vec<_>>>()?;

let expected_n_pages = metadata
Expand Down
2 changes: 1 addition & 1 deletion encodings/sparse/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ impl VTable for SparseVTable {
}
let fill_value = Scalar::new(
dtype.clone(),
ScalarValue::from_protobytes(&buffers[0].clone().try_to_host()?)?,
ScalarValue::from_protobytes(&buffers[0].clone().try_to_host_sync()?)?,
);

SparseArray::try_new(patch_indices, patch_values, len, fill_value)
Expand Down
8 changes: 4 additions & 4 deletions encodings/zstd/src/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -132,16 +132,16 @@ impl VTable for ZstdVTable {
None,
buffers
.iter()
.map(|b| b.clone().try_to_host())
.map(|b| b.clone().try_to_host_sync())
.collect::<VortexResult<Vec<_>>>()?,
)
} else {
// with dictionary
(
Some(buffers[0].clone().try_to_host()?),
Some(buffers[0].clone().try_to_host_sync()?),
buffers[1..]
.iter()
.map(|b| b.clone().try_to_host())
.map(|b| b.clone().try_to_host_sync())
.collect::<VortexResult<Vec<_>>>()?,
)
};
Expand Down Expand Up @@ -368,7 +368,7 @@ impl ZstdArray {
n_values
};

let value_bytes = values.buffer_handle().try_to_host()?;
let value_bytes = values.buffer_handle().try_to_host_sync()?;
// Align frames to buffer alignment. This is necessary for overaligned buffers.
let alignment = *value_bytes.alignment();
let step_width = (values_per_frame * byte_width).div_ceil(alignment) * alignment;
Expand Down
2 changes: 1 addition & 1 deletion vortex-array/src/array/visitor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ impl<A: Array + ?Sized> ArrayVisitorExt for A {}

pub trait ArrayBufferVisitor {
fn visit_buffer_handle(&mut self, handle: &BufferHandle) -> VortexResult<()> {
self.visit_buffer(&handle.clone().try_to_host()?);
self.visit_buffer(&handle.clone().try_to_host_sync()?);
Ok(())
}
fn visit_buffer(&mut self, buffer: &ByteBuffer);
Expand Down
2 changes: 1 addition & 1 deletion vortex-array/src/arrays/bool/vtable/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ impl VTable for BoolVTable {
vortex_bail!("Expected 0 or 1 child, got {}", children.len());
};

let buffer = buffers[0].clone().try_to_host()?;
let buffer = buffers[0].clone().try_to_host_sync()?;
let bits = BitBuffer::new_with_offset(buffer, len, metadata.offset as usize);

BoolArray::try_new(bits, validity)
Expand Down
2 changes: 1 addition & 1 deletion vortex-array/src/arrays/constant/vtable/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ impl VTable for ConstantVTable {
if buffers.len() != 1 {
vortex_bail!("Expected 1 buffer, got {}", buffers.len());
}
let buffer = buffers[0].clone().try_to_host()?;
let buffer = buffers[0].clone().try_to_host_sync()?;
let sv = ScalarValue::from_protobytes(&buffer)?;
let scalar = Scalar::new(dtype.clone(), sv);
Ok(ConstantArray::new(scalar, len))
Expand Down
2 changes: 1 addition & 1 deletion vortex-array/src/arrays/decimal/vtable/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ impl VTable for DecimalVTable {
if buffers.len() != 1 {
vortex_bail!("Expected 1 buffer, got {}", buffers.len());
}
let buffer = buffers[0].clone().try_to_host()?;
let buffer = buffers[0].clone().try_to_host_sync()?;

let validity = if children.is_empty() {
Validity::from(dtype.nullability())
Expand Down
6 changes: 3 additions & 3 deletions vortex-array/src/arrays/primitive/array/conversion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ impl PrimitiveArray {
)
}

Buffer::from_byte_buffer(self.buffer_handle().to_host())
Buffer::from_byte_buffer(self.buffer_handle().to_host_sync())
}

/// Consume the array and get a host Buffer containing the data values.
Expand All @@ -96,7 +96,7 @@ impl PrimitiveArray {
)
}

Buffer::from_byte_buffer(self.buffer.into_host())
Buffer::from_byte_buffer(self.buffer.into_host_sync())
}

/// Extract a mutable buffer from the PrimitiveArray. Attempts to do this with zero-copy
Expand All @@ -115,7 +115,7 @@ impl PrimitiveArray {
self.ptype()
)
}
let buffer = Buffer::<T>::from_byte_buffer(self.buffer.into_host());
let buffer = Buffer::<T>::from_byte_buffer(self.buffer.into_host_sync());
buffer.try_into_mut()
}
}
Expand Down
2 changes: 1 addition & 1 deletion vortex-array/src/arrays/primitive/vtable/visitor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use crate::vtable::VisitorVTable;

impl VisitorVTable<PrimitiveVTable> for PrimitiveVTable {
fn visit_buffers(array: &PrimitiveArray, visitor: &mut dyn ArrayBufferVisitor) {
visitor.visit_buffer(&array.buffer_handle().to_host());
visitor.visit_buffer(&array.buffer_handle().to_host_sync());
}

fn visit_children(array: &PrimitiveArray, visitor: &mut dyn ArrayChildVisitor) {
Expand Down
2 changes: 1 addition & 1 deletion vortex-array/src/arrays/varbin/vtable/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ impl VTable for VarBinVTable {
if buffers.len() != 1 {
vortex_bail!("Expected 1 buffer, got {}", buffers.len());
}
let bytes = buffers[0].clone().try_to_host()?;
let bytes = buffers[0].clone().try_to_host_sync()?;

VarBinArray::try_new(offsets, bytes, dtype.clone(), validity)
}
Expand Down
2 changes: 1 addition & 1 deletion vortex-array/src/arrays/varbinview/vtable/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ impl VTable for VarBinViewVTable {
}
let mut buffers: Vec<ByteBuffer> = buffers
.iter()
.map(|b| b.clone().try_to_host())
.map(|b| b.clone().try_to_host_sync())
.collect::<VortexResult<Vec<_>>>()?;
let views = buffers.pop().vortex_expect("buffers non-empty");

Expand Down
106 changes: 97 additions & 9 deletions vortex-array/src/buffer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ use std::hash::Hasher;
use std::ops::Range;
use std::sync::Arc;

use futures::future::BoxFuture;
use vortex_buffer::ALIGNMENT_TO_HOST_COPY;
use vortex_buffer::Alignment;
use vortex_buffer::ByteBuffer;
Expand Down Expand Up @@ -59,7 +60,23 @@ pub trait DeviceBuffer: 'static + Send + Sync + Debug + DynEq + DynHash {
/// # Errors
///
/// This operation may fail, depending on the device implementation and the underlying hardware.
fn copy_to_host(&self, alignment: Alignment) -> VortexResult<ByteBuffer>;
fn copy_to_host_sync(&self, alignment: Alignment) -> VortexResult<ByteBuffer>;

/// Copies the device buffer to a host buffer asynchronously.
///
/// Schedules an async copy and returns a future that completes when the copy is finished.
///
/// # Arguments
///
/// * `alignment` - The memory alignment to use for the host buffer.
///
/// # Errors
///
/// Returns an error if the async copy operation fails.
fn copy_to_host(
&self,
alignment: Alignment,
) -> VortexResult<BoxFuture<'static, VortexResult<ByteBuffer>>>;

/// Create a new buffer that references a subrange of this buffer at the given
/// slice indices.
Expand Down Expand Up @@ -147,7 +164,7 @@ impl BufferHandle {
/// let values = buffer![1u32, 2u32, 3u32, 4u32];
/// let handle = BufferHandle::new_host(values.into_byte_buffer());
/// let sliced = handle.slice_typed::<u32>(1..4);
/// let result = Buffer::<u32>::from_byte_buffer(sliced.to_host());
/// let result = Buffer::<u32>::from_byte_buffer(sliced.to_host_sync());
/// assert_eq!(result, buffer![2, 3, 4]);
/// ```
pub fn slice_typed<T: Sized>(&self, range: Range<usize>) -> Self {
Expand Down Expand Up @@ -214,8 +231,8 @@ impl BufferHandle {
/// result in a panic.
///
/// See also: [`try_to_host`][Self::try_to_host].
pub fn to_host(&self) -> ByteBuffer {
self.try_to_host()
pub fn to_host_sync(&self) -> ByteBuffer {
self.try_to_host_sync()
.vortex_expect("to_host: copy from device to host failed")
}

Expand All @@ -228,8 +245,8 @@ impl BufferHandle {
/// # Panics
///
/// See the panic documentation on [`to_host`][Self::to_host].
pub fn into_host(self) -> ByteBuffer {
self.try_into_host()
pub fn into_host_sync(self) -> ByteBuffer {
self.try_into_host_sync()
.vortex_expect("into_host: copy from device to host failed")
}

Expand All @@ -239,22 +256,93 @@ impl BufferHandle {
///
/// If it is a device allocation, then this issues an operation that attempts to copy the data
/// from the device into a host-resident buffer, and returns a handle to that buffer.
pub fn try_to_host(&self) -> VortexResult<ByteBuffer> {
pub fn try_to_host_sync(&self) -> VortexResult<ByteBuffer> {
match &self.0 {
Inner::Host(b) => Ok(b.clone()),
Inner::Device(device) => device.copy_to_host(ALIGNMENT_TO_HOST_COPY),
Inner::Device(device) => device.copy_to_host_sync(ALIGNMENT_TO_HOST_COPY),
}
}

/// Attempts to load this buffer into a host-resident allocation, consuming the handle.
///
/// See also [`try_to_host`][Self::try_to_host].
pub fn try_into_host(self) -> VortexResult<ByteBuffer> {
pub fn try_into_host_sync(self) -> VortexResult<ByteBuffer> {
match self.0 {
Inner::Host(b) => Ok(b),
Inner::Device(device) => device.copy_to_host_sync(ALIGNMENT_TO_HOST_COPY),
}
}

/// Asynchronously copies the buffer to the host.
///
/// This is a no-op if the buffer is already on the host.
///
/// # Returns
///
/// A future that resolves to the host buffer when the copy completes.
///
/// # Errors
///
/// Returns an error if the async copy operation fails.
pub fn try_to_host(&self) -> VortexResult<BoxFuture<'static, VortexResult<ByteBuffer>>> {
match &self.0 {
Inner::Host(b) => {
let buffer = b.clone();
Ok(Box::pin(async move { Ok(buffer) }))
}
Inner::Device(device) => device.copy_to_host(ALIGNMENT_TO_HOST_COPY),
}
}

/// Asynchronously copies the buffer to the host, consuming the handle.
///
/// This is a no-op if the buffer is already on the host.
///
/// # Returns
///
/// A future that resolves to the host buffer when the copy completes.
///
/// # Errors
///
/// Returns an error if the async copy operation fails.
pub fn try_into_host(self) -> VortexResult<BoxFuture<'static, VortexResult<ByteBuffer>>> {
match self.0 {
Inner::Host(b) => Ok(Box::pin(async move { Ok(b) })),
Inner::Device(device) => device.copy_to_host(ALIGNMENT_TO_HOST_COPY),
}
}

/// Asynchronously copies the buffer to the host.
///
/// # Panics
///
/// Any errors triggered by the copying from device to host will result in a panic.
pub fn to_host(&self) -> BoxFuture<'static, ByteBuffer> {
let future = self
.try_to_host()
.vortex_expect("to_host: failed to initiate copy from device to host");
Box::pin(async move {
future
.await
.vortex_expect("to_host: copy from device to host failed")
})
}

/// Asynchronously copies the buffer to the host, consuming the handle.
///
/// # Panics
///
/// Any errors triggered by the copying from device to host will result in a panic.
pub fn into_host(self) -> BoxFuture<'static, ByteBuffer> {
let future = self
.try_into_host()
.vortex_expect("into_host: failed to initiate copy from device to host");
Box::pin(async move {
future
.await
.vortex_expect("into_host: copy from device to host failed")
})
}
}

impl ArrayHash for BufferHandle {
Expand Down
4 changes: 2 additions & 2 deletions vortex-array/src/compute/conformance/take.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,8 @@ fn test_take_all(array: &dyn Array) {
) {
(Canonical::Primitive(orig_prim), Canonical::Primitive(result_prim)) => {
assert_eq!(
orig_prim.buffer_handle().to_host(),
result_prim.buffer_handle().to_host()
orig_prim.buffer_handle().to_host_sync(),
result_prim.buffer_handle().to_host_sync()
);
}
_ => {
Expand Down
4 changes: 2 additions & 2 deletions vortex-array/src/serde.rs
Original file line number Diff line number Diff line change
Expand Up @@ -491,7 +491,7 @@ impl ArrayParts {
segment: BufferHandle,
) -> VortexResult<Self> {
// TODO: this can also work with device buffers.
let segment = segment.try_to_host()?;
let segment = segment.try_to_host_sync()?;
// We align each buffer individually, so we remove alignment requirements on the buffer.
let segment = segment.aligned(Alignment::none());

Expand Down Expand Up @@ -612,6 +612,6 @@ impl TryFrom<BufferHandle> for ArrayParts {
type Error = VortexError;

fn try_from(value: BufferHandle) -> Result<Self, Self::Error> {
Self::try_from(value.try_to_host()?)
Self::try_from(value.try_to_host_sync()?)
}
}
Loading
Loading