Skip to content
10 changes: 7 additions & 3 deletions native-engine/datafusion-ext-plans/src/joins/bhj/semi_join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -189,11 +189,15 @@ impl<const P: JoinerParams> Joiner for SemiJoiner<P> {
let mut hashes_idx = 0;

for row_idx in 0..probed_batch.num_rows() {
if probed_valids
let key_is_valid = probed_valids
.as_ref()
.map(|nb| nb.is_valid(row_idx))
.unwrap_or(true)
{
.unwrap_or(true);
if P.mode == Anti && P.probe_is_join_side && !key_is_valid {
probed_joined.set(row_idx, true);
continue;
}
if key_is_valid {
let map_value = map_values[hashes_idx];
hashes_idx += 1;

Expand Down
48 changes: 48 additions & 0 deletions native-engine/datafusion-ext-plans/src/joins/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -600,6 +600,54 @@ mod tests {
Ok(())
}

#[tokio::test(flavor = "multi_thread", worker_threads = 1)]
async fn join_anti_with_null_keys() -> Result<()> {
let left = build_table_i32_nullable(
("a1", &vec![Some(1), Some(2), None, Some(4), Some(5)]),
("b1", &vec![Some(4), Some(5), Some(6), None, Some(8)]),
("c1", &vec![Some(7), Some(8), Some(9), Some(10), Some(11)]),
);
let right = build_table_i32_nullable(
("a2", &vec![Some(10), Some(20), Some(30)]),
("b1", &vec![Some(4), Some(5), Some(7)]),
("c2", &vec![Some(70), Some(80), Some(90)]),
);
let on: JoinOn = vec![(
Arc::new(Column::new_with_schema("b1", &left.schema())?),
Arc::new(Column::new_with_schema("b1", &right.schema())?),
)];

for test_type in [BHJLeftProbed, SHJLeftProbed] {
let (_, batches) =
join_collect(test_type, left.clone(), right.clone(), on.clone(), LeftAnti).await?;
let expected = vec![
"+----+----+----+",
"| a1 | b1 | c1 |",
"+----+----+----+",
"| | 6 | 9 |",
"| 5 | 8 | 11 |",
"+----+----+----+",
];
assert_batches_sorted_eq!(expected, &batches);
}

for test_type in [SMJ, BHJRightProbed, SHJRightProbed] {
let (_, batches) =
join_collect(test_type, left.clone(), right.clone(), on.clone(), LeftAnti).await?;
let expected = vec![
"+----+----+----+",
"| a1 | b1 | c1 |",
"+----+----+----+",
"| | 6 | 9 |",
"| 4 | | 10 |",
"| 5 | 8 | 11 |",
"+----+----+----+",
];
assert_batches_sorted_eq!(expected, &batches);
}
Ok(())
}

#[tokio::test(flavor = "multi_thread", worker_threads = 1)]
async fn join_with_duplicated_column_names() -> Result<()> {
for test_type in ALL_TEST_TYPE {
Expand Down
Loading