Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 44 additions & 1 deletion cpp/src/neighbors/all_neighbors/all_neighbors_batched.cuh
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION.
* SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION.
* SPDX-License-Identifier: Apache-2.0
*/

Expand Down Expand Up @@ -365,6 +365,9 @@ void multi_gpu_batch_build(const raft::resources& handle,
raft::resource::get_cuda_stream(handle));
}

// Ensure all async copies complete before starting parallel region
raft::resource::sync_stream(handle);

#pragma omp parallel for num_threads(num_ranks)
for (int rank = 0; rank < num_ranks; rank++) {
auto dev_res = raft::resource::set_current_device_to_rank(handle, rank);
Expand Down Expand Up @@ -526,9 +529,49 @@ void batch_build(
auto global_neighbors = raft::make_managed_matrix<IdxT, IdxT>(handle, num_rows, k);
auto global_distances = raft::make_managed_matrix<T, IdxT>(handle, num_rows, k);

if (raft::resource::is_multi_gpu(handle)) {
// Check if any GPU is Turing (SM 7.5) or older. These architectures have issues with
// multi-GPU managed memory coherence for concurrent writes. Force CPU-resident memory
// to ensure all GPUs access through host memory, avoiding page migration issues.
// Ampere (SM 8.0+) and newer architectures handle this correctly.
int num_ranks = raft::resource::get_num_ranks(handle);
bool needs_workaround = false;
int original_device;
RAFT_CUDA_TRY(cudaGetDevice(&original_device));
for (int rank = 0; rank < num_ranks; rank++) {
raft::resource::set_current_device_to_rank(handle, rank);
int device_id;
RAFT_CUDA_TRY(cudaGetDevice(&device_id));
int major = 0;
RAFT_CUDA_TRY(cudaDeviceGetAttribute(&major, cudaDevAttrComputeCapabilityMajor, device_id));
if (major < 8) {
needs_workaround = true;
break;
} // Turing is 7.x, Ampere is 8.x
}
RAFT_CUDA_TRY(cudaSetDevice(original_device));

if (needs_workaround) {
RAFT_LOG_DEBUG("Applying managed memory workaround for pre-Ampere GPU architecture");
cudaMemLocation cpu_location = {cudaMemLocationTypeHost, 0};
RAFT_CUDA_TRY(cudaMemAdvise(global_neighbors.data_handle(),
num_rows * k * sizeof(IdxT),
cudaMemAdviseSetPreferredLocation,
cpu_location));
RAFT_CUDA_TRY(cudaMemAdvise(global_distances.data_handle(),
num_rows * k * sizeof(T),
cudaMemAdviseSetPreferredLocation,
cpu_location));
}
}

reset_global_matrices(handle, params.metric, global_neighbors.view(), global_distances.view());

if (raft::resource::is_multi_gpu(handle)) {
// For multi-GPU: sync the stream to ensure fill completes before other GPUs access
// the managed memory.
raft::resource::sync_stream(handle);

multi_gpu_batch_build(handle,
params,
dataset,
Expand Down
Loading