从零入门 cuda 编程?🦴深入理解 sgemmNN 算法!

深入理解 sgemmNN 算法

时隔近一年,CUDA 系列的第三篇终于来了。上篇讲访存优化时留了一个尾巴——说好要专门写一篇讲 SGEMM 的。今天我们就来填这个坑,深入分析一个经典的、手写的 sgemmNN 内核,看看 Volkov 在 2008 年的那篇著名论文 《Benchmarking GPUs to tune dense linear algebra》 中到底做了什么。

从问题说起:为什么 SGEMM 如此重要?

SGEMM 的全称是 Single-precision GEneral Matrix Multiply,即单精度通用矩阵乘法。它的数学定义是:

$$C = \alpha \cdot A \times B + \beta \cdot C$$

其中 $A$ 的维度为 $m \times k$,$B$ 的维度为 $k \times n$,$C$ 的维度为 $m \times n$。$\alpha$ 和 $\beta$ 是标量。

这个操作几乎是所有线性代数计算的基石——卷积神经网络可以转化为矩阵乘法、Transformer 的自注意力机制的核心也是矩阵乘法、各种科学计算更不必说。正因如此,NVIDIA 的 cuBLAS 库对其做了极致的优化,而 Volkov 2008 年的工作则向我们揭示了这些优化背后的核心思想。

Volkov 2008:打破”cuBLAS 不可战胜”的神话

2008 年,Vasily Volkov 在 USENIX Workshop 上发表了一篇极具影响力的论文。在当时的主流认知是:手写的 CUDA 内核不可能超过 NVIDIA 官方优化的 cuBLAS。Volkov 不仅证明了这是错的,还给出了一个系统性的方法论:

  1. 充分利用寄存器:把数据尽可能地留在寄存器中,而不是反复访问共享内存或全局内存。
  2. 减少同步开销:通过巧妙的线程组织,消除不必要的 __syncthreads()
  3. 以计算隐藏访存延迟:通过足够的并行度和指令级并行,让计算单元在等待数据时不至于空闲。

我们今天分析的 sgemmNN 内核,就是这些思想的直接体现。

内核代码全景

先贴出完整的核函数代码,然后我们逐行拆解:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
__global__ void sgemmNN(
const float *A, int lda,
const float *B, int ldb,
float *C, int ldc,
int k,
float alpha, float beta)
{
A += blockIdx.x * 64 + threadIdx.x + threadIdx.y * 16;
B += threadIdx.x + (blockIdx.y * 16 + threadIdx.y) * ldb;
C += blockIdx.x * 64 + threadIdx.x +
(threadIdx.y + blockIdx.y * ldc) * 16;

__shared__ float bs[16][17];
float c[16] = {
0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0
};
const float *Blast = B + k;

do {
#pragma unroll
for (int i = 0; i < 16; i += 4)
bs[threadIdx.x][threadIdx.y + i] = B[i * ldb];
B += 16;
__syncthreads();

#pragma unroll
for (int i = 0; i < 16; i++, A += lda) {
c[0] += A[0] * bs[i][0]; c[1] += A[0] * bs[i][1];
c[2] += A[0] * bs[i][2]; c[3] += A[0] * bs[i][3];
c[4] += A[0] * bs[i][4]; c[5] += A[0] * bs[i][5];
c[6] += A[0] * bs[i][6]; c[7] += A[0] * bs[i][7];
c[8] += A[0] * bs[i][8]; c[9] += A[0] * bs[i][9];
c[10] += A[0] * bs[i][10]; c[11] += A[0] * bs[i][11];
c[12] += A[0] * bs[i][12]; c[13] += A[0] * bs[i][13];
c[14] += A[0] * bs[i][14]; c[15] += A[0] * bs[i][15];
}
__syncthreads();
} while (B < Blast);

for (int i = 0; i < 16; i++, C += ldc)
C[0] = alpha * c[i] + beta * C[0];
}

这段代码大约只有 40 行,但它所蕴含的优化思想极其丰富。接下来我们从 线程块组织线程映射共享内存布局主循环计算写回策略 五个维度来逐一解读。

线程块与线程的组织方式

每个线程块处理 16 × 64 的输出块

