mergeSort

发布时间 2023-12-08 16:08:23作者: 谋杀肚腩

本篇介绍 cuda samples 中的 mergeSort.

大体上来讲, mergeSort 分为两个阶段.

  1. 对含有 SHARED_SIZE_LIMIT (即 1024) 个元素的数组进行排序.
  2. 合并多个有序数组.

其中第一个阶段调用一次函数 mergeSortShared 结束. 而第二个阶段需要循环调用三个函数: generateSampleRanks, mergeRanksAndIndicesmergeElementaryIntervals. 我将分别讲述这两个阶段.

第一阶段: 段内排序

函数 mergeSortShared 比较简单, 就是对完整的数据执行 mergeSortSharedKernel 函数. 其函数声明为

template <uint sortDir>
__global__ void mergeSortSharedKernel(uint *d_DstKey, uint *d_DstVal,
                                      uint *d_SrcKey, uint *d_SrcVal,
                                      uint arrayLength);

该 kernel 函数对一个长为 arrayLength 的数组排序, 在代码中 arrayLength 接收的参数为宏变量 SHARED_SIZE_LIMIT (即 1024). 调用该函数的 grid 和 block 大小分别为 N / arrayLengtharrayLength / 2. 每个 block 排序一段长为 arrayLength 的数组.

在函数 mergeSortSharedKernel 中排序通过一个循环进行. 该循环在长度为 stride 的数组已经排序基础上, 将 2 个长为 stride 的数组进行合并. 因此 stride 的大小从 1 开始每次增大 1 倍直到 arrayLength / 2. 在循环开始之前, 长度 stride 的数组已经是排序了的状态.

循环中数量为 stride 的 thread 为一组, 排序相邻的两个长为 stride 的数组, 记前一个数组为 A, 后一个数组为 B. 因此需要计算每一组 thread 内部的 thread ID 以及这组 thread 对应数据的起始位置.

// 计算组内 thread ID, [0, thread)
uint lPos = thread & (stride - 1);

// 计算该组 thread 对应数据的起始位置
// 每组 thread 处理连续的两个 thread 数组
uint *baseKey = s_key + 2 * (threadIdx.x - lPos);

相对应的, 每个 thread 在两个数组中都有一个对应的数据, 分别为 keyAkeyB.

uint keyA = baseKey[lPos + 0];
uint keyB = baseKey[lPos + stride];

接下来找到 keyAkeyB 在两个数组合并后的位置. 因为数组 A 和数组 B 都是排序的, 因此在数组 A 中有 lPos 个数比 keyA 小, 只需要找到在数组 B 中有多少个数比 keyA 小, 两者相加就得到了 keyA 合并后的位置. 这个寻找的过程可以通过二分法查找. keyB 同理也能找到合并后的位置. 但是这里有个问题就是某个 thread 对应的 keyAkeyB 相同, 那么这样计算得到的最后的位置也是相同的, 产生了冲突.

代码中解决这个可能的冲突的方法是 keyA 在数组 B 中寻找小于 keyA 的数, keyB 在数组 A 中寻找小于等于 keyB 的数. 这样做的好处是保持了排序的稳定性. 数组 A 的元素总是排在数组 B 的相同元素之前.

uint posA = binarySearchExclusive<sortDir>(keyA, baseKey + stride, stride, stride) + lPos;
uint posB = binarySearchInclusive<sortDir>(KeyB, baseKey + 0, stride, stride) + lPos;

函数 binarySearchExclusivebinarySearchInclusive 有着相似的结构, 只有内部的微弱区别. 通过二分法查找 val 在数组中的位置, 即返回数组中小于 (binarySearchExclusive) 或小于等于 (binarySearchInclusive) val 的元素的数量.

// val 待查找的元素
// data 数组起始位置
// L 数组长度
// stride >= L, 并行是 2 的整数幂
template <uint sortDir>
static inline __device__ uint binarySearchExclusive(uint val, uint *data,
                                                    uint L, uint stride) {
  if (L == 0) {
    return 0;
  }

  uint pos = 0;

  for (; stride > 0; stride >>= 1) {
    uint newPos = umin(pos + stride, L);

	// binarySearchInclusive 这里的 < 变成 <=
    if ((sortDir && (data[newPos - 1] < val)) ||
        (!sortDir && (data[newPos - 1] > val))) {
      pos = newPos;
    }
  }

  return pos;
}

merge sort 的第一个阶段就完成了. 该阶段将每 SHARED_SIZE_LIMIT (即 1024) 个元素进行排序, 得到了 N / SHARED_SIZE_LIMIT 个有序数组.

第二阶段: 合并有序段

第二个阶段循环合并连续 2 个长为 stride 的有序数组. 因此 stride 的大小从 SHARED_SIZE_LIMIT (即 1024) 开始每次增大 1 倍直到 N. 这个阶段主要分成三个部分: 生成排序, 合并排序和合并. 也就是调用三个函数 generateSampleRanks, mergeRanksAndIndicesmergeElementaryIntervals.

