From 9b11f4ae8e3ae8165cf30d9f21103c581dd4d9c1 Mon Sep 17 00:00:00 2001 From: vic Date: Tue, 20 Jan 2026 16:29:29 +0100 Subject: [PATCH 1/3] Fix multi-GPU All Neighbors memory coherence issue on older platforms --- .../all_neighbors/all_neighbors_batched.cuh | 38 ++++++++++++++++++- 1 file changed, 37 insertions(+), 1 deletion(-) diff --git a/cpp/src/neighbors/all_neighbors/all_neighbors_batched.cuh b/cpp/src/neighbors/all_neighbors/all_neighbors_batched.cuh index ba286d89fe..c23a87ef81 100644 --- a/cpp/src/neighbors/all_neighbors/all_neighbors_batched.cuh +++ b/cpp/src/neighbors/all_neighbors/all_neighbors_batched.cuh @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ @@ -365,6 +365,9 @@ void multi_gpu_batch_build(const raft::resources& handle, raft::resource::get_cuda_stream(handle)); } + // Ensure all async copies complete before starting parallel region + raft::resource::sync_stream(handle); + #pragma omp parallel for num_threads(num_ranks) for (int rank = 0; rank < num_ranks; rank++) { auto dev_res = raft::resource::set_current_device_to_rank(handle, rank); @@ -528,6 +531,39 @@ void batch_build( reset_global_matrices(handle, params.metric, global_neighbors.view(), global_distances.view()); + // For multi-GPU: sync the stream to ensure fill completes before other GPUs access + // the managed memory. + if (raft::resource::is_multi_gpu(handle)) { + raft::resource::sync_stream(handle); + + // Check if any GPU is Turing (SM 7.5) or older. These architectures have issues with + // multi-GPU managed memory coherence for concurrent writes. Force CPU-resident memory + // to ensure all GPUs access through host memory, avoiding page migration issues. + // Ampere (SM 8.0+) and newer architectures handle this correctly. + int num_ranks = raft::resource::get_num_ranks(handle); + bool needs_workaround = false; + for (int rank = 0; rank < num_ranks; rank++) { + raft::resource::set_current_device_to_rank(handle, rank); + int device_id; + RAFT_CUDA_TRY(cudaGetDevice(&device_id)); + int major = 0; + RAFT_CUDA_TRY(cudaDeviceGetAttribute(&major, cudaDevAttrComputeCapabilityMajor, device_id)); + if (major < 8) { needs_workaround = true; } // Turing is 7.x, Ampere is 8.x + } + + if (needs_workaround) { + cudaMemLocation cpu_location = {cudaMemLocationTypeHost, 0}; + RAFT_CUDA_TRY(cudaMemAdvise(global_neighbors.data_handle(), + num_rows * k * sizeof(IdxT), + cudaMemAdviseSetPreferredLocation, + cpu_location)); + RAFT_CUDA_TRY(cudaMemAdvise(global_distances.data_handle(), + num_rows * k * sizeof(T), + cudaMemAdviseSetPreferredLocation, + cpu_location)); + } + } + if (raft::resource::is_multi_gpu(handle)) { multi_gpu_batch_build(handle, params, From 7f01efb0c8cba93166d11ffa6f47cf9c1508b7c8 Mon Sep 17 00:00:00 2001 From: vic Date: Fri, 23 Jan 2026 11:30:55 +0100 Subject: [PATCH 2/3] Answer reviews --- .../all_neighbors/all_neighbors_batched.cuh | 20 ++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/cpp/src/neighbors/all_neighbors/all_neighbors_batched.cuh b/cpp/src/neighbors/all_neighbors/all_neighbors_batched.cuh index c23a87ef81..9a99e2ca22 100644 --- a/cpp/src/neighbors/all_neighbors/all_neighbors_batched.cuh +++ b/cpp/src/neighbors/all_neighbors/all_neighbors_batched.cuh @@ -529,27 +529,27 @@ void batch_build( auto global_neighbors = raft::make_managed_matrix(handle, num_rows, k); auto global_distances = raft::make_managed_matrix(handle, num_rows, k); - reset_global_matrices(handle, params.metric, global_neighbors.view(), global_distances.view()); - - // For multi-GPU: sync the stream to ensure fill completes before other GPUs access - // the managed memory. if (raft::resource::is_multi_gpu(handle)) { - raft::resource::sync_stream(handle); - // Check if any GPU is Turing (SM 7.5) or older. These architectures have issues with // multi-GPU managed memory coherence for concurrent writes. Force CPU-resident memory // to ensure all GPUs access through host memory, avoiding page migration issues. // Ampere (SM 8.0+) and newer architectures handle this correctly. int num_ranks = raft::resource::get_num_ranks(handle); bool needs_workaround = false; + int original_device; + RAFT_CUDA_TRY(cudaGetDevice(&original_device)); for (int rank = 0; rank < num_ranks; rank++) { raft::resource::set_current_device_to_rank(handle, rank); int device_id; RAFT_CUDA_TRY(cudaGetDevice(&device_id)); int major = 0; RAFT_CUDA_TRY(cudaDeviceGetAttribute(&major, cudaDevAttrComputeCapabilityMajor, device_id)); - if (major < 8) { needs_workaround = true; } // Turing is 7.x, Ampere is 8.x + if (major < 8) { + needs_workaround = true; + break; + } // Turing is 7.x, Ampere is 8.x } + RAFT_CUDA_TRY(cudaSetDevice(original_device)); if (needs_workaround) { cudaMemLocation cpu_location = {cudaMemLocationTypeHost, 0}; @@ -564,7 +564,13 @@ void batch_build( } } + reset_global_matrices(handle, params.metric, global_neighbors.view(), global_distances.view()); + if (raft::resource::is_multi_gpu(handle)) { + // For multi-GPU: sync the stream to ensure fill completes before other GPUs access + // the managed memory. + raft::resource::sync_stream(handle); + multi_gpu_batch_build(handle, params, dataset, From 1fa54d5b81d6ec6c9dea38aa76a3e840b19f4275 Mon Sep 17 00:00:00 2001 From: vic Date: Fri, 23 Jan 2026 11:33:42 +0100 Subject: [PATCH 3/3] Adding log --- cpp/src/neighbors/all_neighbors/all_neighbors_batched.cuh | 1 + 1 file changed, 1 insertion(+) diff --git a/cpp/src/neighbors/all_neighbors/all_neighbors_batched.cuh b/cpp/src/neighbors/all_neighbors/all_neighbors_batched.cuh index 9a99e2ca22..1ec968f945 100644 --- a/cpp/src/neighbors/all_neighbors/all_neighbors_batched.cuh +++ b/cpp/src/neighbors/all_neighbors/all_neighbors_batched.cuh @@ -552,6 +552,7 @@ void batch_build( RAFT_CUDA_TRY(cudaSetDevice(original_device)); if (needs_workaround) { + RAFT_LOG_DEBUG("Applying managed memory workaround for pre-Ampere GPU architecture"); cudaMemLocation cpu_location = {cudaMemLocationTypeHost, 0}; RAFT_CUDA_TRY(cudaMemAdvise(global_neighbors.data_handle(), num_rows * k * sizeof(IdxT),