Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 10 additions & 5 deletions cpp/include/raft/matrix/detail/columnWiseSort.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,12 @@ void sortColumnsPerRow(const InType* in,
// device Segmented radix sort
// 2^18 column cap to restrict size of workspace ~512 MB
// will give better perf than below deviceWide Sort for even larger dims
int numSegments = n_rows + 1;

// n_rows CSR-style segments aliased onto an n_rows + 1 offsets array,
// with begin = &offsets[0] and end = &offsets[1]. cub's `num_segments`
// is the actual segment count (n_rows).
int numSegments = n_rows;
int numOffsets = n_rows + 1;

// need auxiliary storage: cub sorting + keys (if user not passing) +
// staging for values out + segment partition
Expand All @@ -247,8 +252,8 @@ void sortColumnsPerRow(const InType* in,
// value in KV pair need to be passed in, out buffer is separate
workspaceSize += raft::alignTo(sizeof(OutType) * (size_t)totalElements, memAlignWidth);

// for segment offsets
workspaceSize += raft::alignTo(sizeof(int) * (size_t)numSegments, memAlignWidth);
// for segment offsets (numOffsets = numSegments + 1, see above)
workspaceSize += raft::alignTo(sizeof(int) * (size_t)numOffsets, memAlignWidth);
} else {
size_t workspaceOffset = 0;

Expand All @@ -263,14 +268,14 @@ void sortColumnsPerRow(const InType* in,
workspacePtr = (void*)((size_t)workspacePtr + workspaceOffset);

int* dSegmentOffsets = reinterpret_cast<int*>(workspacePtr);
workspaceOffset = raft::alignTo(sizeof(int) * (size_t)numSegments, memAlignWidth);
workspaceOffset = raft::alignTo(sizeof(int) * (size_t)numOffsets, memAlignWidth);
workspacePtr = (void*)((size_t)workspacePtr + workspaceOffset);

// layout idx
RAFT_CUDA_TRY(layoutIdx(dValuesIn, n_rows, n_columns, stream));

// layout segment lengths - spread out column length
RAFT_CUDA_TRY(layoutSortOffset(dSegmentOffsets, n_columns, numSegments, stream));
RAFT_CUDA_TRY(layoutSortOffset(dSegmentOffsets, n_columns, numOffsets, stream));

RAFT_CUDA_TRY(cub::DeviceSegmentedRadixSort::SortPairs(workspacePtr,
workspaceSize,
Expand Down
Loading