From 76a45271ac1343b43c65612447da65c4d0121195 Mon Sep 17 00:00:00 2001 From: B Vadlamani Date: Wed, 21 Jan 2026 23:20:19 -0800 Subject: [PATCH 1/4] cinnit --- .../spark-expr/src/conversion_funcs/cast.rs | 26 +++++++++++++++---- .../org/apache/comet/CometCastSuite.scala | 26 ++++++++++++++++--- 2 files changed, 43 insertions(+), 9 deletions(-) diff --git a/native/spark-expr/src/conversion_funcs/cast.rs b/native/spark-expr/src/conversion_funcs/cast.rs index 2ff1d8c551..5acb70347e 100644 --- a/native/spark-expr/src/conversion_funcs/cast.rs +++ b/native/spark-expr/src/conversion_funcs/cast.rs @@ -2350,6 +2350,11 @@ fn parse_string_to_decimal(s: &str, precision: u8, scale: i8) -> SparkResult { + // Special case: zero always fits regardless of scale + if mantissa == 0 { + return Ok(Some(0)); + } + // Convert to target scale let target_scale = scale as i32; let scale_adjustment = target_scale - exponent; @@ -2392,20 +2397,31 @@ fn parse_string_to_decimal(s: &str, precision: u8, scale: i8) -> SparkResult { - // Check if it fits target precision if is_validate_decimal_precision(value, precision) { Ok(Some(value)) } else { - Ok(None) + // Value parsed successfully but exceeds target precision + Err(SparkError::NumericValueOutOfRange { + value: trimmed.to_string(), + precision, + scale, + }) } } None => { - // Overflow while scaling - Ok(None) + // Overflow while scaling . Throw error and let caller handle it based on EVAL mode + Err(SparkError::NumericValueOutOfRange { + value: trimmed.to_string(), + precision, + scale, + }) } } } - Err(_) => Ok(None), + Err(e) => { + // raise malformed input + Err(SparkError::Internal(e)) + } } } diff --git a/spark/src/test/scala/org/apache/comet/CometCastSuite.scala b/spark/src/test/scala/org/apache/comet/CometCastSuite.scala index 8a68df3820..4900bca14a 100644 --- a/spark/src/test/scala/org/apache/comet/CometCastSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometCastSuite.scala @@ -719,9 +719,17 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper { test("cast StringType to DecimalType(2,2)") { withSQLConf(CometConf.getExprAllowIncompatConfigKey(classOf[Cast]) -> "true") { - // TODO fix for Spark 4.0.0 - assume(!isSpark40Plus) - val values = gen.generateStrings(dataSize, numericPattern, 12).toDF("a") + println("testing with simple input") + val values = Seq(" 3").toDF("a") + Seq(true, false).foreach(ansiEnabled => + castTest(values, DataTypes.createDecimalType(2, 2), testAnsi = ansiEnabled)) + } + } + + test("cast StringType to DecimalType(2,2) check if right exception is being thrown") { + withSQLConf(CometConf.getExprAllowIncompatConfigKey(classOf[Cast]) -> "true") { + println("testing with simple input") + val values = Seq(" 3").toDF("a") Seq(true, false).foreach(ansiEnabled => castTest(values, DataTypes.createDecimalType(2, 2), testAnsi = ansiEnabled)) } @@ -731,7 +739,17 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper { withSQLConf(CometConf.getExprAllowIncompatConfigKey(classOf[Cast]) -> "true") { // TODO fix for Spark 4.0.0 assume(!isSpark40Plus) - val values = gen.generateStrings(dataSize, numericPattern, 38).toDF("a") + val values = Seq("0e31").toDF("a") + Seq(true, false).foreach(ansiEnabled => + castTest(values, DataTypes.createDecimalType(38, 10), testAnsi = ansiEnabled)) + } + } + + test("cast StringType to DecimalType(38,10) high precision - 0 mantissa") { + withSQLConf(CometConf.getExprAllowIncompatConfigKey(classOf[Cast]) -> "true") { + // TODO fix for Spark 4.0.0 + assume(!isSpark40Plus) + val values = Seq("0e31").toDF("a") Seq(true, false).foreach(ansiEnabled => castTest(values, DataTypes.createDecimalType(38, 10), testAnsi = ansiEnabled)) } From 0c8f0a4703345f5b8e92dbb9081637119f5d3062 Mon Sep 17 00:00:00 2001 From: B Vadlamani Date: Thu, 22 Jan 2026 14:52:27 -0800 Subject: [PATCH 2/4] fix_issue --- .../spark-expr/src/conversion_funcs/cast.rs | 154 +++++++++--------- 1 file changed, 76 insertions(+), 78 deletions(-) diff --git a/native/spark-expr/src/conversion_funcs/cast.rs b/native/spark-expr/src/conversion_funcs/cast.rs index 5acb70347e..a1d3408b47 100644 --- a/native/spark-expr/src/conversion_funcs/cast.rs +++ b/native/spark-expr/src/conversion_funcs/cast.rs @@ -2348,88 +2348,86 @@ fn parse_string_to_decimal(s: &str, precision: u8, scale: i8) -> SparkResult { - // Special case: zero always fits regardless of scale - if mantissa == 0 { - return Ok(Some(0)); - } + let (mantissa, exponent) = parse_decimal_str( + trimmed, + "STRING", + &format!("DECIMAL({},{})", precision, scale), + )?; - // Convert to target scale - let target_scale = scale as i32; - let scale_adjustment = target_scale - exponent; + // return early when mantissa is 0 + if mantissa == 0 { + return Ok(Some(0)); + } - let scaled_value = if scale_adjustment >= 0 { - // Need to multiply (increase scale) but return None if scale is too high to fit i128 - if scale_adjustment > 38 { - return Ok(None); - } - mantissa.checked_mul(10_i128.pow(scale_adjustment as u32)) - } else { - // Need to multiply (increase scale) but return None if scale is too high to fit i128 - let abs_scale_adjustment = (-scale_adjustment) as u32; - if abs_scale_adjustment > 38 { - return Ok(Some(0)); - } + // scale adjustment + let target_scale = scale as i32; + let scale_adjustment = target_scale - exponent; - let divisor = 10_i128.pow(abs_scale_adjustment); - let quotient_opt = mantissa.checked_div(divisor); - // Check if divisor is 0 - if quotient_opt.is_none() { - return Ok(None); - } - let quotient = quotient_opt.unwrap(); - let remainder = mantissa % divisor; - - // Round half up: if abs(remainder) >= divisor/2, round away from zero - let half_divisor = divisor / 2; - let rounded = if remainder.abs() >= half_divisor { - if mantissa >= 0 { - quotient + 1 - } else { - quotient - 1 - } - } else { - quotient - }; - Some(rounded) - }; + let scaled_value = if scale_adjustment >= 0 { + // Need to multiply (increase scale) but return None if scale is too high to fit i128 + if scale_adjustment > 38 { + return Ok(None); + } + mantissa.checked_mul(10_i128.pow(scale_adjustment as u32)) + } else { + // Need to divide (decrease scale) + let abs_scale_adjustment = (-scale_adjustment) as u32; + if abs_scale_adjustment > 38 { + return Ok(Some(0)); + } - match scaled_value { - Some(value) => { - if is_validate_decimal_precision(value, precision) { - Ok(Some(value)) - } else { - // Value parsed successfully but exceeds target precision - Err(SparkError::NumericValueOutOfRange { - value: trimmed.to_string(), - precision, - scale, - }) - } - } - None => { - // Overflow while scaling . Throw error and let caller handle it based on EVAL mode - Err(SparkError::NumericValueOutOfRange { - value: trimmed.to_string(), - precision, - scale, - }) - } + let divisor = 10_i128.pow(abs_scale_adjustment); + let quotient_opt = mantissa.checked_div(divisor); + // Check if divisor is 0 + if quotient_opt.is_none() { + return Ok(None); + } + let quotient = quotient_opt.unwrap(); + let remainder = mantissa % divisor; + + // Round half up: if abs(remainder) >= divisor/2, round away from zero + let half_divisor = divisor / 2; + let rounded = if remainder.abs() >= half_divisor { + if mantissa >= 0 { + quotient + 1 + } else { + quotient - 1 + } + } else { + quotient + }; + Some(rounded) + }; + + match scaled_value { + Some(value) => { + if is_validate_decimal_precision(value, precision) { + Ok(Some(value)) + } else { + // Value ok but exceeds precision mentioned . THrow error + Err(SparkError::NumericValueOutOfRange { + value: trimmed.to_string(), + precision, + scale, + }) } } - Err(e) => { - // raise malformed input - Err(SparkError::Internal(e)) + None => { + // Overflow when scaling raise exception + Err(SparkError::NumericValueOutOfRange { + value: trimmed.to_string(), + precision, + scale, + }) } } } /// Parse a decimal string into mantissa and scale /// e.g., "123.45" -> (12345, 2), "-0.001" -> (-1, 3) -fn parse_decimal_str(s: &str) -> Result<(i128, i32), String> { +fn parse_decimal_str(s: &str, from_type: &str, to_type: &str) -> SparkResult<(i128, i32)> { if s.is_empty() { - return Err("Empty string".to_string()); + return Err(invalid_value(s, from_type, to_type)); } let (mantissa_str, exponent) = if let Some(e_pos) = s.find(|c| ['e', 'E'].contains(&c)) { @@ -2438,7 +2436,7 @@ fn parse_decimal_str(s: &str) -> Result<(i128, i32), String> { // Parse exponent let exp: i32 = exponent_part .parse() - .map_err(|e| format!("Invalid exponent: {}", e))?; + .map_err(|_| invalid_value(s, from_type, to_type))?; (mantissa_part, exp) } else { @@ -2453,13 +2451,13 @@ fn parse_decimal_str(s: &str) -> Result<(i128, i32), String> { }; if mantissa_str.starts_with('+') || mantissa_str.starts_with('-') { - return Err("Invalid sign format".to_string()); + return Err(invalid_value(s, from_type, to_type)); } let (integral_part, fractional_part) = match mantissa_str.find('.') { Some(dot_pos) => { if mantissa_str[dot_pos + 1..].contains('.') { - return Err("Multiple decimal points".to_string()); + return Err(invalid_value(s, from_type, to_type)); } (&mantissa_str[..dot_pos], &mantissa_str[dot_pos + 1..]) } @@ -2467,15 +2465,15 @@ fn parse_decimal_str(s: &str) -> Result<(i128, i32), String> { }; if integral_part.is_empty() && fractional_part.is_empty() { - return Err("No digits found".to_string()); + return Err(invalid_value(s, from_type, to_type)); } if !integral_part.is_empty() && !integral_part.bytes().all(|b| b.is_ascii_digit()) { - return Err("Invalid integral part".to_string()); + return Err(invalid_value(s, from_type, to_type)); } if !fractional_part.is_empty() && !fractional_part.bytes().all(|b| b.is_ascii_digit()) { - return Err("Invalid fractional part".to_string()); + return Err(invalid_value(s, from_type, to_type)); } // Parse integral part @@ -2485,7 +2483,7 @@ fn parse_decimal_str(s: &str) -> Result<(i128, i32), String> { } else { integral_part .parse() - .map_err(|_| "Invalid integral part".to_string())? + .map_err(|_| invalid_value(s, from_type, to_type))? }; // Parse fractional part @@ -2495,14 +2493,14 @@ fn parse_decimal_str(s: &str) -> Result<(i128, i32), String> { } else { fractional_part .parse() - .map_err(|_| "Invalid fractional part".to_string())? + .map_err(|_| invalid_value(s, from_type, to_type))? }; // Combine: value = integral * 10^fractional_scale + fractional let mantissa = integral_value .checked_mul(10_i128.pow(fractional_scale as u32)) .and_then(|v| v.checked_add(fractional_value)) - .ok_or("Overflow in mantissa calculation")?; + .ok_or_else(|| invalid_value(s, from_type, to_type))?; let final_mantissa = if negative { -mantissa } else { mantissa }; // final scale = fractional_scale - exponent From 61987e19fbe442523535fb50d44ea4a748943876 Mon Sep 17 00:00:00 2001 From: B Vadlamani Date: Thu, 22 Jan 2026 14:56:36 -0800 Subject: [PATCH 3/4] fix_issue --- .../scala/org/apache/comet/CometCastSuite.scala | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/spark/src/test/scala/org/apache/comet/CometCastSuite.scala b/spark/src/test/scala/org/apache/comet/CometCastSuite.scala index 4900bca14a..cfad751c0a 100644 --- a/spark/src/test/scala/org/apache/comet/CometCastSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometCastSuite.scala @@ -720,7 +720,7 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper { test("cast StringType to DecimalType(2,2)") { withSQLConf(CometConf.getExprAllowIncompatConfigKey(classOf[Cast]) -> "true") { println("testing with simple input") - val values = Seq(" 3").toDF("a") + val values = gen.generateStrings(dataSize, numericPattern, 12).toDF("a") Seq(true, false).foreach(ansiEnabled => castTest(values, DataTypes.createDecimalType(2, 2), testAnsi = ansiEnabled)) } @@ -735,15 +735,21 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper { } } - test("cast StringType to DecimalType(38,10) high precision") { + test("cast StringType to DecimalType(38,10) high precision - check 0 mantissa") { withSQLConf(CometConf.getExprAllowIncompatConfigKey(classOf[Cast]) -> "true") { - // TODO fix for Spark 4.0.0 - assume(!isSpark40Plus) val values = Seq("0e31").toDF("a") Seq(true, false).foreach(ansiEnabled => castTest(values, DataTypes.createDecimalType(38, 10), testAnsi = ansiEnabled)) } - } + } + + test("cast StringType to DecimalType(38,10) high precision") { + withSQLConf(CometConf.getExprAllowIncompatConfigKey(classOf[Cast]) -> "true") { + val values = gen.generateStrings(dataSize, numericPattern, 38).toDF("a") + Seq(true, false).foreach(ansiEnabled => + castTest(values, DataTypes.createDecimalType(38, 10), testAnsi = ansiEnabled)) + } + } test("cast StringType to DecimalType(38,10) high precision - 0 mantissa") { withSQLConf(CometConf.getExprAllowIncompatConfigKey(classOf[Cast]) -> "true") { From de1618a3fb609d0367d32d24359531230da7f9ec Mon Sep 17 00:00:00 2001 From: B Vadlamani Date: Fri, 23 Jan 2026 10:55:47 -0800 Subject: [PATCH 4/4] fix_issue --- native/spark-expr/src/conversion_funcs/cast.rs | 5 ----- spark/src/test/scala/org/apache/comet/CometCastSuite.scala | 2 +- 2 files changed, 1 insertion(+), 6 deletions(-) diff --git a/native/spark-expr/src/conversion_funcs/cast.rs b/native/spark-expr/src/conversion_funcs/cast.rs index a1d3408b47..86bc75fce7 100644 --- a/native/spark-expr/src/conversion_funcs/cast.rs +++ b/native/spark-expr/src/conversion_funcs/cast.rs @@ -2354,11 +2354,6 @@ fn parse_string_to_decimal(s: &str, precision: u8, scale: i8) -> SparkResult "true") { - val values = Seq("0e31").toDF("a") + val values = Seq("0e31", "000e3375", "0e5887677", "0e40").toDF("a") Seq(true, false).foreach(ansiEnabled => castTest(values, DataTypes.createDecimalType(38, 10), testAnsi = ansiEnabled)) }