diff --git a/cpp/include/raft/matrix/detail/columnWiseSort.cuh b/cpp/include/raft/matrix/detail/columnWiseSort.cuh index e809e0a0ee..4bb9007b90 100644 --- a/cpp/include/raft/matrix/detail/columnWiseSort.cuh +++ b/cpp/include/raft/matrix/detail/columnWiseSort.cuh @@ -221,7 +221,12 @@ void sortColumnsPerRow(const InType* in, // device Segmented radix sort // 2^18 column cap to restrict size of workspace ~512 MB // will give better perf than below deviceWide Sort for even larger dims - int numSegments = n_rows + 1; + + // n_rows CSR-style segments aliased onto an n_rows + 1 offsets array, + // with begin = &offsets[0] and end = &offsets[1]. cub's `num_segments` + // is the actual segment count (n_rows). + int numSegments = n_rows; + int numOffsets = n_rows + 1; // need auxiliary storage: cub sorting + keys (if user not passing) + // staging for values out + segment partition @@ -248,8 +253,8 @@ void sortColumnsPerRow(const InType* in, // value in KV pair need to be passed in, out buffer is separate workspaceSize += raft::alignTo(sizeof(OutType) * (size_t)totalElements, memAlignWidth); - // for segment offsets - workspaceSize += raft::alignTo(sizeof(int) * (size_t)numSegments, memAlignWidth); + // for segment offsets (numOffsets = numSegments + 1, see above) + workspaceSize += raft::alignTo(sizeof(int) * (size_t)numOffsets, memAlignWidth); } else { size_t workspaceOffset = 0; @@ -264,14 +269,14 @@ void sortColumnsPerRow(const InType* in, workspacePtr = (void*)((size_t)workspacePtr + workspaceOffset); int* dSegmentOffsets = reinterpret_cast(workspacePtr); - workspaceOffset = raft::alignTo(sizeof(int) * (size_t)numSegments, memAlignWidth); + workspaceOffset = raft::alignTo(sizeof(int) * (size_t)numOffsets, memAlignWidth); workspacePtr = (void*)((size_t)workspacePtr + workspaceOffset); // layout idx RAFT_CUDA_TRY(layoutIdx(dValuesIn, n_rows, n_columns, stream)); // layout segment lengths - spread out column length - RAFT_CUDA_TRY(layoutSortOffset(dSegmentOffsets, n_columns, numSegments, stream)); + RAFT_CUDA_TRY(layoutSortOffset(dSegmentOffsets, n_columns, numOffsets, stream)); RAFT_CUDA_TRY(cub::DeviceSegmentedRadixSort::SortPairs(workspacePtr, workspaceSize,