Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
144 changes: 79 additions & 65 deletions native/spark-expr/src/conversion_funcs/cast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2348,72 +2348,86 @@ fn parse_string_to_decimal(s: &str, precision: u8, scale: i8) -> SparkResult<Opt
}

// validate and parse mantissa and exponent
match parse_decimal_str(trimmed) {
Ok((mantissa, exponent)) => {
// Convert to target scale
let target_scale = scale as i32;
let scale_adjustment = target_scale - exponent;

let scaled_value = if scale_adjustment >= 0 {
// Need to multiply (increase scale) but return None if scale is too high to fit i128
if scale_adjustment > 38 {
return Ok(None);
}
mantissa.checked_mul(10_i128.pow(scale_adjustment as u32))
} else {
// Need to multiply (increase scale) but return None if scale is too high to fit i128
let abs_scale_adjustment = (-scale_adjustment) as u32;
if abs_scale_adjustment > 38 {
return Ok(Some(0));
}
let (mantissa, exponent) = parse_decimal_str(
trimmed,
"STRING",
&format!("DECIMAL({},{})", precision, scale),
)?;

let divisor = 10_i128.pow(abs_scale_adjustment);
let quotient_opt = mantissa.checked_div(divisor);
// Check if divisor is 0
if quotient_opt.is_none() {
return Ok(None);
}
let quotient = quotient_opt.unwrap();
let remainder = mantissa % divisor;

// Round half up: if abs(remainder) >= divisor/2, round away from zero
let half_divisor = divisor / 2;
let rounded = if remainder.abs() >= half_divisor {
if mantissa >= 0 {
quotient + 1
} else {
quotient - 1
}
} else {
quotient
};
Some(rounded)
};
// return early when mantissa is 0
if mantissa == 0 {
return Ok(Some(0));
}

match scaled_value {
Some(value) => {
// Check if it fits target precision
if is_validate_decimal_precision(value, precision) {
Ok(Some(value))
} else {
Ok(None)
}
}
None => {
// Overflow while scaling
Ok(None)
}
// scale adjustment
let target_scale = scale as i32;
let scale_adjustment = target_scale - exponent;

let scaled_value = if scale_adjustment >= 0 {
// Need to multiply (increase scale) but return None if scale is too high to fit i128
if scale_adjustment > 38 {
return Ok(None);
}
mantissa.checked_mul(10_i128.pow(scale_adjustment as u32))
} else {
// Need to divide (decrease scale)
let abs_scale_adjustment = (-scale_adjustment) as u32;
if abs_scale_adjustment > 38 {
return Ok(Some(0));
}

let divisor = 10_i128.pow(abs_scale_adjustment);
let quotient_opt = mantissa.checked_div(divisor);
// Check if divisor is 0
if quotient_opt.is_none() {
return Ok(None);
}
let quotient = quotient_opt.unwrap();
let remainder = mantissa % divisor;

// Round half up: if abs(remainder) >= divisor/2, round away from zero
let half_divisor = divisor / 2;
let rounded = if remainder.abs() >= half_divisor {
if mantissa >= 0 {
quotient + 1
} else {
quotient - 1
}
} else {
quotient
};
Some(rounded)
};

match scaled_value {
Some(value) => {
if is_validate_decimal_precision(value, precision) {
Ok(Some(value))
} else {
// Value ok but exceeds precision mentioned . THrow error
Err(SparkError::NumericValueOutOfRange {
value: trimmed.to_string(),
precision,
scale,
})
}
}
Err(_) => Ok(None),
None => {
// Overflow when scaling raise exception
Err(SparkError::NumericValueOutOfRange {
value: trimmed.to_string(),
precision,
scale,
})
}
}
}

/// Parse a decimal string into mantissa and scale
/// e.g., "123.45" -> (12345, 2), "-0.001" -> (-1, 3)
fn parse_decimal_str(s: &str) -> Result<(i128, i32), String> {
fn parse_decimal_str(s: &str, from_type: &str, to_type: &str) -> SparkResult<(i128, i32)> {
if s.is_empty() {
return Err("Empty string".to_string());
return Err(invalid_value(s, from_type, to_type));
}

let (mantissa_str, exponent) = if let Some(e_pos) = s.find(|c| ['e', 'E'].contains(&c)) {
Expand All @@ -2422,7 +2436,7 @@ fn parse_decimal_str(s: &str) -> Result<(i128, i32), String> {
// Parse exponent
let exp: i32 = exponent_part
.parse()
.map_err(|e| format!("Invalid exponent: {}", e))?;
.map_err(|_| invalid_value(s, from_type, to_type))?;

(mantissa_part, exp)
} else {
Expand All @@ -2437,29 +2451,29 @@ fn parse_decimal_str(s: &str) -> Result<(i128, i32), String> {
};

if mantissa_str.starts_with('+') || mantissa_str.starts_with('-') {
return Err("Invalid sign format".to_string());
return Err(invalid_value(s, from_type, to_type));
}

let (integral_part, fractional_part) = match mantissa_str.find('.') {
Some(dot_pos) => {
if mantissa_str[dot_pos + 1..].contains('.') {
return Err("Multiple decimal points".to_string());
return Err(invalid_value(s, from_type, to_type));
}
(&mantissa_str[..dot_pos], &mantissa_str[dot_pos + 1..])
}
None => (mantissa_str, ""),
};

if integral_part.is_empty() && fractional_part.is_empty() {
return Err("No digits found".to_string());
return Err(invalid_value(s, from_type, to_type));
}

if !integral_part.is_empty() && !integral_part.bytes().all(|b| b.is_ascii_digit()) {
return Err("Invalid integral part".to_string());
return Err(invalid_value(s, from_type, to_type));
}

if !fractional_part.is_empty() && !fractional_part.bytes().all(|b| b.is_ascii_digit()) {
return Err("Invalid fractional part".to_string());
return Err(invalid_value(s, from_type, to_type));
}

// Parse integral part
Expand All @@ -2469,7 +2483,7 @@ fn parse_decimal_str(s: &str) -> Result<(i128, i32), String> {
} else {
integral_part
.parse()
.map_err(|_| "Invalid integral part".to_string())?
.map_err(|_| invalid_value(s, from_type, to_type))?
};

// Parse fractional part
Expand All @@ -2479,14 +2493,14 @@ fn parse_decimal_str(s: &str) -> Result<(i128, i32), String> {
} else {
fractional_part
.parse()
.map_err(|_| "Invalid fractional part".to_string())?
.map_err(|_| invalid_value(s, from_type, to_type))?
};

// Combine: value = integral * 10^fractional_scale + fractional
let mantissa = integral_value
.checked_mul(10_i128.pow(fractional_scale as u32))
.and_then(|v| v.checked_add(fractional_value))
.ok_or("Overflow in mantissa calculation")?;
.ok_or_else(|| invalid_value(s, from_type, to_type))?;

let final_mantissa = if negative { -mantissa } else { mantissa };
// final scale = fractional_scale - exponent
Expand Down
30 changes: 27 additions & 3 deletions spark/src/test/scala/org/apache/comet/CometCastSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -719,19 +719,43 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {

test("cast StringType to DecimalType(2,2)") {
withSQLConf(CometConf.getExprAllowIncompatConfigKey(classOf[Cast]) -> "true") {
// TODO fix for Spark 4.0.0
assume(!isSpark40Plus)
println("testing with simple input")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the debug logging should be removed

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you. I will update and remove debug statements

val values = gen.generateStrings(dataSize, numericPattern, 12).toDF("a")
Seq(true, false).foreach(ansiEnabled =>
castTest(values, DataTypes.createDecimalType(2, 2), testAnsi = ansiEnabled))
}
}

test("cast StringType to DecimalType(2,2) check if right exception is being thrown") {
withSQLConf(CometConf.getExprAllowIncompatConfigKey(classOf[Cast]) -> "true") {
println("testing with simple input")
val values = Seq(" 3").toDF("a")
Seq(true, false).foreach(ansiEnabled =>
castTest(values, DataTypes.createDecimalType(2, 2), testAnsi = ansiEnabled))
}
}

test("cast StringType to DecimalType(38,10) high precision - check 0 mantissa") {
withSQLConf(CometConf.getExprAllowIncompatConfigKey(classOf[Cast]) -> "true") {
val values = Seq("0e31").toDF("a")
Seq(true, false).foreach(ansiEnabled =>
castTest(values, DataTypes.createDecimalType(38, 10), testAnsi = ansiEnabled))
}
}

test("cast StringType to DecimalType(38,10) high precision") {
withSQLConf(CometConf.getExprAllowIncompatConfigKey(classOf[Cast]) -> "true") {
val values = gen.generateStrings(dataSize, numericPattern, 38).toDF("a")
Seq(true, false).foreach(ansiEnabled =>
castTest(values, DataTypes.createDecimalType(38, 10), testAnsi = ansiEnabled))
}
}

test("cast StringType to DecimalType(38,10) high precision - 0 mantissa") {
withSQLConf(CometConf.getExprAllowIncompatConfigKey(classOf[Cast]) -> "true") {
// TODO fix for Spark 4.0.0
assume(!isSpark40Plus)
val values = gen.generateStrings(dataSize, numericPattern, 38).toDF("a")
val values = Seq("0e31").toDF("a")
Seq(true, false).foreach(ansiEnabled =>
castTest(values, DataTypes.createDecimalType(38, 10), testAnsi = ansiEnabled))
}
Expand Down
Loading