diff --git a/docs/source/user-guide/latest/compatibility.md b/docs/source/user-guide/latest/compatibility.md index 0ca6f8ea97..48c3601390 100644 --- a/docs/source/user-guide/latest/compatibility.md +++ b/docs/source/user-guide/latest/compatibility.md @@ -105,7 +105,6 @@ Cast operations in Comet fall into three levels of support: **Notes:** - - **decimal -> string**: There can be formatting differences in some case due to Spark using scientific notation where Comet does not - **double -> decimal**: There can be rounding differences - **double -> string**: There can be differences in precision. For example, the input "1.4E-45" will produce 1.0E-45 instead of 1.4E-45 @@ -113,7 +112,7 @@ Cast operations in Comet fall into three levels of support: - **float -> string**: There can be differences in precision. For example, the input "1.4E-45" will produce 1.0E-45 instead of 1.4E-45 - **string -> date**: Only supports years between 262143 BC and 262142 AD - **string -> decimal**: Does not support fullwidth unicode digits (e.g \\uFF10) - or strings containing null bytes (e.g \\u0000) +or strings containing null bytes (e.g \\u0000) - **string -> timestamp**: Not all valid formats are supported @@ -140,7 +139,6 @@ Cast operations in Comet fall into three levels of support: **Notes:** - - **decimal -> string**: There can be formatting differences in some case due to Spark using scientific notation where Comet does not - **double -> decimal**: There can be rounding differences - **double -> string**: There can be differences in precision. For example, the input "1.4E-45" will produce 1.0E-45 instead of 1.4E-45 @@ -148,7 +146,7 @@ Cast operations in Comet fall into three levels of support: - **float -> string**: There can be differences in precision. For example, the input "1.4E-45" will produce 1.0E-45 instead of 1.4E-45 - **string -> date**: Only supports years between 262143 BC and 262142 AD - **string -> decimal**: Does not support fullwidth unicode digits (e.g \\uFF10) - or strings containing null bytes (e.g \\u0000) +or strings containing null bytes (e.g \\u0000) - **string -> timestamp**: Not all valid formats are supported @@ -175,7 +173,6 @@ Cast operations in Comet fall into three levels of support: **Notes:** - - **decimal -> string**: There can be formatting differences in some case due to Spark using scientific notation where Comet does not - **double -> decimal**: There can be rounding differences - **double -> string**: There can be differences in precision. For example, the input "1.4E-45" will produce 1.0E-45 instead of 1.4E-45 @@ -183,7 +180,7 @@ Cast operations in Comet fall into three levels of support: - **float -> string**: There can be differences in precision. For example, the input "1.4E-45" will produce 1.0E-45 instead of 1.4E-45 - **string -> date**: Only supports years between 262143 BC and 262142 AD - **string -> decimal**: Does not support fullwidth unicode digits (e.g \\uFF10) - or strings containing null bytes (e.g \\u0000) +or strings containing null bytes (e.g \\u0000) - **string -> timestamp**: ANSI mode not supported diff --git a/native/Cargo.lock b/native/Cargo.lock index c2bfb84004..a3d315e86d 100644 --- a/native/Cargo.lock +++ b/native/Cargo.lock @@ -17,6 +17,41 @@ version = "2.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "320119579fcad9c21884f5c4861d16174d0e06250625266f50fe6898340abefa" +[[package]] +name = "aead" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d122413f284cf2d62fb1b7db97e02edb8cda96d769b16e443a4f6195e35662b0" +dependencies = [ + "crypto-common", + "generic-array", +] + +[[package]] +name = "aes" +version = "0.8.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b169f7a6d4742236a0a00c541b845991d0ac43e546831af1249753ab4c3aa3a0" +dependencies = [ + "cfg-if", + "cipher", + "cpufeatures", +] + +[[package]] +name = "aes-gcm" +version = "0.10.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "831010a0f742e1209b3bcea8fab6a8e149051ba6099432c8cb2cc117dec3ead1" +dependencies = [ + "aead", + "aes", + "cipher", + "ctr", + "ghash", + "subtle", +] + [[package]] name = "ahash" version = "0.7.8" @@ -1061,6 +1096,15 @@ dependencies = [ "generic-array", ] +[[package]] +name = "block-padding" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a8894febbff9f758034a5b8e12d87918f56dfc64a8e1fe757d65e29041538d93" +dependencies = [ + "generic-array", +] + [[package]] name = "blocking" version = "1.6.2" @@ -1233,6 +1277,15 @@ version = "0.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "37b2a672a2cb129a2e41c10b1224bb368f9f37a2b16b612598138befd7b37eb5" +[[package]] +name = "cbc" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "26b52a9543ae338f279b96b0b9fed9c8093744685043739079ce85cd58f289a6" +dependencies = [ + "cipher", +] + [[package]] name = "cc" version = "1.2.52" @@ -1323,6 +1376,16 @@ dependencies = [ "half", ] +[[package]] +name = "cipher" +version = "0.4.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "773f3b9af64447d2ce9850330c473515014aa235e6a783b02db81ff39e4a3dad" +dependencies = [ + "crypto-common", + "inout", +] + [[package]] name = "clang-sys" version = "1.8.1" @@ -1563,6 +1626,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "78c8292055d1c1df0cce5d180393dc8cce0abec0a7102adb6c7b1eef6016d60a" dependencies = [ "generic-array", + "rand_core 0.6.4", "typenum", ] @@ -1587,6 +1651,15 @@ dependencies = [ "memchr", ] +[[package]] +name = "ctr" +version = "0.9.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0369ee1ad671834580515889b80f2ea915f23b8be8d0daa4bbaf2ac5c7590835" +dependencies = [ + "cipher", +] + [[package]] name = "darling" version = "0.20.11" @@ -1871,12 +1944,17 @@ dependencies = [ name = "datafusion-comet-spark-expr" version = "0.13.0" dependencies = [ + "aes", + "aes-gcm", "arrow", "base64", + "cbc", "chrono", "chrono-tz", + "cipher", "criterion", "datafusion", + "ecb", "futures", "hex", "num", @@ -2577,6 +2655,15 @@ version = "1.0.20" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d0881ea181b1df73ff77ffaaf9c7544ecc11e82fba9b5f27b262a3c73a332555" +[[package]] +name = "ecb" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1a8bfa975b1aec2145850fcaa1c6fe269a16578c44705a532ae3edc92b8881c7" +dependencies = [ + "cipher", +] + [[package]] name = "either" version = "1.15.0" @@ -2899,6 +2986,16 @@ dependencies = [ "wasm-bindgen", ] +[[package]] +name = "ghash" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f0d8a4362ccb29cb0b265253fb0a2728f592895ee6854fd9bc13f2ffda266ff1" +dependencies = [ + "opaque-debug", + "polyval", +] + [[package]] name = "gimli" version = "0.32.3" @@ -3407,6 +3504,16 @@ dependencies = [ "str_stack", ] +[[package]] +name = "inout" +version = "0.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "879f10e63c20629ecabbb64a8010319738c66a5cd0c29b02d63d272b03751d01" +dependencies = [ + "block-padding", + "generic-array", +] + [[package]] name = "integer-encoding" version = "3.0.4" @@ -4078,6 +4185,12 @@ version = "11.1.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d6790f58c7ff633d8771f42965289203411a5e5c68388703c06e14f24770b41e" +[[package]] +name = "opaque-debug" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c08d65885ee38876c4f86fa503fb49d7b507c2b62552df7c70b2fce627e06381" + [[package]] name = "opendal" version = "0.55.0" @@ -4401,6 +4514,18 @@ dependencies = [ "windows-sys 0.61.2", ] +[[package]] +name = "polyval" +version = "0.6.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9d1fe60d06143b2430aa532c94cfe9e29783047f06c0d7fd359a9a51b729fa25" +dependencies = [ + "cfg-if", + "cpufeatures", + "opaque-debug", + "universal-hash", +] + [[package]] name = "portable-atomic" version = "1.11.1" @@ -6031,6 +6156,16 @@ version = "0.2.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ebc1c04c71510c7f702b52b7c350734c9ff1295c464a03335b00bb84fc54f853" +[[package]] +name = "universal-hash" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fc1de2c688dc15305988b563c3854064043356019f97a4b46276fe734c4f07ea" +dependencies = [ + "crypto-common", + "subtle", +] + [[package]] name = "unsafe-any-ors" version = "1.0.0" diff --git a/native/spark-expr/Cargo.toml b/native/spark-expr/Cargo.toml index 94653d8864..9d11a4424b 100644 --- a/native/spark-expr/Cargo.toml +++ b/native/spark-expr/Cargo.toml @@ -40,6 +40,11 @@ twox-hash = "2.1.2" rand = { workspace = true } hex = "0.4.3" base64 = "0.22.1" +aes = "0.8" +aes-gcm = "0.10" +cbc = { version = "0.1", features = ["alloc"] } +cipher = "0.4" +ecb = "0.1" [dev-dependencies] arrow = {workspace = true} diff --git a/native/spark-expr/src/comet_scalar_funcs.rs b/native/spark-expr/src/comet_scalar_funcs.rs index 8384a4646a..03dc71ddf4 100644 --- a/native/spark-expr/src/comet_scalar_funcs.rs +++ b/native/spark-expr/src/comet_scalar_funcs.rs @@ -15,6 +15,7 @@ // specific language governing permissions and limitations // under the License. +use crate::encryption_funcs::spark_aes_encrypt; use crate::hash_funcs::*; use crate::math_funcs::abs::abs; use crate::math_funcs::checked_arithmetic::{checked_add, checked_div, checked_mul, checked_sub}; @@ -165,6 +166,10 @@ pub fn create_comet_physical_fun_with_eval_mode( let func = Arc::new(spark_xxhash64); make_comet_scalar_udf!("xxhash64", func, without data_type) } + "aes_encrypt" => { + let func = Arc::new(spark_aes_encrypt); + make_comet_scalar_udf!("aes_encrypt", func, without data_type) + } "isnan" => { let func = Arc::new(spark_isnan); make_comet_scalar_udf!("isnan", func, without data_type) diff --git a/native/spark-expr/src/encryption_funcs/aes_encrypt.rs b/native/spark-expr/src/encryption_funcs/aes_encrypt.rs new file mode 100644 index 0000000000..db8df3f4ee --- /dev/null +++ b/native/spark-expr/src/encryption_funcs/aes_encrypt.rs @@ -0,0 +1,468 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::{Array, BinaryArray, BinaryBuilder, StringArray}; +use datafusion::common::{DataFusionError, Result, ScalarValue}; +use datafusion::logical_expr::ColumnarValue; +use std::sync::Arc; + +use super::cipher_modes::get_cipher_mode; + +pub fn spark_aes_encrypt(args: &[ColumnarValue]) -> Result { + if args.len() < 2 || args.len() > 6 { + return Err(DataFusionError::Execution(format!( + "aes_encrypt expects 2-6 arguments, got {}", + args.len() + ))); + } + + let mode_default = ColumnarValue::Scalar(ScalarValue::Utf8(Some("GCM".to_string()))); + let padding_default = ColumnarValue::Scalar(ScalarValue::Utf8(Some("DEFAULT".to_string()))); + let iv_default = ColumnarValue::Scalar(ScalarValue::Binary(Some(vec![]))); + let aad_default = ColumnarValue::Scalar(ScalarValue::Binary(Some(vec![]))); + + let input_arg = &args[0]; + let key_arg = &args[1]; + let mode_arg = args.get(2).unwrap_or(&mode_default); + let padding_arg = args.get(3).unwrap_or(&padding_default); + let iv_arg = args.get(4).unwrap_or(&iv_default); + let aad_arg = args.get(5).unwrap_or(&aad_default); + + let batch_size = get_batch_size(args)?; + + if batch_size == 1 { + encrypt_scalar(input_arg, key_arg, mode_arg, padding_arg, iv_arg, aad_arg) + } else { + encrypt_batch( + input_arg, + key_arg, + mode_arg, + padding_arg, + iv_arg, + aad_arg, + batch_size, + ) + } +} + +fn encrypt_scalar( + input_arg: &ColumnarValue, + key_arg: &ColumnarValue, + mode_arg: &ColumnarValue, + padding_arg: &ColumnarValue, + iv_arg: &ColumnarValue, + aad_arg: &ColumnarValue, +) -> Result { + let input = match input_arg { + ColumnarValue::Scalar(ScalarValue::Binary(opt)) => opt, + _ => return Err(DataFusionError::Execution("Invalid input type".to_string())), + }; + + let key = match key_arg { + ColumnarValue::Scalar(ScalarValue::Binary(opt)) => opt, + _ => return Err(DataFusionError::Execution("Invalid key type".to_string())), + }; + + if input.is_none() || key.is_none() { + return Ok(ColumnarValue::Scalar(ScalarValue::Binary(None))); + } + + let mode = match mode_arg { + ColumnarValue::Scalar(ScalarValue::Utf8(Some(s))) => s.as_str(), + _ => "GCM", + }; + + let padding = match padding_arg { + ColumnarValue::Scalar(ScalarValue::Utf8(Some(s))) => s.as_str(), + _ => "DEFAULT", + }; + + let iv = match iv_arg { + ColumnarValue::Scalar(ScalarValue::Binary(Some(v))) if !v.is_empty() => Some(v.as_slice()), + _ => None, + }; + + let aad = match aad_arg { + ColumnarValue::Scalar(ScalarValue::Binary(Some(v))) if !v.is_empty() => Some(v.as_slice()), + _ => None, + }; + + let cipher = get_cipher_mode(mode, padding)?; + + let encrypted = cipher + .encrypt(input.as_ref().unwrap(), key.as_ref().unwrap(), iv, aad) + .map_err(|e| DataFusionError::Execution(format!("{:?}", e)))?; + + Ok(ColumnarValue::Scalar(ScalarValue::Binary(Some(encrypted)))) +} + +fn encrypt_batch( + input_arg: &ColumnarValue, + key_arg: &ColumnarValue, + mode_arg: &ColumnarValue, + padding_arg: &ColumnarValue, + iv_arg: &ColumnarValue, + aad_arg: &ColumnarValue, + batch_size: usize, +) -> Result { + let input_array = to_binary_array(input_arg, batch_size)?; + let key_array = to_binary_array(key_arg, batch_size)?; + let mode_array = to_string_array(mode_arg, batch_size)?; + let padding_array = to_string_array(padding_arg, batch_size)?; + let iv_array = to_binary_array(iv_arg, batch_size)?; + let aad_array = to_binary_array(aad_arg, batch_size)?; + + let mut builder = BinaryBuilder::new(); + + for i in 0..batch_size { + if input_array.is_null(i) || key_array.is_null(i) { + builder.append_null(); + continue; + } + + let input = input_array.value(i); + let key = key_array.value(i); + let mode = mode_array.value(i); + let padding = padding_array.value(i); + let iv = if iv_array.is_null(i) || iv_array.value(i).is_empty() { + None + } else { + Some(iv_array.value(i)) + }; + let aad = if aad_array.is_null(i) || aad_array.value(i).is_empty() { + None + } else { + Some(aad_array.value(i)) + }; + + match get_cipher_mode(mode, padding) { + Ok(cipher) => match cipher.encrypt(input, key, iv, aad) { + Ok(encrypted) => builder.append_value(&encrypted), + Err(_) => builder.append_null(), + }, + Err(_) => builder.append_null(), + } + } + + Ok(ColumnarValue::Array(Arc::new(builder.finish()))) +} + +fn get_batch_size(args: &[ColumnarValue]) -> Result { + for arg in args { + if let ColumnarValue::Array(array) = arg { + return Ok(array.len()); + } + } + Ok(1) +} + +fn to_binary_array(col: &ColumnarValue, size: usize) -> Result { + match col { + ColumnarValue::Array(array) => Ok(array + .as_any() + .downcast_ref::() + .ok_or_else(|| DataFusionError::Execution("Expected binary array".to_string()))? + .clone()), + ColumnarValue::Scalar(ScalarValue::Binary(opt_val)) => { + let mut builder = BinaryBuilder::new(); + for _ in 0..size { + if let Some(val) = opt_val { + builder.append_value(val); + } else { + builder.append_null(); + } + } + Ok(builder.finish()) + } + _ => Err(DataFusionError::Execution( + "Invalid argument type".to_string(), + )), + } +} + +fn to_string_array(col: &ColumnarValue, size: usize) -> Result { + match col { + ColumnarValue::Array(array) => Ok(array + .as_any() + .downcast_ref::() + .ok_or_else(|| DataFusionError::Execution("Expected string array".to_string()))? + .clone()), + ColumnarValue::Scalar(ScalarValue::Utf8(opt_val)) => { + let val = opt_val.as_deref().unwrap_or("GCM"); + Ok(StringArray::from(vec![val; size])) + } + _ => Err(DataFusionError::Execution( + "Invalid argument type".to_string(), + )), + } +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow::array::BinaryArray; + use datafusion::common::ScalarValue; + use std::sync::Arc; + + #[test] + fn test_aes_encrypt_basic_gcm() { + let input = ScalarValue::Binary(Some(b"Spark".to_vec())); + let key = ScalarValue::Binary(Some(b"0000111122223333".to_vec())); + + let args = vec![ColumnarValue::Scalar(input), ColumnarValue::Scalar(key)]; + + let result = spark_aes_encrypt(&args); + assert!(result.is_ok()); + + if let ColumnarValue::Scalar(ScalarValue::Binary(Some(encrypted))) = result.unwrap() { + assert!(encrypted.len() > 12); + } else { + panic!("Expected binary scalar result"); + } + } + + #[test] + fn test_aes_encrypt_with_mode() { + let input = ScalarValue::Binary(Some(b"Spark SQL".to_vec())); + let key = ScalarValue::Binary(Some(b"1234567890abcdef".to_vec())); + let mode = ScalarValue::Utf8(Some("ECB".to_string())); + + let args = vec![ + ColumnarValue::Scalar(input), + ColumnarValue::Scalar(key), + ColumnarValue::Scalar(mode), + ]; + + let result = spark_aes_encrypt(&args); + assert!(result.is_ok()); + } + + #[test] + fn test_aes_encrypt_with_mode_padding() { + let input = ScalarValue::Binary(Some(b"test".to_vec())); + let key = ScalarValue::Binary(Some(b"1234567890abcdef".to_vec())); + let mode = ScalarValue::Utf8(Some("CBC".to_string())); + let padding = ScalarValue::Utf8(Some("PKCS".to_string())); + + let args = vec![ + ColumnarValue::Scalar(input), + ColumnarValue::Scalar(key), + ColumnarValue::Scalar(mode), + ColumnarValue::Scalar(padding), + ]; + + let result = spark_aes_encrypt(&args); + assert!(result.is_ok()); + } + + #[test] + fn test_aes_encrypt_with_iv() { + let input = ScalarValue::Binary(Some(b"Apache Spark".to_vec())); + let key = ScalarValue::Binary(Some(b"1234567890abcdef".to_vec())); + let mode = ScalarValue::Utf8(Some("CBC".to_string())); + let padding = ScalarValue::Utf8(Some("PKCS".to_string())); + let iv = ScalarValue::Binary(Some(vec![0u8; 16])); + + let args = vec![ + ColumnarValue::Scalar(input), + ColumnarValue::Scalar(key), + ColumnarValue::Scalar(mode), + ColumnarValue::Scalar(padding), + ColumnarValue::Scalar(iv.clone()), + ]; + + let result = spark_aes_encrypt(&args); + assert!(result.is_ok()); + + if let ColumnarValue::Scalar(ScalarValue::Binary(Some(encrypted))) = result.unwrap() { + if let ScalarValue::Binary(Some(iv_bytes)) = iv { + assert_eq!(&encrypted[..16], &iv_bytes[..]); + } + } + } + + #[test] + fn test_aes_encrypt_gcm_with_aad() { + let input = ScalarValue::Binary(Some(b"Spark".to_vec())); + let key = ScalarValue::Binary(Some(b"abcdefghijklmnop12345678ABCDEFGH".to_vec())); + let mode = ScalarValue::Utf8(Some("GCM".to_string())); + let padding = ScalarValue::Utf8(Some("DEFAULT".to_string())); + let iv = ScalarValue::Binary(Some(vec![0u8; 12])); + let aad = ScalarValue::Binary(Some(b"This is an AAD mixed into the input".to_vec())); + + let args = vec![ + ColumnarValue::Scalar(input), + ColumnarValue::Scalar(key), + ColumnarValue::Scalar(mode), + ColumnarValue::Scalar(padding), + ColumnarValue::Scalar(iv), + ColumnarValue::Scalar(aad), + ]; + + let result = spark_aes_encrypt(&args); + assert!(result.is_ok()); + } + + #[test] + fn test_aes_encrypt_invalid_key_length() { + let input = ScalarValue::Binary(Some(b"test".to_vec())); + let key = ScalarValue::Binary(Some(b"short".to_vec())); + + let args = vec![ColumnarValue::Scalar(input), ColumnarValue::Scalar(key)]; + + let result = spark_aes_encrypt(&args); + assert!( + result.is_err() + || matches!( + result.unwrap(), + ColumnarValue::Scalar(ScalarValue::Binary(None)) + ) + ); + } + + #[test] + fn test_aes_encrypt_null_input() { + let input = ScalarValue::Binary(None); + let key = ScalarValue::Binary(Some(b"0000111122223333".to_vec())); + + let args = vec![ColumnarValue::Scalar(input), ColumnarValue::Scalar(key)]; + + let result = spark_aes_encrypt(&args); + assert!(result.is_ok()); + assert!(matches!( + result.unwrap(), + ColumnarValue::Scalar(ScalarValue::Binary(None)) + )); + } + + #[test] + fn test_aes_encrypt_null_key() { + let input = ScalarValue::Binary(Some(b"test".to_vec())); + let key = ScalarValue::Binary(None); + + let args = vec![ColumnarValue::Scalar(input), ColumnarValue::Scalar(key)]; + + let result = spark_aes_encrypt(&args); + assert!(result.is_ok()); + assert!(matches!( + result.unwrap(), + ColumnarValue::Scalar(ScalarValue::Binary(None)) + )); + } + + #[test] + fn test_aes_encrypt_vectorized() { + let input_array = BinaryArray::from(vec![ + Some(b"message1".as_ref()), + Some(b"message2".as_ref()), + Some(b"message3".as_ref()), + ]); + let key_array = BinaryArray::from(vec![ + Some(b"key1key1key1key1".as_ref()), + Some(b"key2key2key2key2".as_ref()), + Some(b"key3key3key3key3".as_ref()), + ]); + + let args = vec![ + ColumnarValue::Array(Arc::new(input_array)), + ColumnarValue::Array(Arc::new(key_array)), + ]; + + let result = spark_aes_encrypt(&args); + assert!(result.is_ok()); + + if let ColumnarValue::Array(array) = result.unwrap() { + assert_eq!(array.len(), 3); + let binary_array = array.as_any().downcast_ref::().unwrap(); + for i in 0..3 { + assert!(!binary_array.is_null(i)); + assert!(!binary_array.value(i).is_empty()); + } + } else { + panic!("Expected array result"); + } + } + + #[test] + fn test_aes_encrypt_vectorized_with_nulls() { + let input_array = BinaryArray::from(vec![ + Some(b"message1".as_ref()), + None, + Some(b"message3".as_ref()), + ]); + let key_array = BinaryArray::from(vec![ + Some(b"key1key1key1key1".as_ref()), + Some(b"key2key2key2key2".as_ref()), + Some(b"key3key3key3key3".as_ref()), + ]); + + let args = vec![ + ColumnarValue::Array(Arc::new(input_array)), + ColumnarValue::Array(Arc::new(key_array)), + ]; + + let result = spark_aes_encrypt(&args); + assert!(result.is_ok()); + + if let ColumnarValue::Array(array) = result.unwrap() { + let binary_array = array.as_any().downcast_ref::().unwrap(); + assert!(!binary_array.is_null(0)); + assert!(binary_array.is_null(1)); + assert!(!binary_array.is_null(2)); + } + } + + #[test] + fn test_aes_encrypt_mixed_scalar_array() { + let input_array = + BinaryArray::from(vec![Some(b"message1".as_ref()), Some(b"message2".as_ref())]); + let key = ScalarValue::Binary(Some(b"0000111122223333".to_vec())); + + let args = vec![ + ColumnarValue::Array(Arc::new(input_array)), + ColumnarValue::Scalar(key), + ]; + + let result = spark_aes_encrypt(&args); + assert!(result.is_ok()); + + if let ColumnarValue::Array(array) = result.unwrap() { + assert_eq!(array.len(), 2); + } else { + panic!("Expected array result"); + } + } + + #[test] + fn test_aes_encrypt_too_few_args() { + let input = ScalarValue::Binary(Some(b"test".to_vec())); + let args = vec![ColumnarValue::Scalar(input)]; + + let result = spark_aes_encrypt(&args); + assert!(result.is_err()); + } + + #[test] + fn test_aes_encrypt_too_many_args() { + let args: Vec = (0..7) + .map(|_| ColumnarValue::Scalar(ScalarValue::Binary(Some(vec![0u8; 16])))) + .collect(); + + let result = spark_aes_encrypt(&args); + assert!(result.is_err()); + } +} diff --git a/native/spark-expr/src/encryption_funcs/cipher_modes.rs b/native/spark-expr/src/encryption_funcs/cipher_modes.rs new file mode 100644 index 0000000000..1e6a850b00 --- /dev/null +++ b/native/spark-expr/src/encryption_funcs/cipher_modes.rs @@ -0,0 +1,470 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use super::crypto_utils::{ + generate_random_iv, validate_iv_length, validate_key_length, CryptoError, +}; + +pub trait CipherMode: Send + Sync + std::fmt::Debug { + #[allow(dead_code)] + fn name(&self) -> &str; + #[allow(dead_code)] + fn iv_length(&self) -> usize; + #[allow(dead_code)] + fn supports_aad(&self) -> bool; + + fn encrypt( + &self, + input: &[u8], + key: &[u8], + iv: Option<&[u8]>, + aad: Option<&[u8]>, + ) -> Result, CryptoError>; +} + +#[derive(Debug)] +pub struct EcbMode; +#[derive(Debug)] +pub struct CbcMode; +#[derive(Debug)] +pub struct GcmMode; + +impl CipherMode for EcbMode { + fn name(&self) -> &str { + "ECB" + } + + fn iv_length(&self) -> usize { + 0 + } + + fn supports_aad(&self) -> bool { + false + } + + fn encrypt( + &self, + input: &[u8], + key: &[u8], + iv: Option<&[u8]>, + aad: Option<&[u8]>, + ) -> Result, CryptoError> { + use aes::{Aes128, Aes192, Aes256}; + use cipher::{block_padding::Pkcs7, BlockEncryptMut, KeyInit}; + use ecb::Encryptor; + + validate_key_length(key)?; + + if iv.is_some() { + return Err(CryptoError::UnsupportedIv("ECB".to_string())); + } + if aad.is_some() { + return Err(CryptoError::UnsupportedAad("ECB".to_string())); + } + + let encrypted = match key.len() { + 16 => { + let cipher = Encryptor::::new(key.into()); + cipher.encrypt_padded_vec_mut::(input) + } + 24 => { + let cipher = Encryptor::::new(key.into()); + cipher.encrypt_padded_vec_mut::(input) + } + 32 => { + let cipher = Encryptor::::new(key.into()); + cipher.encrypt_padded_vec_mut::(input) + } + _ => unreachable!("Key length validated above"), + }; + + Ok(encrypted) + } +} + +impl CipherMode for CbcMode { + fn name(&self) -> &str { + "CBC" + } + + fn iv_length(&self) -> usize { + 16 + } + + fn supports_aad(&self) -> bool { + false + } + + fn encrypt( + &self, + input: &[u8], + key: &[u8], + iv: Option<&[u8]>, + aad: Option<&[u8]>, + ) -> Result, CryptoError> { + use aes::{Aes128, Aes192, Aes256}; + use cbc::cipher::{block_padding::Pkcs7, BlockEncryptMut, KeyIvInit}; + use cbc::Encryptor; + + validate_key_length(key)?; + + if aad.is_some() { + return Err(CryptoError::UnsupportedAad("CBC".to_string())); + } + + let iv_bytes = match iv { + Some(iv) => { + validate_iv_length(iv, 16)?; + iv.to_vec() + } + None => generate_random_iv(16), + }; + + let ciphertext = match key.len() { + 16 => { + let cipher = Encryptor::::new(key.into(), iv_bytes.as_slice().into()); + cipher.encrypt_padded_vec_mut::(input) + } + 24 => { + let cipher = Encryptor::::new(key.into(), iv_bytes.as_slice().into()); + cipher.encrypt_padded_vec_mut::(input) + } + 32 => { + let cipher = Encryptor::::new(key.into(), iv_bytes.as_slice().into()); + cipher.encrypt_padded_vec_mut::(input) + } + _ => unreachable!("Key length validated above"), + }; + + let mut result = iv_bytes; + result.extend_from_slice(&ciphertext); + Ok(result) + } +} + +impl CipherMode for GcmMode { + fn name(&self) -> &str { + "GCM" + } + + fn iv_length(&self) -> usize { + 12 + } + + fn supports_aad(&self) -> bool { + true + } + + fn encrypt( + &self, + input: &[u8], + key: &[u8], + iv: Option<&[u8]>, + aad: Option<&[u8]>, + ) -> Result, CryptoError> { + use aes_gcm::aead::{Aead, Payload}; + use aes_gcm::{Aes128Gcm, Aes256Gcm, KeyInit, Nonce}; + + validate_key_length(key)?; + + let iv_bytes = match iv { + Some(iv) => { + validate_iv_length(iv, 12)?; + iv.to_vec() + } + None => generate_random_iv(12), + }; + + let nonce = Nonce::from_slice(&iv_bytes); + + let ciphertext = match key.len() { + 16 => { + let cipher = Aes128Gcm::new(key.into()); + let payload = match aad { + Some(aad_data) => Payload { + msg: input, + aad: aad_data, + }, + None => Payload { + msg: input, + aad: &[], + }, + }; + cipher + .encrypt(nonce, payload) + .map_err(|e| CryptoError::EncryptionFailed(e.to_string()))? + } + 24 | 32 => { + let cipher = Aes256Gcm::new(key.into()); + let payload = match aad { + Some(aad_data) => Payload { + msg: input, + aad: aad_data, + }, + None => Payload { + msg: input, + aad: &[], + }, + }; + cipher + .encrypt(nonce, payload) + .map_err(|e| CryptoError::EncryptionFailed(e.to_string()))? + } + _ => unreachable!("Key length validated above"), + }; + + let mut result = iv_bytes; + result.extend_from_slice(&ciphertext); + Ok(result) + } +} + +pub fn get_cipher_mode(mode: &str, padding: &str) -> Result, CryptoError> { + let mode_upper = mode.to_uppercase(); + let padding_upper = padding.to_uppercase(); + + match (mode_upper.as_str(), padding_upper.as_str()) { + ("ECB", "PKCS" | "DEFAULT") => Ok(Box::new(EcbMode)), + ("CBC", "PKCS" | "DEFAULT") => Ok(Box::new(CbcMode)), + ("GCM", "NONE" | "DEFAULT") => Ok(Box::new(GcmMode)), + _ => Err(CryptoError::UnsupportedMode( + mode.to_string(), + padding.to_string(), + )), + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_ecb_mode_properties() { + let mode = EcbMode; + assert_eq!(mode.name(), "ECB"); + assert_eq!(mode.iv_length(), 0); + assert!(!mode.supports_aad()); + } + + #[test] + fn test_ecb_encrypt_basic() { + let mode = EcbMode; + let input = b"Spark SQL"; + let key = b"1234567890abcdef"; + + let result = mode.encrypt(input, key, None, None); + assert!(result.is_ok()); + let encrypted = result.unwrap(); + assert!(!encrypted.is_empty()); + assert_ne!(&encrypted[..], input); + } + + #[test] + fn test_ecb_rejects_iv() { + let mode = EcbMode; + let input = b"test"; + let key = b"1234567890abcdef"; + let iv = vec![0u8; 16]; + + let result = mode.encrypt(input, key, Some(&iv), None); + assert!(result.is_err()); + assert!(matches!(result.unwrap_err(), CryptoError::UnsupportedIv(_))); + } + + #[test] + fn test_ecb_rejects_aad() { + let mode = EcbMode; + let input = b"test"; + let key = b"1234567890abcdef"; + let aad = b"metadata"; + + let result = mode.encrypt(input, key, None, Some(aad)); + assert!(result.is_err()); + assert!(matches!( + result.unwrap_err(), + CryptoError::UnsupportedAad(_) + )); + } + + #[test] + fn test_ecb_invalid_key() { + let mode = EcbMode; + let input = b"test"; + let key = b"short"; + + let result = mode.encrypt(input, key, None, None); + assert!(result.is_err()); + assert!(matches!( + result.unwrap_err(), + CryptoError::InvalidKeyLength(_) + )); + } + + #[test] + fn test_cbc_mode_properties() { + let mode = CbcMode; + assert_eq!(mode.name(), "CBC"); + assert_eq!(mode.iv_length(), 16); + assert!(!mode.supports_aad()); + } + + #[test] + fn test_cbc_encrypt_generates_iv() { + let mode = CbcMode; + let input = b"Apache Spark"; + let key = b"1234567890abcdef"; + + let result = mode.encrypt(input, key, None, None); + assert!(result.is_ok()); + let encrypted = result.unwrap(); + assert!(encrypted.len() > 16); + } + + #[test] + fn test_cbc_encrypt_with_provided_iv() { + let mode = CbcMode; + let input = b"test"; + let key = b"1234567890abcdef"; + let iv = vec![0u8; 16]; + + let result = mode.encrypt(input, key, Some(&iv), None); + assert!(result.is_ok()); + let encrypted = result.unwrap(); + assert_eq!(&encrypted[..16], &iv[..]); + } + + #[test] + fn test_cbc_rejects_aad() { + let mode = CbcMode; + let input = b"test"; + let key = b"1234567890abcdef"; + let aad = b"metadata"; + + let result = mode.encrypt(input, key, None, Some(aad)); + assert!(result.is_err()); + assert!(matches!( + result.unwrap_err(), + CryptoError::UnsupportedAad(_) + )); + } + + #[test] + fn test_cbc_invalid_iv_length() { + let mode = CbcMode; + let input = b"test"; + let key = b"1234567890abcdef"; + let iv = vec![0u8; 8]; + + let result = mode.encrypt(input, key, Some(&iv), None); + assert!(result.is_err()); + assert!(matches!( + result.unwrap_err(), + CryptoError::InvalidIvLength { .. } + )); + } + + #[test] + fn test_gcm_mode_properties() { + let mode = GcmMode; + assert_eq!(mode.name(), "GCM"); + assert_eq!(mode.iv_length(), 12); + assert!(mode.supports_aad()); + } + + #[test] + fn test_gcm_encrypt_generates_iv() { + let mode = GcmMode; + let input = b"Spark"; + let key = b"0000111122223333"; + + let result = mode.encrypt(input, key, None, None); + assert!(result.is_ok()); + let encrypted = result.unwrap(); + assert!(encrypted.len() > 12); + } + + #[test] + fn test_gcm_encrypt_with_aad() { + let mode = GcmMode; + let input = b"Spark"; + let key = b"abcdefghijklmnop12345678ABCDEFGH"; + let iv = vec![0u8; 12]; + let aad = b"This is an AAD mixed into the input"; + + let result = mode.encrypt(input, key, Some(&iv), Some(aad)); + assert!(result.is_ok()); + } + + #[test] + fn test_gcm_invalid_iv_length() { + let mode = GcmMode; + let input = b"test"; + let key = b"1234567890abcdef"; + let iv = vec![0u8; 16]; + + let result = mode.encrypt(input, key, Some(&iv), None); + assert!(result.is_err()); + assert!(matches!( + result.unwrap_err(), + CryptoError::InvalidIvLength { .. } + )); + } + + #[test] + fn test_get_cipher_mode_ecb() { + let mode = get_cipher_mode("ECB", "PKCS"); + assert!(mode.is_ok()); + assert_eq!(mode.unwrap().name(), "ECB"); + } + + #[test] + fn test_get_cipher_mode_cbc() { + let mode = get_cipher_mode("CBC", "PKCS"); + assert!(mode.is_ok()); + assert_eq!(mode.unwrap().name(), "CBC"); + } + + #[test] + fn test_get_cipher_mode_gcm() { + let mode = get_cipher_mode("GCM", "NONE"); + assert!(mode.is_ok()); + assert_eq!(mode.unwrap().name(), "GCM"); + } + + #[test] + fn test_get_cipher_mode_default_padding() { + let mode = get_cipher_mode("GCM", "DEFAULT"); + assert!(mode.is_ok()); + assert_eq!(mode.unwrap().name(), "GCM"); + } + + #[test] + fn test_get_cipher_mode_invalid() { + let mode = get_cipher_mode("CTR", "NONE"); + assert!(mode.is_err()); + assert!(matches!( + mode.unwrap_err(), + CryptoError::UnsupportedMode(_, _) + )); + } + + #[test] + fn test_get_cipher_mode_invalid_combination() { + let mode = get_cipher_mode("GCM", "PKCS"); + assert!(mode.is_err()); + } +} diff --git a/native/spark-expr/src/encryption_funcs/crypto_utils.rs b/native/spark-expr/src/encryption_funcs/crypto_utils.rs new file mode 100644 index 0000000000..3345f77ef7 --- /dev/null +++ b/native/spark-expr/src/encryption_funcs/crypto_utils.rs @@ -0,0 +1,159 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use datafusion::common::DataFusionError; + +#[derive(Debug, PartialEq)] +pub enum CryptoError { + InvalidKeyLength(usize), + InvalidIvLength { expected: usize, actual: usize }, + UnsupportedMode(String, String), + UnsupportedIv(String), + UnsupportedAad(String), + EncryptionFailed(String), +} + +impl From for DataFusionError { + fn from(err: CryptoError) -> Self { + DataFusionError::Execution(format!("{:?}", err)) + } +} + +pub fn validate_key_length(key: &[u8]) -> Result<(), CryptoError> { + match key.len() { + 16 | 24 | 32 => Ok(()), + len => Err(CryptoError::InvalidKeyLength(len)), + } +} + +pub fn generate_random_iv(length: usize) -> Vec { + use rand::Rng; + let mut iv = vec![0u8; length]; + rand::rng().fill(&mut iv[..]); + iv +} + +pub fn validate_iv_length(iv: &[u8], expected: usize) -> Result<(), CryptoError> { + if iv.len() == expected { + Ok(()) + } else { + Err(CryptoError::InvalidIvLength { + expected, + actual: iv.len(), + }) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_validate_key_length_valid_16() { + let key = vec![0u8; 16]; + assert!(validate_key_length(&key).is_ok()); + } + + #[test] + fn test_validate_key_length_valid_24() { + let key = vec![0u8; 24]; + assert!(validate_key_length(&key).is_ok()); + } + + #[test] + fn test_validate_key_length_valid_32() { + let key = vec![0u8; 32]; + assert!(validate_key_length(&key).is_ok()); + } + + #[test] + fn test_validate_key_length_invalid_short() { + let key = vec![0u8; 8]; + let result = validate_key_length(&key); + assert!(result.is_err()); + assert_eq!(result.unwrap_err(), CryptoError::InvalidKeyLength(8)); + } + + #[test] + fn test_validate_key_length_invalid_long() { + let key = vec![0u8; 64]; + let result = validate_key_length(&key); + assert!(result.is_err()); + assert_eq!(result.unwrap_err(), CryptoError::InvalidKeyLength(64)); + } + + #[test] + fn test_validate_key_length_invalid_zero() { + let key = vec![]; + let result = validate_key_length(&key); + assert!(result.is_err()); + assert_eq!(result.unwrap_err(), CryptoError::InvalidKeyLength(0)); + } + + #[test] + fn test_generate_random_iv_length_12() { + let iv = generate_random_iv(12); + assert_eq!(iv.len(), 12); + } + + #[test] + fn test_generate_random_iv_length_16() { + let iv = generate_random_iv(16); + assert_eq!(iv.len(), 16); + } + + #[test] + fn test_generate_random_iv_is_random() { + let iv1 = generate_random_iv(16); + let iv2 = generate_random_iv(16); + assert_ne!(iv1, iv2); + } + + #[test] + fn test_validate_iv_length_valid() { + let iv = vec![0u8; 16]; + assert!(validate_iv_length(&iv, 16).is_ok()); + } + + #[test] + fn test_validate_iv_length_too_short() { + let iv = vec![0u8; 8]; + let result = validate_iv_length(&iv, 16); + assert!(result.is_err()); + assert_eq!( + result.unwrap_err(), + CryptoError::InvalidIvLength { + expected: 16, + actual: 8 + } + ); + } + + #[test] + fn test_validate_iv_length_too_long() { + let iv = vec![0u8; 20]; + let result = validate_iv_length(&iv, 16); + assert!(result.is_err()); + assert_eq!( + result.unwrap_err(), + CryptoError::InvalidIvLength { + expected: 16, + actual: 20 + } + ); + } +} diff --git a/native/spark-expr/src/encryption_funcs/mod.rs b/native/spark-expr/src/encryption_funcs/mod.rs new file mode 100644 index 0000000000..21f1bb2e11 --- /dev/null +++ b/native/spark-expr/src/encryption_funcs/mod.rs @@ -0,0 +1,22 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +mod aes_encrypt; +mod cipher_modes; +mod crypto_utils; + +pub use aes_encrypt::spark_aes_encrypt; diff --git a/native/spark-expr/src/lib.rs b/native/spark-expr/src/lib.rs index f26fd911d8..0647c05477 100644 --- a/native/spark-expr/src/lib.rs +++ b/native/spark-expr/src/lib.rs @@ -42,6 +42,7 @@ mod agg_funcs; mod array_funcs; mod bitwise_funcs; mod comet_scalar_funcs; +pub mod encryption_funcs; pub mod hash_funcs; mod string_funcs; diff --git a/spark/.gitignore b/spark/.gitignore new file mode 100644 index 0000000000..ba9e3b3c2d --- /dev/null +++ b/spark/.gitignore @@ -0,0 +1 @@ +spark-warehouse diff --git a/spark/src/main/scala/org/apache/comet/serde/CometScalarFunction.scala b/spark/src/main/scala/org/apache/comet/serde/CometScalarFunction.scala index aa3bf775fb..66126db09e 100644 --- a/spark/src/main/scala/org/apache/comet/serde/CometScalarFunction.scala +++ b/spark/src/main/scala/org/apache/comet/serde/CometScalarFunction.scala @@ -22,13 +22,15 @@ package org.apache.comet.serde import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression} import org.apache.comet.serde.ExprOuterClass.Expr -import org.apache.comet.serde.QueryPlanSerde.{exprToProtoInternal, optExprWithInfo, scalarFunctionExprToProto} +import org.apache.comet.serde.QueryPlanSerde.{exprToProtoInternal, optExprWithInfo, scalarFunctionExprToProtoWithReturnType} /** Serde for scalar function. */ case class CometScalarFunction[T <: Expression](name: String) extends CometExpressionSerde[T] { override def convert(expr: T, inputs: Seq[Attribute], binding: Boolean): Option[Expr] = { val childExpr = expr.children.map(exprToProtoInternal(_, inputs, binding)) - val optExpr = scalarFunctionExprToProto(name, childExpr: _*) + // Pass return type to avoid native lookup in DataFusion registry + val optExpr = + scalarFunctionExprToProtoWithReturnType(name, expr.dataType, false, childExpr: _*) optExprWithInfo(optExpr, expr, expr.children: _*) } } diff --git a/spark/src/main/scala/org/apache/comet/serde/statics.scala b/spark/src/main/scala/org/apache/comet/serde/statics.scala index 0737644ab9..411b76fb94 100644 --- a/spark/src/main/scala/org/apache/comet/serde/statics.scala +++ b/spark/src/main/scala/org/apache/comet/serde/statics.scala @@ -19,7 +19,7 @@ package org.apache.comet.serde -import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.expressions.{Attribute, ExpressionImplUtils} import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke import org.apache.spark.sql.catalyst.util.CharVarcharCodegenUtils @@ -34,7 +34,8 @@ object CometStaticInvoke extends CometExpressionSerde[StaticInvoke] { : Map[(String, Class[_]), CometExpressionSerde[StaticInvoke]] = Map( ("readSidePadding", classOf[CharVarcharCodegenUtils]) -> CometScalarFunction( - "read_side_padding")) + "read_side_padding"), + ("aesEncrypt", classOf[ExpressionImplUtils]) -> CometScalarFunction("aes_encrypt")) override def convert( expr: StaticInvoke, diff --git a/spark/src/test/scala/org/apache/comet/CometEncryptionSuite.scala b/spark/src/test/scala/org/apache/comet/CometEncryptionSuite.scala new file mode 100644 index 0000000000..7ab3333951 --- /dev/null +++ b/spark/src/test/scala/org/apache/comet/CometEncryptionSuite.scala @@ -0,0 +1,65 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet + +import org.apache.spark.sql.CometTestBase + +class CometEncryptionSuite extends CometTestBase { + + test("aes_encrypt basic") { + withTable("t1") { + sql(""" + CREATE TABLE t1(data STRING, key STRING) USING parquet + """) + sql(""" + INSERT INTO t1 VALUES + ('Spark', '0000111122223333'), + ('SQL', 'abcdefghijklmnop') + """) + + val query = """ + SELECT + data, + hex(aes_encrypt(cast(data as binary), cast(key as binary))) as encrypted + FROM t1 + """ + + checkSparkAnswerAndOperator(query) + } + } + + test("aes_encrypt with mode") { + withTable("t1") { + sql(""" + CREATE TABLE t1(data STRING, key STRING) USING parquet + """) + sql(""" + INSERT INTO t1 VALUES ('test', '1234567890123456') + """) + + val query = """ + SELECT hex(aes_encrypt(cast(data as binary), cast(key as binary), 'GCM')) + FROM t1 + """ + + checkSparkAnswerAndOperator(query) + } + } +} diff --git a/spark/src/test/scala/org/apache/comet/CometStaticInvokeSuite.scala b/spark/src/test/scala/org/apache/comet/CometStaticInvokeSuite.scala new file mode 100644 index 0000000000..e3825a6a5f --- /dev/null +++ b/spark/src/test/scala/org/apache/comet/CometStaticInvokeSuite.scala @@ -0,0 +1,100 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet + +import org.apache.spark.sql.CometTestBase + +class CometStaticInvokeSuite extends CometTestBase { + + test("aes_encrypt basic - verify native execution") { + withTable("t1") { + sql("CREATE TABLE t1(data STRING, key STRING) USING parquet") + sql("""INSERT INTO t1 VALUES + ('Spark', '0000111122223333'), + ('SQL', 'abcdefghijklmnop')""") + + val query = """ + SELECT + data, + hex(aes_encrypt(cast(data as binary), cast(key as binary))) as encrypted + FROM t1 + """ + + checkSparkAnswerAndOperator(query) + + val df = sql(query) + val plan = df.queryExecution.executedPlan.toString + assert( + plan.contains("CometProject") || plan.contains("CometNative"), + s"Expected native execution but got Spark fallback:\n$plan") + } + } + + test("aes_encrypt with mode") { + withTable("t1") { + sql("CREATE TABLE t1(data STRING, key STRING) USING parquet") + sql("INSERT INTO t1 VALUES ('test', '1234567890123456')") + + val query = """ + SELECT hex(aes_encrypt(cast(data as binary), cast(key as binary), 'GCM')) + FROM t1 + """ + + checkSparkAnswerAndOperator(query) + } + } + + test("aes_encrypt with all parameters") { + withTable("t1") { + sql("CREATE TABLE t1(data STRING, key STRING) USING parquet") + sql("INSERT INTO t1 VALUES ('test', '1234567890123456')") + + val query = """ + SELECT hex(aes_encrypt( + cast(data as binary), + cast(key as binary), + 'GCM', + 'DEFAULT', + cast('initialization' as binary), + cast('additional' as binary) + )) + FROM t1 + """ + + checkSparkAnswerAndOperator(query) + } + } + + test("aes_encrypt wrapped in multiple functions") { + withTable("t1") { + sql("CREATE TABLE t1(data STRING, key STRING) USING parquet") + sql("INSERT INTO t1 VALUES ('test', '1234567890123456')") + + val query = """ + SELECT + upper(hex(aes_encrypt(cast(data as binary), cast(key as binary)))) as encrypted, + length(hex(aes_encrypt(cast(data as binary), cast(key as binary)))) as len + FROM t1 + """ + + checkSparkAnswerAndOperator(query) + } + } +} diff --git a/spark/src/test/scala/org/apache/spark/sql/benchmark/CometEncryptionBenchmark.scala b/spark/src/test/scala/org/apache/spark/sql/benchmark/CometEncryptionBenchmark.scala new file mode 100644 index 0000000000..e9792b0c07 --- /dev/null +++ b/spark/src/test/scala/org/apache/spark/sql/benchmark/CometEncryptionBenchmark.scala @@ -0,0 +1,96 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.benchmark + +/** + * Configuration for an encryption expression benchmark. + * @param name + * Name for the benchmark + * @param query + * SQL query to benchmark + * @param extraCometConfigs + * Additional Comet configurations for the scan+exec case + */ +case class EncryptionExprConfig( + name: String, + query: String, + extraCometConfigs: Map[String, String] = Map.empty) + +/** + * Benchmark to measure performance of Comet encryption expressions. To run this benchmark: + * {{{ + * SPARK_GENERATE_BENCHMARK_FILES=1 make benchmark-org.apache.spark.sql.benchmark.CometEncryptionBenchmark + * }}} + * Results will be written to "spark/benchmarks/CometEncryptionBenchmark-**results.txt". + */ +object CometEncryptionBenchmark extends CometBenchmarkBase { + + private val encryptionExpressions = List( + EncryptionExprConfig( + "aes_encrypt_gcm_basic", + "select hex(aes_encrypt(data, key)) from parquetV1Table"), + EncryptionExprConfig( + "aes_encrypt_gcm_with_mode", + "select hex(aes_encrypt(data, key, 'GCM')) from parquetV1Table"), + EncryptionExprConfig( + "aes_encrypt_cbc", + "select hex(aes_encrypt(data, key, 'CBC', 'PKCS')) from parquetV1Table"), + EncryptionExprConfig( + "aes_encrypt_ecb", + "select hex(aes_encrypt(data, key, 'ECB', 'PKCS')) from parquetV1Table"), + EncryptionExprConfig( + "aes_encrypt_gcm_with_iv", + "select hex(aes_encrypt(data, key, 'GCM', 'DEFAULT', iv)) from parquetV1Table"), + EncryptionExprConfig( + "aes_encrypt_gcm_with_aad", + "select hex(aes_encrypt(data, key, 'GCM', 'DEFAULT', iv, aad)) from parquetV1Table"), + EncryptionExprConfig( + "aes_encrypt_with_base64", + "select base64(aes_encrypt(data, key)) from parquetV1Table"), + EncryptionExprConfig( + "aes_encrypt_long_data", + "select hex(aes_encrypt(long_data, key)) from parquetV1Table")) + + override def runCometBenchmark(mainArgs: Array[String]): Unit = { + runBenchmarkWithTable("Encryption expressions", 100000) { v => + withTempPath { dir => + withTempTable("parquetV1Table") { + prepareTable( + dir, + spark.sql(s""" + SELECT + CAST(REPEAT(CAST(value AS STRING), 2) AS BINARY) AS data, + CAST('0000111122223333' AS BINARY) AS key, + CAST(unhex('000000000000000000000000') AS BINARY) AS iv, + CAST('This is AAD data' AS BINARY) AS aad, + CAST(REPEAT(CAST(value AS STRING), 100) AS BINARY) AS long_data + FROM $tbl + """)) + + encryptionExpressions.foreach { config => + runBenchmark(config.name) { + runExpressionBenchmark(config.name, v, config.query, config.extraCometConfigs) + } + } + } + } + } + } +}