1
A += blockIdx.x * 64 + threadIdx.x + threadIdx.y * 16;

这个内核启动的网格规模为 $\lceil m/16 \rceil \times \lceil n/64 \rceil$,即:

  • gridDim.y 等于输出矩阵 C 的行数除以 16(向上取整)
  • gridDim.x 等于输出矩阵 C 的列数除以 64(向上取整)

每个线程块负责计算 C 上的一个 16 行 × 64 列 的子块。选择 $(16, 64)$ 这个尺寸是有讲究的:

  • 16 是半个 warp 的大小,也是共享内存 bs 的行数
  • 64 列对应 4 个 warp 的连续访存宽度,刚好与 $4 \times 16$ 的线程组织对应

每个线程处理一个长度为 16 的列向量

输出块是 $16 \times 64$,而一个线程块只包含 64 个线程。这意味着每个线程负责输出块中的 一列 16 个元素,即上面代码中 float c[16] 这个数组。

64 个线程 × 每个线程 16 个输出元素 = 1024 个浮点数 = 一个线程块的完整输出。

线程以 4 × 16 的方式组织

1
dim3 blockDim(16, 4);  // threadIdx.x ∈ [0, 15], threadIdx.y ∈ [0, 3]

64 个线程被组织为 blockDim.x = 16blockDim.y = 4 的二维块。这个选择是经过深思熟虑的。我们可以换一个角度来看待这个问题:将 64 个线程看成 2 个 warp(每个 warp 32 线程),但实际上的组织是将 64 个线程映射到输出列的 64 个不同位置上。

指针初始化的精确解读

这是整个内核中最容易让人困惑的部分。让我们逐个拆解 A、B、C 三个指针的初始化过程。

C 矩阵:定位输出位置

1
2
C += blockIdx.x * 64 + threadIdx.x +
(threadIdx.y + blockIdx.y * ldc) * 16;

把它拆成两部分看:

  • 列偏移blockIdx.x * 64 + threadIdx.x —— 当前线程负责输出矩阵中的哪一列
  • 行偏移(threadIdx.y + blockIdx.y * ldc) * 16 —— 注意这里乘的是 ldc(leading dimension of C),而不是 16。这是因为 C 是 行主序 存储的(虽然 CUDA 中通常用列主序的 BLAS 约定,但这里内核实现是以行主序处理的),ldc * 16 表示跳过了 16 行(每行 ldc 个元素),而 threadIdx.y * 16 表示在这个块的 16 行中,当前线程负责第 threadIdx.y * 16 行开始的连续 16 个元素。

所以每个线程通过 C += ldc 循环 16 次,写回 C[0]C[ldc]C[2*ldc]、…… 共 16 个元素,即 C 中 同一列、连续 16 行 的列向量。

A 矩阵:与 C 列对齐

1
A += blockIdx.x * 64 + threadIdx.x + threadIdx.y * 16;

A 和 C 有着相同的列偏移计算方式。这是因为在矩阵乘法 $C = A \times B$ 中,C 的每一列都是由 A 的 所有列 乘以 B 的 对应行 累加而来。当线程固定负责 C 的第 col 列时,它需要从 A 中读取的也正是第 col 列的所有行。

这里 A += lda 是在主循环中逐行移动 A 的指针,对应 A 的不同行(即 k 维度上的不同位置)。

B 矩阵:与 C 行对齐

1
B += threadIdx.x + (blockIdx.y * 16 + threadIdx.y) * ldb;
  • blockIdx.y * 16 —— 当前线程块负责的输出块的首行在 B 中的对应位置(注意是 * ldb,跳到对应的行)
  • threadIdx.y —— 这是细粒度的行偏移,即这个块内第几个 16 行组
  • threadIdx.x —— 这个偏移是 B 矩阵的列偏移,对应 $C(:,col) = A(:,0) \times B(0,col) + A(:,1) \times B(1,col) + \dots$ 中的 B(row_of_B, col_of_C) 中的 col_of_C

主循环中的 B += 16 表示每次迭代跳过 B 的 16 行(即沿 k 维度前进 16 步),读取下一个 16 × 64 的 B 子块。

