从零入门 cuda 编程?🦴深入理解 sgemmNN 算法!
深入理解 sgemmNN 算法
时隔近一年,CUDA 系列的第三篇终于来了。上篇讲访存优化时留了一个尾巴——说好要专门写一篇讲 SGEMM 的。今天我们就来填这个坑,深入分析一个经典的、手写的 sgemmNN 内核,看看 Volkov 在 2008 年的那篇著名论文 《Benchmarking GPUs to tune dense linear algebra》 中到底做了什么。
从问题说起:为什么 SGEMM 如此重要?
SGEMM 的全称是 Single-precision GEneral Matrix Multiply,即单精度通用矩阵乘法。它的数学定义是:
$$C = \alpha \cdot A \times B + \beta \cdot C$$
其中 $A$ 的维度为 $m \times k$,$B$ 的维度为 $k \times n$,$C$ 的维度为 $m \times n$。$\alpha$ 和 $\beta$ 是标量。
这个操作几乎是所有线性代数计算的基石——卷积神经网络可以转化为矩阵乘法、Transformer 的自注意力机制的核心也是矩阵乘法、各种科学计算更不必说。正因如此,NVIDIA 的 cuBLAS 库对其做了极致的优化,而 Volkov 2008 年的工作则向我们揭示了这些优化背后的核心思想。
Volkov 2008:打破”cuBLAS 不可战胜”的神话
2008 年,Vasily Volkov 在 USENIX Workshop 上发表了一篇极具影响力的论文。在当时的主流认知是:手写的 CUDA 内核不可能超过 NVIDIA 官方优化的 cuBLAS。Volkov 不仅证明了这是错的,还给出了一个系统性的方法论:
- 充分利用寄存器:把数据尽可能地留在寄存器中,而不是反复访问共享内存或全局内存。
- 减少同步开销:通过巧妙的线程组织,消除不必要的
__syncthreads()。 - 以计算隐藏访存延迟:通过足够的并行度和指令级并行,让计算单元在等待数据时不至于空闲。
我们今天分析的 sgemmNN 内核,就是这些思想的直接体现。
内核代码全景
先贴出完整的核函数代码,然后我们逐行拆解:
1 | |
这段代码大约只有 40 行,但它所蕴含的优化思想极其丰富。接下来我们从 线程块组织、线程映射、共享内存布局、主循环计算 和 写回策略 五个维度来逐一解读。
线程块与线程的组织方式
每个线程块处理 16 × 64 的输出块
1 | |
这个内核启动的网格规模为 $\lceil m/16 \rceil \times \lceil n/64 \rceil$,即:
gridDim.y等于输出矩阵 C 的行数除以 16(向上取整)gridDim.x等于输出矩阵 C 的列数除以 64(向上取整)
每个线程块负责计算 C 上的一个 16 行 × 64 列 的子块。选择 $(16, 64)$ 这个尺寸是有讲究的:
- 16 是半个 warp 的大小,也是共享内存
bs的行数 - 64 列对应 4 个 warp 的连续访存宽度,刚好与 $4 \times 16$ 的线程组织对应
每个线程处理一个长度为 16 的列向量
输出块是 $16 \times 64$,而一个线程块只包含 64 个线程。这意味着每个线程负责输出块中的 一列 16 个元素,即上面代码中 float c[16] 这个数组。
64 个线程 × 每个线程 16 个输出元素 = 1024 个浮点数 = 一个线程块的完整输出。
线程以 4 × 16 的方式组织
1 | |
64 个线程被组织为 blockDim.x = 16、blockDim.y = 4 的二维块。这个选择是经过深思熟虑的。我们可以换一个角度来看待这个问题:将 64 个线程看成 2 个 warp(每个 warp 32 线程),但实际上的组织是将 64 个线程映射到输出列的 64 个不同位置上。
指针初始化的精确解读
这是整个内核中最容易让人困惑的部分。让我们逐个拆解 A、B、C 三个指针的初始化过程。
C 矩阵:定位输出位置
1 | |
把它拆成两部分看:
- 列偏移:
blockIdx.x * 64 + threadIdx.x—— 当前线程负责输出矩阵中的哪一列 - 行偏移:
(threadIdx.y + blockIdx.y * ldc) * 16—— 注意这里乘的是ldc(leading dimension of C),而不是 16。这是因为 C 是 行主序 存储的(虽然 CUDA 中通常用列主序的 BLAS 约定,但这里内核实现是以行主序处理的),ldc * 16表示跳过了 16 行(每行ldc个元素),而threadIdx.y * 16表示在这个块的 16 行中,当前线程负责第threadIdx.y * 16行开始的连续 16 个元素。
所以每个线程通过 C += ldc 循环 16 次,写回 C[0]、C[ldc]、C[2*ldc]、…… 共 16 个元素,即 C 中 同一列、连续 16 行 的列向量。
A 矩阵:与 C 列对齐
1 | |
A 和 C 有着相同的列偏移计算方式。这是因为在矩阵乘法 $C = A \times B$ 中,C 的每一列都是由 A 的 所有列 乘以 B 的 对应行 累加而来。当线程固定负责 C 的第 col 列时,它需要从 A 中读取的也正是第 col 列的所有行。
这里 A += lda 是在主循环中逐行移动 A 的指针,对应 A 的不同行(即 k 维度上的不同位置)。
B 矩阵:与 C 行对齐
1 | |
blockIdx.y * 16—— 当前线程块负责的输出块的首行在 B 中的对应位置(注意是* ldb,跳到对应的行)threadIdx.y—— 这是细粒度的行偏移,即这个块内第几个 16 行组threadIdx.x—— 这个偏移是 B 矩阵的列偏移,对应 $C(:,col) = A(:,0) \times B(0,col) + A(:,1) \times B(1,col) + \dots$ 中的B(row_of_B, col_of_C)中的col_of_C。
主循环中的 B += 16 表示每次迭代跳过 B 的 16 行(即沿 k 维度前进 16 步),读取下一个 16 × 64 的 B 子块。
共享内存布局:bs[16][17] 的玄机
1 | |
注意第二维是 17 而不是 16。这是一个经典技巧——pad 一列以避免 bank conflict。
回顾一下共享内存的 bank 映射规则:连续的 4 字节字映射到连续的 bank,即 bs[i][j] 的 bank 编号为 (j * 17 + i) % 32。如果没有 padding,即使用 bs[16][16],那么 bs[i][0] 和 bs[i+1][0] 之间的偏移是 16 个字 = 64 字节,对应 bank 偏移为 16 % 32 = 16。但问题在于,当同一个 warp 的线程同时访问 bs[threadIdx.x][j] 时(其中 threadIdx.x 在 [0, 15] 范围内变化),访问的是同一行中的不同列:
| 线程 (tid.x) | 访问地址 |
|---|---|
| 0 | bs[0][j] → bank (j*17+0) % 32 |
| 1 | bs[1][j] → bank (j*17+1) % 32 |
| … | … |
| 15 | bs[15][j] → bank (j*17+15) % 32 |
如果使用 bs[16][16],则 j 相同的列上的元素会映射到相同 bank,从而引发冲突。而 +1 padding 让列间偏移变成了 17 × 4 = 68 字节,使得 连续的行映射到连续的 bank,从而避免了冲突。
主循环:分块计算的核心
主循环的结构是一个 do-while:
1 | |
阶段 1:加载 B
1 | |
这 64 个线程共同协作,将 B 的一个 16 × 64 的子块加载到共享内存中。每个线程负责加载 i += 4 循环的 4 次迭代,共 4 个元素。64 个线程 × 4 个元素 = 256 个浮点数 = 一个完整的 16 × 16 的 B 子块。
每次循环加载一个 16 × 16 的 B 子块,4 次迭代共加载 16 × 64 的 B 子块。
有人可能会问:每次加载的 B[i*ldb] 中的 ldb 表示的是 B 中一行的跨度,而 B 在每次 do-while 循环中会前进 16(B += 16),这里为什么有 B[i*ldb]?因为在 k 维度上的步进是 16 一行一行地在前进,而 i*ldb 是在当前这个 16 × 64 的块中,访问 B 的不同行。
阶段 2:计算
1 | |
这是整个内核的计算核心。对于每个 i(0 到 15):
- **
A[0]**:从全局内存读取 A 的当前元素,即 A 矩阵的第col列的第row_in_block * 16 + i行 - **
bs[i][0..15]**:从共享内存读取 B 子块的对应行
对于固定的 i,bs[i][0..15] 这 16 个数字将被当前 block 中每行(也就是有相同 threadIdx.x 的线程)的 16 个线程共享。但当前线程只需要读取 bs[i][j] 中的 16 个数字,j 由 threadIdx.x 决定。
每个 do-while 迭代内,内层循环执行 16 次,每次执行 16 个乘加运算(1 个 A[0] 乘以 16 个不同的 bs[i][j]),共 256 次乘加运算。同时,这 256 次运算中:
- 从全局内存读取了 16 个 A 的元素(
A[0]在每次内层循环中被读取) - 从共享内存读取了 256 个 B 的元素(
bs[i][0..15]在每次内层循环中被读取)
等等,这里有个细节:实际上 A[0] 在 16 次内层循环迭代中是从全局内存读取的,每个 A 的元素被重用了 16 次(对应 bs[i][0..15] 的 16 个值)。反过来,B 的元素通过共享内存被加载到 bs 后,在随后的内层循环中被 block 中具有相同 threadIdx.x 的所有线程(即同一列的所有线程)访问。
共享内存在这里起的关键作用是:将 B 的全局内存访问从非合并的转变为合并的,并且 让同一个 warp 的线程能够共享 B 的数据。
数据复用比率
在一个 do-while 迭代中:
- 从全局内存读取 A:16 个浮点数(
A[0]被连续读 16 次,每次 i 迭代读一次) - 从共享内存读取 B:16 × 16 = 256 个浮点数(但已经在阶段 1 加载到共享内存)
- 从全局内存写 C:0(写回在最后 epilogue 中)
对于全局内存而言,每个 do-while 迭代只读 256 个浮点数(16 个 A + 通过共享内存加载的 256 个 B),却完成了 256 次 FMA,数据复用比率为 1:1。
实际上,更精确的计算是:每 256 次 FMA,只从全局内存读取 A 的 16 个元素(B 通过共享内存中转,但最终还是从全局内存读取了 256 个元素到共享内存)。所以每 256 次 FMA,产生了 272 个全局内存读取。然而,B 的子块一旦加载到共享内存,可以被后续循环中的多个线程访问,所以实际的数据复用率远高于表面数字。
写回阶段:alpha 和 beta 的应用
1 | |
这是 epilogue 阶段。每个线程将其寄存器中的 16 个部分和 c[0..15] 写回全局内存:
c[i]乘以alpha- 加上原来的
C[0]乘以beta - 结果写回
C[0]
这是一个 融合的乘加更新(fused multiply-add update),在一次操作中同时完成缩放和累加。
与 Volkov 论文的对应
回顾 Volkov 论文的核心观点,看看这个内核是如何体现的:
| 论文观点 | 内核中的体现 |
|---|---|
| 最大化计算/访存比 | 每个线程有 16 个寄存器累加器,FMA 与访存的比值很高 |
| 减少同步 | 每个 do-while 迭代只需要两次 __syncthreads() |
| 合并访存 | A 的访问通过 A[0] 是合并的;B 通过共享内存转储再访问 |
| 利用寄存器,避免共享内存 | 16 个 c[i] 全部在寄存器中,没有中间结果的共享内存存储 |
| Bank conflict 避免 | bs[16][17] 的 padding 技巧 |
进一步的思考
这个内核虽然已经非常高效,但它并不是 SGEMM 优化的终点。后续的改进方向包括:
双缓冲(Double Buffering):使用两份共享内存,让加载和计算可以重叠,消除
__syncthreads()带来的等待。更大的 tile 尺寸:现代 GPU(如 Volta、Ampere、Hopper)有更多的寄存器和更大的共享内存,可以使用 32 × 128、64 × 64 甚至更大的 tile 尺寸。
Warp-level 矩阵分块(WMMA):从 Volta 架构开始,NVIDIA 引入了 Tensor Core 和 warp-level 矩阵乘法指令,可以一个 warp 在一条指令中完成 16 × 16 × 16 的矩阵乘法。
自动化调优:使用 Auto-Tuning 框架(如 Ansor、AutoTVM)自动搜索最优的 tile 尺寸、循环展开因子等参数。
但无论如何,这个 40 行的内核已经抓住了 SGEMM 优化的灵魂——通过精心设计的线程组织和数据流,最大化计算吞吐量、最小化片外内存访问。理解了它,你就理解了高性能计算中矩阵乘法优化的精髓。
这一篇就到这里,下一篇我们可能会讲 执行配置优化(occupancy、grid size、block size 的选择),也可能是 双缓冲技术 或者 WMMA/Tensor Core 编程。随缘吧。鼓掌👏👏👏