diff --git a/native/spark-expr/src/conversion_funcs/cast.rs b/native/spark-expr/src/conversion_funcs/cast.rs index 2ff1d8c551..86bc75fce7 100644 --- a/native/spark-expr/src/conversion_funcs/cast.rs +++ b/native/spark-expr/src/conversion_funcs/cast.rs @@ -2348,72 +2348,81 @@ fn parse_string_to_decimal(s: &str, precision: u8, scale: i8) -> SparkResult { - // Convert to target scale - let target_scale = scale as i32; - let scale_adjustment = target_scale - exponent; - - let scaled_value = if scale_adjustment >= 0 { - // Need to multiply (increase scale) but return None if scale is too high to fit i128 - if scale_adjustment > 38 { - return Ok(None); - } - mantissa.checked_mul(10_i128.pow(scale_adjustment as u32)) - } else { - // Need to multiply (increase scale) but return None if scale is too high to fit i128 - let abs_scale_adjustment = (-scale_adjustment) as u32; - if abs_scale_adjustment > 38 { - return Ok(Some(0)); - } + let (mantissa, exponent) = parse_decimal_str( + trimmed, + "STRING", + &format!("DECIMAL({},{})", precision, scale), + )?; - let divisor = 10_i128.pow(abs_scale_adjustment); - let quotient_opt = mantissa.checked_div(divisor); - // Check if divisor is 0 - if quotient_opt.is_none() { - return Ok(None); - } - let quotient = quotient_opt.unwrap(); - let remainder = mantissa % divisor; - - // Round half up: if abs(remainder) >= divisor/2, round away from zero - let half_divisor = divisor / 2; - let rounded = if remainder.abs() >= half_divisor { - if mantissa >= 0 { - quotient + 1 - } else { - quotient - 1 - } - } else { - quotient - }; - Some(rounded) - }; + // scale adjustment + let target_scale = scale as i32; + let scale_adjustment = target_scale - exponent; - match scaled_value { - Some(value) => { - // Check if it fits target precision - if is_validate_decimal_precision(value, precision) { - Ok(Some(value)) - } else { - Ok(None) - } - } - None => { - // Overflow while scaling - Ok(None) - } + let scaled_value = if scale_adjustment >= 0 { + // Need to multiply (increase scale) but return None if scale is too high to fit i128 + if scale_adjustment > 38 { + return Ok(None); + } + mantissa.checked_mul(10_i128.pow(scale_adjustment as u32)) + } else { + // Need to divide (decrease scale) + let abs_scale_adjustment = (-scale_adjustment) as u32; + if abs_scale_adjustment > 38 { + return Ok(Some(0)); + } + + let divisor = 10_i128.pow(abs_scale_adjustment); + let quotient_opt = mantissa.checked_div(divisor); + // Check if divisor is 0 + if quotient_opt.is_none() { + return Ok(None); + } + let quotient = quotient_opt.unwrap(); + let remainder = mantissa % divisor; + + // Round half up: if abs(remainder) >= divisor/2, round away from zero + let half_divisor = divisor / 2; + let rounded = if remainder.abs() >= half_divisor { + if mantissa >= 0 { + quotient + 1 + } else { + quotient - 1 } + } else { + quotient + }; + Some(rounded) + }; + + match scaled_value { + Some(value) => { + if is_validate_decimal_precision(value, precision) { + Ok(Some(value)) + } else { + // Value ok but exceeds precision mentioned . THrow error + Err(SparkError::NumericValueOutOfRange { + value: trimmed.to_string(), + precision, + scale, + }) + } + } + None => { + // Overflow when scaling raise exception + Err(SparkError::NumericValueOutOfRange { + value: trimmed.to_string(), + precision, + scale, + }) } - Err(_) => Ok(None), } } /// Parse a decimal string into mantissa and scale /// e.g., "123.45" -> (12345, 2), "-0.001" -> (-1, 3) -fn parse_decimal_str(s: &str) -> Result<(i128, i32), String> { +fn parse_decimal_str(s: &str, from_type: &str, to_type: &str) -> SparkResult<(i128, i32)> { if s.is_empty() { - return Err("Empty string".to_string()); + return Err(invalid_value(s, from_type, to_type)); } let (mantissa_str, exponent) = if let Some(e_pos) = s.find(|c| ['e', 'E'].contains(&c)) { @@ -2422,7 +2431,7 @@ fn parse_decimal_str(s: &str) -> Result<(i128, i32), String> { // Parse exponent let exp: i32 = exponent_part .parse() - .map_err(|e| format!("Invalid exponent: {}", e))?; + .map_err(|_| invalid_value(s, from_type, to_type))?; (mantissa_part, exp) } else { @@ -2437,13 +2446,13 @@ fn parse_decimal_str(s: &str) -> Result<(i128, i32), String> { }; if mantissa_str.starts_with('+') || mantissa_str.starts_with('-') { - return Err("Invalid sign format".to_string()); + return Err(invalid_value(s, from_type, to_type)); } let (integral_part, fractional_part) = match mantissa_str.find('.') { Some(dot_pos) => { if mantissa_str[dot_pos + 1..].contains('.') { - return Err("Multiple decimal points".to_string()); + return Err(invalid_value(s, from_type, to_type)); } (&mantissa_str[..dot_pos], &mantissa_str[dot_pos + 1..]) } @@ -2451,15 +2460,15 @@ fn parse_decimal_str(s: &str) -> Result<(i128, i32), String> { }; if integral_part.is_empty() && fractional_part.is_empty() { - return Err("No digits found".to_string()); + return Err(invalid_value(s, from_type, to_type)); } if !integral_part.is_empty() && !integral_part.bytes().all(|b| b.is_ascii_digit()) { - return Err("Invalid integral part".to_string()); + return Err(invalid_value(s, from_type, to_type)); } if !fractional_part.is_empty() && !fractional_part.bytes().all(|b| b.is_ascii_digit()) { - return Err("Invalid fractional part".to_string()); + return Err(invalid_value(s, from_type, to_type)); } // Parse integral part @@ -2469,7 +2478,7 @@ fn parse_decimal_str(s: &str) -> Result<(i128, i32), String> { } else { integral_part .parse() - .map_err(|_| "Invalid integral part".to_string())? + .map_err(|_| invalid_value(s, from_type, to_type))? }; // Parse fractional part @@ -2479,14 +2488,14 @@ fn parse_decimal_str(s: &str) -> Result<(i128, i32), String> { } else { fractional_part .parse() - .map_err(|_| "Invalid fractional part".to_string())? + .map_err(|_| invalid_value(s, from_type, to_type))? }; // Combine: value = integral * 10^fractional_scale + fractional let mantissa = integral_value .checked_mul(10_i128.pow(fractional_scale as u32)) .and_then(|v| v.checked_add(fractional_value)) - .ok_or("Overflow in mantissa calculation")?; + .ok_or_else(|| invalid_value(s, from_type, to_type))?; let final_mantissa = if negative { -mantissa } else { mantissa }; // final scale = fractional_scale - exponent diff --git a/spark/src/test/scala/org/apache/comet/CometCastSuite.scala b/spark/src/test/scala/org/apache/comet/CometCastSuite.scala index 8a68df3820..097f8025f6 100644 --- a/spark/src/test/scala/org/apache/comet/CometCastSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometCastSuite.scala @@ -719,19 +719,43 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper { test("cast StringType to DecimalType(2,2)") { withSQLConf(CometConf.getExprAllowIncompatConfigKey(classOf[Cast]) -> "true") { - // TODO fix for Spark 4.0.0 - assume(!isSpark40Plus) + println("testing with simple input") val values = gen.generateStrings(dataSize, numericPattern, 12).toDF("a") Seq(true, false).foreach(ansiEnabled => castTest(values, DataTypes.createDecimalType(2, 2), testAnsi = ansiEnabled)) } } + test("cast StringType to DecimalType(2,2) check if right exception is being thrown") { + withSQLConf(CometConf.getExprAllowIncompatConfigKey(classOf[Cast]) -> "true") { + println("testing with simple input") + val values = Seq(" 3").toDF("a") + Seq(true, false).foreach(ansiEnabled => + castTest(values, DataTypes.createDecimalType(2, 2), testAnsi = ansiEnabled)) + } + } + + test("cast StringType to DecimalType(38,10) high precision - check 0 mantissa") { + withSQLConf(CometConf.getExprAllowIncompatConfigKey(classOf[Cast]) -> "true") { + val values = Seq("0e31", "000e3375", "0e5887677", "0e40").toDF("a") + Seq(true, false).foreach(ansiEnabled => + castTest(values, DataTypes.createDecimalType(38, 10), testAnsi = ansiEnabled)) + } + } + test("cast StringType to DecimalType(38,10) high precision") { + withSQLConf(CometConf.getExprAllowIncompatConfigKey(classOf[Cast]) -> "true") { + val values = gen.generateStrings(dataSize, numericPattern, 38).toDF("a") + Seq(true, false).foreach(ansiEnabled => + castTest(values, DataTypes.createDecimalType(38, 10), testAnsi = ansiEnabled)) + } + } + + test("cast StringType to DecimalType(38,10) high precision - 0 mantissa") { withSQLConf(CometConf.getExprAllowIncompatConfigKey(classOf[Cast]) -> "true") { // TODO fix for Spark 4.0.0 assume(!isSpark40Plus) - val values = gen.generateStrings(dataSize, numericPattern, 38).toDF("a") + val values = Seq("0e31").toDF("a") Seq(true, false).foreach(ansiEnabled => castTest(values, DataTypes.createDecimalType(38, 10), testAnsi = ansiEnabled)) }