共享内存布局:bs[16][17] 的玄机

1
__shared__ float bs[16][17];

注意第二维是 17 而不是 16。这是一个经典技巧——pad 一列以避免 bank conflict

回顾一下共享内存的 bank 映射规则:连续的 4 字节字映射到连续的 bank,即 bs[i][j] 的 bank 编号为 (j * 17 + i) % 32。如果没有 padding,即使用 bs[16][16],那么 bs[i][0]bs[i+1][0] 之间的偏移是 16 个字 = 64 字节,对应 bank 偏移为 16 % 32 = 16。但问题在于,当同一个 warp 的线程同时访问 bs[threadIdx.x][j] 时(其中 threadIdx.x 在 [0, 15] 范围内变化),访问的是同一行中的不同列:

线程 (tid.x) 访问地址
0 bs[0][j] → bank (j*17+0) % 32
1 bs[1][j] → bank (j*17+1) % 32
15 bs[15][j] → bank (j*17+15) % 32

如果使用 bs[16][16],则 j 相同的列上的元素会映射到相同 bank,从而引发冲突。而 +1 padding 让列间偏移变成了 17 × 4 = 68 字节,使得 连续的行映射到连续的 bank,从而避免了冲突。

主循环:分块计算的核心

主循环的结构是一个 do-while

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
do {
// 阶段 1:加载 B 的子块到共享内存
#pragma unroll
for (int i = 0; i < 16; i += 4)
bs[threadIdx.x][threadIdx.y + i] = B[i * ldb];
B += 16;
__syncthreads();

// 阶段 2:从共享内存中读取 B,从全局内存中读取 A,累积计算
#pragma unroll
for (int i = 0; i < 16; i++, A += lda) {
c[0] += A[0] * bs[i][0];
// ... 共 16 个乘加
c[15] += A[0] * bs[i][15];
}
__syncthreads();
} while (B < Blast);

阶段 1:加载 B

1
2
for (int i = 0; i < 16; i += 4)
bs[threadIdx.x][threadIdx.y + i] = B[i * ldb];

这 64 个线程共同协作,将 B 的一个 16 × 64 的子块加载到共享内存中。每个线程负责加载 i += 4 循环的 4 次迭代,共 4 个元素。64 个线程 × 4 个元素 = 256 个浮点数 = 一个完整的 16 × 16 的 B 子块。

每次循环加载一个 16 × 16 的 B 子块,4 次迭代共加载 16 × 64 的 B 子块。

有人可能会问:每次加载的 B[i*ldb] 中的 ldb 表示的是 B 中一行的跨度,而 B 在每次 do-while 循环中会前进 16(B += 16),这里为什么有 B[i*ldb]?因为在 k 维度上的步进是 16 一行一行地在前进,而 i*ldb 是在当前这个 16 × 64 的块中,访问 B 的不同行。

阶段 2:计算

1
2
3
4
5
for (int i = 0; i < 16; i++, A += lda) {
c[0] += A[0] * bs[i][0]; c[1] += A[0] * bs[i][1];
// ...
c[15] += A[0] * bs[i][15];
}

这是整个内核的计算核心。对于每个 i(0 到 15):

  • **A[0]**:从全局内存读取 A 的当前元素,即 A 矩阵的第 col 列的第 row_in_block * 16 + i
  • **bs[i][0..15]**:从共享内存读取 B 子块的对应行

对于固定的 i,bs[i][0..15] 这 16 个数字将被当前 block 中每行(也就是有相同 threadIdx.x 的线程)的 16 个线程共享。但当前线程只需要读取 bs[i][j] 中的 16 个数字,jthreadIdx.x 决定。

每个 do-while 迭代内,内层循环执行 16 次,每次执行 16 个乘加运算(1 个 A[0] 乘以 16 个不同的 bs[i][j]),共 256 次乘加运算。同时,这 256 次运算中:

  • 从全局内存读取了 16 个 A 的元素A[0] 在每次内层循环中被读取)
  • 从共享内存读取了 256 个 B 的元素bs[i][0..15] 在每次内层循环中被读取)

