diff --git a/native-engine/datafusion-ext-plans/src/joins/bhj/semi_join.rs b/native-engine/datafusion-ext-plans/src/joins/bhj/semi_join.rs index 26841cee9..41ebcf6fd 100644 --- a/native-engine/datafusion-ext-plans/src/joins/bhj/semi_join.rs +++ b/native-engine/datafusion-ext-plans/src/joins/bhj/semi_join.rs @@ -189,11 +189,15 @@ impl Joiner for SemiJoiner

{ let mut hashes_idx = 0; for row_idx in 0..probed_batch.num_rows() { - if probed_valids + let key_is_valid = probed_valids .as_ref() .map(|nb| nb.is_valid(row_idx)) - .unwrap_or(true) - { + .unwrap_or(true); + if P.mode == Anti && P.probe_is_join_side && !key_is_valid { + probed_joined.set(row_idx, true); + continue; + } + if key_is_valid { let map_value = map_values[hashes_idx]; hashes_idx += 1; diff --git a/native-engine/datafusion-ext-plans/src/joins/test.rs b/native-engine/datafusion-ext-plans/src/joins/test.rs index 9125ed53e..671ecd732 100644 --- a/native-engine/datafusion-ext-plans/src/joins/test.rs +++ b/native-engine/datafusion-ext-plans/src/joins/test.rs @@ -600,6 +600,54 @@ mod tests { Ok(()) } + #[tokio::test(flavor = "multi_thread", worker_threads = 1)] + async fn join_anti_with_null_keys() -> Result<()> { + let left = build_table_i32_nullable( + ("a1", &vec![Some(1), Some(2), None, Some(4), Some(5)]), + ("b1", &vec![Some(4), Some(5), Some(6), None, Some(8)]), + ("c1", &vec![Some(7), Some(8), Some(9), Some(10), Some(11)]), + ); + let right = build_table_i32_nullable( + ("a2", &vec![Some(10), Some(20), Some(30)]), + ("b1", &vec![Some(4), Some(5), Some(7)]), + ("c2", &vec![Some(70), Some(80), Some(90)]), + ); + let on: JoinOn = vec![( + Arc::new(Column::new_with_schema("b1", &left.schema())?), + Arc::new(Column::new_with_schema("b1", &right.schema())?), + )]; + + for test_type in [BHJLeftProbed, SHJLeftProbed] { + let (_, batches) = + join_collect(test_type, left.clone(), right.clone(), on.clone(), LeftAnti).await?; + let expected = vec![ + "+----+----+----+", + "| a1 | b1 | c1 |", + "+----+----+----+", + "| | 6 | 9 |", + "| 5 | 8 | 11 |", + "+----+----+----+", + ]; + assert_batches_sorted_eq!(expected, &batches); + } + + for test_type in [SMJ, BHJRightProbed, SHJRightProbed] { + let (_, batches) = + join_collect(test_type, left.clone(), right.clone(), on.clone(), LeftAnti).await?; + let expected = vec![ + "+----+----+----+", + "| a1 | b1 | c1 |", + "+----+----+----+", + "| | 6 | 9 |", + "| 4 | | 10 |", + "| 5 | 8 | 11 |", + "+----+----+----+", + ]; + assert_batches_sorted_eq!(expected, &batches); + } + Ok(()) + } + #[tokio::test(flavor = "multi_thread", worker_threads = 1)] async fn join_with_duplicated_column_names() -> Result<()> { for test_type in ALL_TEST_TYPE {