Puzzles 8: Long softmax
puzzles8是计算batch的softmax,题目如下:
Softmax of a batch of logits.
Uses one program block axis. Block size B0 represents the batch of x of length N0.
Block logit length T. Process it B1 < T elements at a time.
.. math::
z_{i, j} = \text{softmax}(x_{i,1} \ldots x_{i, T}) \text{ for } i = 1\ldots N_0
Note softmax needs to be computed in numerically stable form as in Python. In addition in Triton
they recommend not using exp but instead using exp2. You need the identity
.. math::
\exp(x) = 2^{\log_2(e) x}
Advanced: there one way to do this with 3 loops. You can also do it with 2 loops if you are clever.
Hint: you will find this identity useful:
.. math::
\exp(x_i – m) = \exp(x_i – m/2 – m/2) = \exp(x_i – m/ 2) / \exp(m/2)
“””
def softmax_spec(x: Float32[4, 200]) -> Float32[4, 200]:
x_max = x.max(1, keepdim=True)[0]
x = x – x_max
x_exp = x.exp()
return x_exp / x_exp.sum(1, keepdim=True)
然后这题需要提供两种解法,一种是暴力的解法,3个loop;另一种是聪明的解法,2个loop。先从暴力解法开始着手。
暴力解法思路
-
一个loop去取每一个行的最大值
-
每行中的每列减去对应行的最大值,顺便exp
-
一个loop去相加对应exp之后的值函数
-
一个loop计算最后的softmax
相关的triton接口
- torch.full(shape, value, dtype)可以直接初始化一个大小为shape,值为value的dtype向量,可以用来初始化极小值,用来取最大值,后面发现用tl.zeros也可以
解法
def softmax_kernel_brute_force(x_ptr, z_ptr, N0, N1, T, B0: tl.constexpr, B1: tl.constexpr):
"""2 loops ver."""
block_id_i = tl.program_id(0)
log2_e = 1.44269504
# Finish me!i
offset_x = block_id_i * B0 + tl.arange(0,B0)
mask_x = offset_x < N0
row_max = tl.zeros(shape=[B0,1], dtype=tl.float32)
row_sum_exp = tl.zeros([B0, 1], dtype=tl.float32)
for idj in tl.range(0,T,B1):
offset_y = idj + tl.arange(0,B1)
mask_y = offset_y < T
offset_xy = offset_x[:,None] * T + offset_y[None,:]
mask_xy = mask_x[:,None] & mask_y[None,:]
block_value =tl.load(x_ptr + offset_xy, mask_xy, other=float('-inf'))
row_max = tl.maximum(row_max, tl.max(block_value,axis=1, keep_dims=True))
for idj in tl.range(0, T, B1):
offset_y = idj + tl.arange(0,B1)
mask_y = offset_y < T
offset_xy = offset_x[:,None] * T + offset_y[None,:]
mask_xy = mask_x[:,None] & mask_y[None,:]
block_value = tl.load(x_ptr + offset_xy, mask_xy, other=float('-inf'))
block_value -= row_max
row_sum_exp += tl.sum(exp_approx(block_value),axis=1, keep_dims=True)
for idj in tl.range(0, T, B1):
offset_y = idj + tl.arange(0,B1)
mask_y = offset_y < T
offset_xy = offset_x[:,None] * T + offset_y[None,:]
mask_xy = mask_x[:,None] & mask_y[None,:]
block_value = tl.load(x_ptr + offset_xy, mask_xy, other=float('-inf'))
softmax_value = exp_approx(block_value - row_max) / row_sum_exp
tl.store(z_ptr + offset_xy, softmax_value, mask_xy)
return
- 写得比较冗长,但是核心思路应该就是上面说的三个循环
两个循环思路
- 这个思路就类似online softmax
$$
\begin{aligned}
& m_0 \leftarrow -\infty \
& d_0 \leftarrow 0 \
& \text{for } j \leftarrow 1, V \text{ do} \
& \quad m_j \leftarrow \max(m_{j-1}, x_j) \
& \quad d_j \leftarrow d_{j-1} \times e^{m_{j-1}-m_j} + e^{x_j-m_j} \quad \text{(Update row_sum_exp within the loop)} \
& \text{end for} \
& \text{for } i \leftarrow 1, V \text{ do} \
& \quad y_i \leftarrow \frac{e^{x_i-m_V}}{d_V} \
& \text{end for}
\end{aligned}
$$
解法
@triton.jit
def exp_approx(x):
return tl.exp2(1.44269504 * x)
@triton.jit
def softmax_kernel(x_ptr, z_ptr, N0, N1, T, B0: tl.constexpr, B1: tl.constexpr):
"""2 loops ver."""
block_id_i = tl.program_id(0)
log2_e = 1.44269504
# Finish me!i
offset_x = block_id_i * B0 + tl.arange(0,B0)
mask_x = offset_x < N0
row_max = tl.zeros(shape=[B0, 1],dtype=tl.float32)
row_sum_exp = tl.zeros([B0, 1], dtype=tl.float32)
for idj in tl.range(0, T, B1):
offset_y = idj + tl.arange(0,B1)
mask_y = offset_y < T
offset_xy = offset_x[:,None] * T + offset_y[None,:]
mask_xy = mask_x[:,None] & mask_y[None,:]
block_value = tl.load(x_ptr + offset_xy, mask_xy, other=float('-inf'))
tmp_row_max = row_max
tmp_row_max = tl.maximum(tl.max(block_value, axis=1, keep_dims=True), tmp_row_max)
row_sum_exp = row_sum_exp * exp_approx(row_max - tmp_row_max) + tl.sum(exp_approx(block_value - tmp_row_max),axis=1,keep_dims=True)
row_max = tmp_row_max
for idj in tl.range(0, T, B1):
offset_y = idj + tl.arange(0,B1)
mask_y = offset_y < T
offset_xy = offset_x[:,None] * T + offset_y[None,:]
mask_xy = mask_x[:,None] & mask_y[None,:]
block_value = tl.load(x_ptr + offset_xy, mask_xy, other=float('-inf'))
z_value = exp_approx(block_value - row_max) / row_sum_exp
tl.store(z_ptr + offset_xy, z_value, mask_xy)
return
Puzzle 9: Simple FlashAttention
A scalar version of FlashAttention.
Uses zero programs. Block size B0 represent the batches of q to process out of N0. Sequence length is T. Process it B1 < T elements (k, v) at a time for some B1.
.. math::
z_{i} = \sum_{j=1}^{T} \text{softmax}(q_i k_1, \ldots, q_i k_T)j v \text{ for } i = 1\ldots N_0
This can be done in 1 loop using a similar trick from the last puzzle.
Hint: Use tl.where to mask q dot k to -inf to avoid overflow (NaN).
这个类似flash attention v1了,one pass
Flash attention v1的完整递推公式
$$
\mathbf{
\begin{aligned}
x_i &\leftarrow Q[k,:] \cdot K^T[:,i] \
m_i &\leftarrow \max(m_{i-1}, x_i) \
d_i’ &\leftarrow d_{i-1}’ \cdot e^{m_{i-1} – m_i} + e^{x_i – m_i} \
O_i’ &\leftarrow O_{i-1}’ \cdot \frac{d_{i-1}’}{d_i’} \cdot e^{m_{i-1} – m_i} + \frac{e^{x_i – m_i}}{d_i’} \cdot V[i,:] \
\end{aligned}
}
$$
最终输出:
$$
O[k,:] \leftarrow O_N’
$$
其中:
-
$Q[k,:]$ 是 $Q$ 矩阵的第 $k$ 行向量。
-
$K^T[:,i]$ 是 $K^T$ 矩阵的第 $i$ 列向量。
-
$x_i$是预 softmax 的 logits 值。
-
$ m_i $ 是累积的最大值。
-
$d_i’$ 是累积的指数和。
-
$O_i’$ 是部分输出的累积值。
-
$V[i,:]$ 是 $ V $ 矩阵的第 $ i $ 行向量。
-
$ O[k,:]$ 是输出矩阵的第 $k $ 行向量。
解法
@triton.jit
def myexp(x):
return tl.exp2(1.44269504 * x)
@triton.jit
def flashatt_kernel(
q_ptr, k_ptr, v_ptr, z_ptr, N0, T, B0: tl.constexpr, B1: tl.constexpr
):
block_id_i = tl.program_id(0)
log2_e = 1.44269504
# Finish me!
off_i = block_id_i * B0 + tl.arange(0, B0)
mask_i = off_i < N0
inf = 1.0e6
# Need `other`!!!
q = tl.load(q_ptr + off_i, mask=mask_i)
# The variable names of Triton's offcial FlashAttention tutorial
# is attached here for reference.
# Our variable names are consistent with Puzzle 8.
# l_i
exp_sum = tl.zeros((B0,), dtype=tl.float32)
# m_i
qk_max = tl.full((B0,), -inf, dtype=tl.float32)
z = tl.zeros((B0,), dtype=tl.float32)
for id_j in tl.range(0, T, B1):
off_j = id_j + tl.arange(0, B1)
mask_j = off_j < T
mask_ij = mask_i[:, None] & mask_j[None, :]
k = tl.load(k_ptr + off_j, mask=mask_j)
qk = q[:, None] * k[None, :] + tl.where(mask_ij, 0, -inf)
# print(qk.shape)
# m_ij
new_max = tl.maximum(tl.max(qk, axis=1), qk_max)
qk_exp = myexp(qk - new_max[:, None])
# alpha
factor = myexp(qk_max - new_max)
# l_ij
new_exp_sum = exp_sum * factor + tl.sum(qk_exp, axis=1)
v = tl.load(v_ptr + off_j, mask=mask_j, other=0.0)
z = z * factor + tl.sum(qk_exp * v[None, :], axis=1)
qk_max = new_max
exp_sum = new_exp_sum
z = z / exp_sum
tl.store(z_ptr + off_i, z, mask=mask_i)
return
Reference
-
online softmax
-
Flash Attention
本文由博客一文多发平台 OpenWrite 发布!