等等,这里有个细节:实际上 A[0] 在 16 次内层循环迭代中是从全局内存读取的,每个 A 的元素被重用了 16 次(对应 bs[i][0..15] 的 16 个值)。反过来,B 的元素通过共享内存被加载到 bs 后,在随后的内层循环中被 block 中具有相同 threadIdx.x 的所有线程(即同一列的所有线程)访问。

共享内存在这里起的关键作用是:将 B 的全局内存访问从非合并的转变为合并的,并且 让同一个 warp 的线程能够共享 B 的数据

数据复用比率

在一个 do-while 迭代中:

  • 从全局内存读取 A:16 个浮点数(A[0] 被连续读 16 次,每次 i 迭代读一次)
  • 从共享内存读取 B:16 × 16 = 256 个浮点数(但已经在阶段 1 加载到共享内存)
  • 从全局内存写 C:0(写回在最后 epilogue 中)

对于全局内存而言,每个 do-while 迭代只读 256 个浮点数(16 个 A + 通过共享内存加载的 256 个 B),却完成了 256 次 FMA,数据复用比率为 1:1

实际上,更精确的计算是:每 256 次 FMA,只从全局内存读取 A 的 16 个元素(B 通过共享内存中转,但最终还是从全局内存读取了 256 个元素到共享内存)。所以每 256 次 FMA,产生了 272 个全局内存读取。然而,B 的子块一旦加载到共享内存,可以被后续循环中的多个线程访问,所以实际的数据复用率远高于表面数字。

写回阶段:alpha 和 beta 的应用

1
2
for (int i = 0; i < 16; i++, C += ldc)
C[0] = alpha * c[i] + beta * C[0];

这是 epilogue 阶段。每个线程将其寄存器中的 16 个部分和 c[0..15] 写回全局内存:

  1. c[i] 乘以 alpha
  2. 加上原来的 C[0] 乘以 beta
  3. 结果写回 C[0]

这是一个 融合的乘加更新(fused multiply-add update),在一次操作中同时完成缩放和累加。

与 Volkov 论文的对应

回顾 Volkov 论文的核心观点,看看这个内核是如何体现的:

论文观点 内核中的体现
最大化计算/访存比 每个线程有 16 个寄存器累加器,FMA 与访存的比值很高
减少同步 每个 do-while 迭代只需要两次 __syncthreads()
合并访存 A 的访问通过 A[0] 是合并的;B 通过共享内存转储再访问
利用寄存器,避免共享内存 16 个 c[i] 全部在寄存器中,没有中间结果的共享内存存储
Bank conflict 避免 bs[16][17] 的 padding 技巧

进一步的思考

这个内核虽然已经非常高效,但它并不是 SGEMM 优化的终点。后续的改进方向包括:

  1. 双缓冲(Double Buffering):使用两份共享内存,让加载和计算可以重叠,消除 __syncthreads() 带来的等待。

  2. 更大的 tile 尺寸:现代 GPU(如 Volta、Ampere、Hopper)有更多的寄存器和更大的共享内存,可以使用 32 × 128、64 × 64 甚至更大的 tile 尺寸。

  3. Warp-level 矩阵分块(WMMA):从 Volta 架构开始,NVIDIA 引入了 Tensor Core 和 warp-level 矩阵乘法指令,可以一个 warp 在一条指令中完成 16 × 16 × 16 的矩阵乘法。

  4. 自动化调优:使用 Auto-Tuning 框架(如 Ansor、AutoTVM)自动搜索最优的 tile 尺寸、循环展开因子等参数。

但无论如何,这个 40 行的内核已经抓住了 SGEMM 优化的灵魂——通过精心设计的线程组织和数据流,最大化计算吞吐量、最小化片外内存访问。理解了它,你就理解了高性能计算中矩阵乘法优化的精髓。

这一篇就到这里,下一篇我们可能会讲 执行配置优化(occupancy、grid size、block size 的选择),也可能是 双缓冲技术 或者 WMMA/Tensor Core 编程。随缘吧。鼓掌👏👏👏


从零入门 cuda 编程?🦴深入理解 sgemmNN 算法!
http://example.com/2026/06/28/cuda-sgemmnn/
Author
LazyPool
Posted on
June 28, 2026
Licensed under