Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 13 additions & 13 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

10 changes: 7 additions & 3 deletions encodings/alp/src/alp/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,11 @@ impl VTable for ALPVTable {
ALPArray::new(
array.encoded().slice(range.clone()),
array.exponents(),
array.patches().and_then(|p| p.slice(range)),
array
.patches()
.map(|p| p.slice(range))
.transpose()?
.flatten(),
)
.into_array(),
))
Expand Down Expand Up @@ -349,7 +353,7 @@ impl ALPArray {
/// None
/// ).unwrap();
///
/// assert_eq!(value.scalar_at(0), 0f32.into());
/// assert_eq!(value.scalar_at(0).unwrap(), 0f32.into());
/// ```
pub fn try_new(
encoded: ArrayRef,
Expand Down Expand Up @@ -696,7 +700,7 @@ mod tests {
for idx in 0..slice_len {
let expected_value = values[slice_start + idx];

let result_valid = result_primitive.validity_mask().value(idx);
let result_valid = result_primitive.validity_mask().unwrap().value(idx);
assert_eq!(
result_valid,
expected_value.is_some(),
Expand Down
2 changes: 1 addition & 1 deletion encodings/alp/src/alp/compress.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ where

let encoded_array = PrimitiveArray::new(encoded, values.validity().clone()).into_array();

let validity = values.validity_mask();
let validity = values.validity_mask()?;
// exceptional_positions may contain exceptions at invalid positions (which contain garbage
// data). We remove null exceptions in order to keep the Patches small.
let (valid_exceptional_positions, valid_exceptional_values): (Buffer<u64>, Buffer<T>) =
Expand Down
13 changes: 7 additions & 6 deletions encodings/alp/src/alp/ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

use vortex_array::vtable::OperationsVTable;
use vortex_error::VortexExpect;
use vortex_error::VortexResult;
use vortex_scalar::Scalar;

use crate::ALPArray;
Expand All @@ -11,16 +12,16 @@ use crate::ALPVTable;
use crate::match_each_alp_float_ptype;

impl OperationsVTable<ALPVTable> for ALPVTable {
fn scalar_at(array: &ALPArray, index: usize) -> Scalar {
fn scalar_at(array: &ALPArray, index: usize) -> VortexResult<Scalar> {
if let Some(patches) = array.patches()
&& let Some(patch) = patches.get_patched(index)
&& let Some(patch) = patches.get_patched(index)?
{
return patch.cast(array.dtype()).vortex_expect("cast failure");
return patch.cast(array.dtype());
}

let encoded_val = array.encoded().scalar_at(index);
let encoded_val = array.encoded().scalar_at(index)?;

match_each_alp_float_ptype!(array.ptype(), |T| {
Ok(match_each_alp_float_ptype!(array.ptype(), |T| {
let encoded_val: <T as ALPFloat>::ALPInt = encoded_val
.as_ref()
.try_into()
Expand All @@ -29,6 +30,6 @@ impl OperationsVTable<ALPVTable> for ALPVTable {
<T as ALPFloat>::decode_single(encoded_val, array.exponents()),
array.dtype().nullability(),
)
})
}))
}
}
6 changes: 4 additions & 2 deletions encodings/alp/src/alp_rd/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,9 @@ impl VTable for ALPRDVTable {
fn slice(array: &Self::Array, range: std::ops::Range<usize>) -> VortexResult<Option<ArrayRef>> {
let left_parts_exceptions = array
.left_parts_patches()
.and_then(|patches| patches.slice(range.clone()));
.map(|patches| patches.slice(range.clone()))
.transpose()?
.flatten();

// SAFETY: slicing components does not change the encoded values
Ok(Some(unsafe {
Expand Down Expand Up @@ -341,7 +343,7 @@ impl ALPRDArray {

let left_parts_patches = left_parts_patches
.map(|patches| {
if !patches.values().all_valid() {
if !patches.values().all_valid()? {
vortex_bail!("patches must be all valid: {}", patches.values());
}
// TODO(ngates): assert the DType, don't cast it.
Expand Down
20 changes: 11 additions & 9 deletions encodings/alp/src/alp_rd/ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,20 @@
use vortex_array::Array;
use vortex_array::vtable::OperationsVTable;
use vortex_error::VortexExpect;
use vortex_error::VortexResult;
use vortex_scalar::Scalar;

use crate::ALPRDArray;
use crate::ALPRDVTable;

impl OperationsVTable<ALPRDVTable> for ALPRDVTable {
fn scalar_at(array: &ALPRDArray, index: usize) -> Scalar {
fn scalar_at(array: &ALPRDArray, index: usize) -> VortexResult<Scalar> {
// The left value can either be a direct value, or an exception.
// The exceptions array represents exception positions with non-null values.
let maybe_patched_value = array
.left_parts_patches()
.and_then(|patches| patches.get_patched(index));
let maybe_patched_value = match array.left_parts_patches() {
Some(patches) => patches.get_patched(index)?,
None => None,
};
let left = match maybe_patched_value {
Some(patched_value) => patched_value
.as_primitive()
Expand All @@ -24,7 +26,7 @@ impl OperationsVTable<ALPRDVTable> for ALPRDVTable {
_ => {
let left_code: u16 = array
.left_parts()
.scalar_at(index)
.scalar_at(index)?
.as_primitive()
.as_::<u16>()
.vortex_expect("left_code must be non-null");
Expand All @@ -33,10 +35,10 @@ impl OperationsVTable<ALPRDVTable> for ALPRDVTable {
};

// combine left and right values
if array.is_f32() {
Ok(if array.is_f32() {
let right: u32 = array
.right_parts()
.scalar_at(index)
.scalar_at(index)?
.as_primitive()
.as_::<u32>()
.vortex_expect("non-null");
Expand All @@ -45,13 +47,13 @@ impl OperationsVTable<ALPRDVTable> for ALPRDVTable {
} else {
let right: u64 = array
.right_parts()
.scalar_at(index)
.scalar_at(index)?
.as_primitive()
.as_::<u64>()
.vortex_expect("non-null");
let packed = f64::from_bits(((left as u64) << array.right_bit_width()) | right);
Scalar::primitive(packed, array.dtype().nullability())
}
})
}
}

Expand Down
17 changes: 10 additions & 7 deletions encodings/bytebool/src/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -219,8 +219,11 @@ impl BaseArrayVTable<ByteBoolVTable> for ByteBoolVTable {
}

impl OperationsVTable<ByteBoolVTable> for ByteBoolVTable {
fn scalar_at(array: &ByteBoolArray, index: usize) -> Scalar {
Scalar::bool(array.buffer()[index] == 1, array.dtype().nullability())
fn scalar_at(array: &ByteBoolArray, index: usize) -> VortexResult<Scalar> {
Ok(Scalar::bool(
array.buffer()[index] == 1,
array.dtype().nullability(),
))
}
}

Expand Down Expand Up @@ -275,14 +278,14 @@ mod tests {
assert_eq!(v_len, arr.len());

for idx in 0..arr.len() {
assert!(arr.is_valid(idx));
assert!(arr.is_valid(idx).unwrap());
}

let v = vec![Some(true), None, Some(false)];
let arr = ByteBoolArray::from(v);
assert!(arr.is_valid(0));
assert!(!arr.is_valid(1));
assert!(arr.is_valid(2));
assert!(arr.is_valid(0).unwrap());
assert!(!arr.is_valid(1).unwrap());
assert!(arr.is_valid(2).unwrap());
assert_eq!(arr.len(), 3);

let v: Vec<Option<bool>> = vec![None, None];
Expand All @@ -292,7 +295,7 @@ mod tests {
assert_eq!(v_len, arr.len());

for idx in 0..arr.len() {
assert!(!arr.is_valid(idx));
assert!(!arr.is_valid(idx).unwrap());
}
assert_eq!(arr.len(), 2);
}
Expand Down
4 changes: 2 additions & 2 deletions encodings/datetime-parts/src/canonical.rs
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ pub fn decode_to_temporal(
}

Ok(TemporalArray::new_timestamp(
PrimitiveArray::new(values.freeze(), Validity::copy_from_array(array.as_ref()))
PrimitiveArray::new(values.freeze(), Validity::copy_from_array(array.as_ref())?)
.into_array(),
temporal_metadata.time_unit(),
temporal_metadata.time_zone().map(ToString::to_string),
Expand Down Expand Up @@ -147,7 +147,7 @@ mod test {
.unwrap();

assert_eq!(
date_times.validity_mask(),
date_times.validity_mask().unwrap(),
validity.to_mask(date_times.len())
);

Expand Down
15 changes: 8 additions & 7 deletions encodings/datetime-parts/src/ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use vortex_array::vtable::OperationsVTable;
use vortex_dtype::DType;
use vortex_dtype::datetime::TemporalMetadata;
use vortex_error::VortexExpect;
use vortex_error::VortexResult;
use vortex_error::vortex_panic;
use vortex_scalar::Scalar;

Expand All @@ -15,7 +16,7 @@ use crate::timestamp;
use crate::timestamp::TimestampParts;

impl OperationsVTable<DateTimePartsVTable> for DateTimePartsVTable {
fn scalar_at(array: &DateTimePartsArray, index: usize) -> Scalar {
fn scalar_at(array: &DateTimePartsArray, index: usize) -> VortexResult<Scalar> {
let DType::Extension(ext) = array.dtype().clone() else {
vortex_panic!(
"DateTimePartsArray must have extension dtype, found {}",
Expand All @@ -27,25 +28,25 @@ impl OperationsVTable<DateTimePartsVTable> for DateTimePartsVTable {
vortex_panic!(ComputeError: "must decode TemporalMetadata from extension metadata");
};

if !array.is_valid(index) {
return Scalar::null(DType::Extension(ext));
if !array.is_valid(index)? {
return Ok(Scalar::null(DType::Extension(ext)));
}

let days: i64 = array
.days()
.scalar_at(index)
.scalar_at(index)?
.as_primitive()
.as_::<i64>()
.vortex_expect("days fits in i64");
let seconds: i64 = array
.seconds()
.scalar_at(index)
.scalar_at(index)?
.as_primitive()
.as_::<i64>()
.vortex_expect("seconds fits in i64");
let subseconds: i64 = array
.subseconds()
.scalar_at(index)
.scalar_at(index)?
.as_primitive()
.as_::<i64>()
.vortex_expect("subseconds fits in i64");
Expand All @@ -59,6 +60,6 @@ impl OperationsVTable<DateTimePartsVTable> for DateTimePartsVTable {
temporal_metadata.time_unit(),
);

Scalar::extension(ext, Scalar::from(ts))
Ok(Scalar::extension(ext, Scalar::from(ts)))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ impl CompareKernel for DecimalBytePartsVTable {
// (depending on the `sign`) than all values in MSP.
// If the LHS or the RHS contain nulls, then we must fallback to the canonicalized
// implementation which does null-checking instead.
if lhs.all_valid() && rhs.all_valid() {
if lhs.all_valid()? && rhs.all_valid()? {
Ok(Some(
ConstantArray::new(
unconvertible_value(sign, operator, nullability),
Expand Down
Loading
Loading