思路解析

要合并 2 个连续的有序数组, 当然可以像第一阶段一样启动足够多的线程查找每个元素合并后的位置. 但是这样的方法在后期的时延会增加很多, 因为 stride 长度倍增. 这里采用的方法是先分组再查找. 要理解这部分的代码, 需要理解数组 d_RanksA, d_RanksB, d_LimitsAd_LimitsB. 这个会在后面详细讲解.

先把 stride 分为大小为 SAMPLE_STRIDE 的子数组. d_RanksA 表示 2 个 stride 中所有的子数组的起始元素在 stride A 中的位置, 即 A 中小于 (或小于等于) 某个起始元素的元素数. 这一部分对应的就是函数 generateSampleRanks 的部分.

接着把 d_RanksA 排序, 将排序后的值填入对应的 d_Limits 中. 这一部分对应的就是函数 mergeRanksAndIndices 的部分.

经过前两步, 就把 AB 每个都分为了 2 \times stride / SAMPLE\_STRIDE 个范围. 之后合并对应的范围就得到 2 个 stride 合并后的结果. 这一部分对应的就是函数 mergeElementaryIntervals.

代码解析

generateSamplesRanks

这一步生成 d_RanksAd_RanksB. 通过 kernel 函数 generateSampleRanksKernel 实现的. 无论 stride 大小, 该 kernel 函数的总线程数都不变. 就是 \(N / (2 \times SAMPLE\_STRIDE)\).

// 计算总线程数量
uint lastSegmentElements = N % (2 * stride);
uint threadCount =
	(lastSegmentElements > stride)
	    ? (N + 2 * stride - lastSegmentElements) / (2 * SAMPLE_STRIDE)
        : (N - lastSegmentElements) / (2 * SAMPLE_STRIDE);

调用函数 generateSampleRanksKernel 的 grid 大小为 threadCount / 256, block 大小为 256. 在该 kernel 函数中, 把所有的线程分组. 每 stride / SAMPLE_STRIDE 个线程组成一组. 因此, 每个线程的全局 ID 和组内 ID 可以通过如下方式计算.

// 线程的全局 ID
uint pos = blockIdx.x * blockDim.x + threadIdx.x;

// stride / SAMPLE_STRIDE 个线程为一组
// 线程的组内 ID
const uint i = pos & ((stride / SAMPLE_STRIDE) - 1);

每个 thread 对应两个长为 SAMPLE_STRIDE 的数组以及 d_RanksAd_RanksB 中的两个元素, 就可以计算每个线程组的数据偏移量.

// 计算每个线程组的数据偏移量
const uint segmentBase = (pos - i) * (2 * SAMPLE_STRIDE);
d_SrcKey += segmentBase;
d_RanksA += segmentBase / SAMPLE_STRIDE;
d_RanksB += segmentBase / SAMPLE_STRIDE;

每个线程组对应连续 2 个长为 stride 的数组, 前一个数组称为数组 A, 后一个称为 B. 每个线程在 AB 中都有对应的一个长为 SAMPLE_STRIDE 的子数组. 就可以计算得到 AB 的长度以及子数组的个数.

// A 和 B 的长度
const uint segmentElementsA = stride;
const uint segmentElementsB = umin(stride, N - segmentBase - stride);

// 子数组的个数
const uint segmentSamplesA = getSampleCount(segmentElementsA);
const uint segmentSamplesB = getSampleCount(segmentElementsB);

最后求在 AB 两个数组中, 小于 (或者小于等于) 每个子数组起始元素的元素数. 这也是通过 binarySearchExclusivebinarySearchInclusive 两个 kernel 函数完成.

if (i < segmentSamplesA) {
	d_RanksA[i] = i * SAMPLE_STRIDE;
    d_RanksB[i] = binarySearchExclusive<sortDir>(
        d_SrcKey[i * SAMPLE_STRIDE], d_SrcKey + stride, segmentElementsB,
        nextPowerOfTwo(segmentElementsB));
}

if (i < segmentSamplesB) {
    d_RanksB[(stride / SAMPLE_STRIDE) + i] = i * SAMPLE_STRIDE;
    d_RanksA[(stride / SAMPLE_STRIDE) + i] = binarySearchInclusive<sortDir>(
        d_SrcKey[stride + i * SAMPLE_STRIDE], d_SrcKey + 0, segmentElementsA,
        nextPowerOfTwo(segmentElementsA));
}

mergeRanksAndIndices

这一步把 d_RanksAd_RanksB 中的内容排序后赋值给 d_LimitsAd_LimitsB.

AB 分别调用 kernel 函数 mergeRanksAndIndicesKernel, 通过 d_RanksA 生成 d_LimitsA. 该 kernel 函数有大小 <<<threadCount / 256, 256>>>. 其中 threadCount = N / (2 * SAMPLE_STRIDE).

同样 stride / SAMPLE_STRIDE 个 thread 为一组. 计算组内 ID 和数组偏移同之前一样.

// thread ID
uint pos = blockIdx.x * blockDim.x + threadIdx.x;

// 组内 ID
const uint i = pos & ((stride / SAMPLE_STRIDE) - 1);

// 数组偏移
const uint segmentBase = (pos - i) * (2 * SAMPLE_STRIDE);
d_Ranks += (pos - i) * 2;
d_Limits += (pos - i) * 2;

需要将 d_RanksA 排序. 对 A 中的每个 d_RanksA 查找其在 B 中的位置.

if (i < segmentSamplesA) {
    uint dstPos = binarySearchExclusive<1U>(
                      d_Ranks[i], d_Ranks + segmentSamplesA, segmentSamplesB,
                      nextPowerOfTwo(segmentSamplesB)) +
                  i;
    d_Limits[dstPos] = d_Ranks[i];
}

同样, 对 B 中的每个 d_RanksA 查找其在 A 中的位置.

if (i < segmentSamplesA) {
    uint dstPos = binarySearchExclusive<1U>(
                      d_Ranks[i], d_Ranks + segmentSamplesA, segmentSamplesB,
                      nextPowerOfTwo(segmentSamplesB)) +
                  i;
    d_Limits[dstPos] = d_Ranks[i];
}

经过排序, 实际上是把一个 stride 分为了 2 * stride / SAMPLE_STRIDE 个部分. 划分的标准是每个 SAMPLE_STRIDE 的起始元素在 AB 中的位置.

mergeElementaryIntervals

这一步合并 2 个 stride. 调用 kernel 函数 mergeElementaryIntervalsKernel, 其大小为 <<<N / SAMPLE_STRIDE, SAMPLE_STRIDE>>>. 之前已经把每个 stride 分为了 2 * SAMPLE_STRIDE 个部分. 每个 block 负责 2 个 stride 中对应的一部分, 将他们合并.

首先是吧所有的 block 分组, 每组有 2 * stride / SAMPLE_STRIDE 个 block. 再计算组内 ID 和数据偏移.

// 组内 ID
const uint intervalI = blockIdx.x & ((2 * stride) / SAMPLE_STRIDE - 1);

// 数据偏移
const uint segmentBase = (blockIdx.x - intervalI) * SAMPLE_STRIDE;
d_SrcKey += segmentBase;
d_SrcVal += segmentBase;
d_DstKey += segmentBase;
d_DstVal += segmentBase;

还需要计算合并后的起始位置, 该 block 的合并长度等数据.

__shared__ uint startSrcA, startSrcB, lenSrcA, lenSrcB, startDstA, startDstB;

// 本 block 读取数据的起始位置
startSrcA = d_LimitsA[blockIdx.x];
startSrcB = d_LimitsB[blockIdx.x];

// 本 block 读取数据的终止位置
// 下一个 block 读取数据的开始位置或者全 stride 长度
uint endSrcA = (intervalI + 1 < segmentSamples) ? d_LimitsA[blockIdx.x + 1]
                                                    : segmentElementsA;
uint endSrcB = (intervalI + 1 < segmentSamples) ? d_LimitsB[blockIdx.x + 1]
                                                    : segmentElementsB;

// 本 block 读取的数据长度
lenSrcA = endSrcA - startSrcA;
lenSrcB = endSrcB - startSrcB;
// 本 block 合并完需要写入的起始位置
// A 的起始位置就是之前 block 的合并的长度和
// B 的起始位置就是 A 的起始位置 + 长度
startDstA = startSrcA + startSrcB;
startDstB = startDstA + lenSrcA;

把 global 内存中的数据加载到 shared 内存.

if (threadIdx.x < lenSrcA) {
    s_key[threadIdx.x + 0] = d_SrcKey[0 + startSrcA + threadIdx.x];
    s_val[threadIdx.x + 0] = d_SrcVal[0 + startSrcA + threadIdx.x];
}

if (threadIdx.x < lenSrcB) {
    s_key[threadIdx.x + SAMPLE_STRIDE] =
        d_SrcKey[stride + startSrcB + threadIdx.x];
    s_val[threadIdx.x + SAMPLE_STRIDE] =
        d_SrcVal[stride + startSrcB + threadIdx.x];
}

调用 merge 函数合并 s_key 中的数据. merge 函数比较简单, 找到每个元素合并后的位置.

最后把数据存到 global 内存的对应位置.

if (threadIdx.x < lenSrcA) {
    d_DstKey[startDstA + threadIdx.x] = s_key[threadIdx.x];
    d_DstVal[startDstA + threadIdx.x] = s_val[threadIdx.x];
}

if (threadIdx.x < lenSrcB) {
    d_DstKey[startDstB + threadIdx.x] = s_key[lenSrcA + threadIdx.x];
    d_DstVal[startDstB + threadIdx.x] = s_val[lenSrcA + threadIdx.x];
}