CS336资源计算问题

2025/12/06

最近在学习 Standford CS336 的课程,由于本人对这类计算问题不甚熟悉,因此记录一下。下图是课程中的 Transformer 结构:

问题一

考虑 GPT-2 XL,配置如下:

vocab_size : 50,257
context_length : 1,024
num_layers : 48
d_model : 1,600
num_heads : 25
d_ff : 6,400

假设我们使用这样的配置在上述结构中:

1)模型中有多少可训练的参数?

其参数的尺寸是 $\text{num\_embeddings}\times \text{embedding\_dim}$,所以

$$\text{num\_embeddings}\times \text{embedding\_dim}=\text{vocab\_size}\times \text{d\_model}=50,257\times 1,600=80,411,200$$

因为每个 TransformerBlock 中都有 RMSNorm、CausalMultiHeadAttention、RotaryPositionalEmbedding、SwiGLU,所以我们先讨论这三块。RMSNorm 的参数数量是 $\text{d\_model}$,即

$$\text{d\_model}=1,600$$

虽然 RoPE 并没有需要训练的参数,但其在计算时也会占用内存,且大小为

$$\text{seq\_len} \times \frac{d_k}{2} \times 2$$

其中 $d_k=d_v=\text{d\_model}/\text{num\_heads}=64$,故 RoPE 的(最大)参数量为:

$$\text{seq\_len} \times \frac{d_k}{2} \times 2=\text{context\_length}\times\frac{d_k}{2}\times 2=65,536$$
class SwiGLU(torch.nn.Module):
    def __init__(self, d_model: int, d_ff: int = None, device=None, dtype=None):
        super().__init__()
        if d_ff is None:
            d_ff = 8 * d_model / 3
        self.w1 = Linear(d_model, d_ff, device, dtype)
        self.w2 = Linear(d_ff, d_model, device, dtype)
        self.w3 = Linear(d_model, d_ff, device, dtype)

    def forward(self, x: torch.Tensor):  # [batch_size, seq_len, d_model] -> [batch_size, seq_len, d_model]
        _x = self.w1(x)
        x_silu = _x * torch.sigmoid(_x)
        return self.w2(self.w3(x) * x_silu)

容易看出,总参数量为:

$$3\times \text{d\_model}\times \text{d\_ff} = 3 \times 1,600 \times 6,400=30,720,000$$

其 $Q、K、V、O$ 均为 $\text{d\_model}\times \text{d\_model}$,其中 $O$ 是上投影层。总共可训练参数量为:

$$4\times \text{d\_model}\times \text{d\_model}=4\times 1,600\times 1,600=10,240,000$$

每个 TransformerBlock 包含 2 个 RMSNorm,1 个 CausalMultiHeadAttention,1 个 SwiGLU。故总参数量为:

$$2\times1,600+10,240,000+30,720,000=40,963,200$$

Transformer 包括 1 个 Embedding、1 个 RotaryPositionalEmbedding(不计入可训练参数)、48 个 TransformerBlock、1 个 RMSNorm、1 个 Linear 上投影层($\text{d\_model}\times \text{vocab\_size}$)。故总参数量为:

$$ 1\times 80,411,200 + 48\times 40,963,200 + 1\times 1,600 + 1\times 1,600\times 50,257=2,127,057,600 $$

2)假设每个参数都是单精度浮点数,载入这个模型需要多少内存?

每个单精度浮点数占 4 字节,故总内存为:

$$ 4\times 2,127,057,600=8,508,230,400~bytes \approx 7.9239~GB $$

问题二

1)假设输入的长度为 $\text{context\_length}$,在一次前向传播过程中,需要多少 FLOPs?只计算矩阵乘法。

对于矩阵 $A(m\times n)$、$B(n\times p)$,其 FLOPs 数量为 $2mnp$。知道做了多少次矩阵乘法,就知道了总的 FLOPs。我们还是像刚才一样,一层层看。

这里只是通过 token id 来选择对应的词向量,没有做矩阵乘法。

其前向过程如下;

def forward(
    self, x: torch.Tensor
) -> torch.Tensor:  # [batch_size, seq_len, d_model] -> [batch_size, seq_len, d_model]
    in_dtype = x.dtype
    x = x.to(torch.float32)

    rms = torch.sqrt(reduce(x**2, "... d_model -> ... 1", "mean") + self.eps)
    result = self.w * x / rms

    return result.to(in_dtype)

可以看到,实际上没有做矩阵乘法(注意 self.w * x 不是矩阵乘法)。

对于 rms = torch.sqrt(reduce(x**2, "... d_model -> ... 1", "mean") + self.eps) 来说,做了 $\text{batch\_size}\times \text{seq\_len}\times \text{d\_model}$ 次乘法、$\text{batch\_size}\times \text{seq\_len}\times (\text{d\_model}-1)+1$ 次加法、$\text{batch\_size}\times \text{seq\_len}$ 次开方。

对于 result = self.w * x / rms 来说,self.w 会被广播为 $\text{batch\_size}\times \text{seq\_len}\times \text{d\_model}$,然后逐元素相乘,因此乘法次数也是 $\text{batch\_size}\times \text{seq\_len}\times \text{d\_model}$,除法次数也是 $\text{batch\_size}\times \text{seq\_len}\times \text{d\_model}$。

总共的运算次数是:

$$ \begin{aligned} FLOPS_{RMSNorm} =& \text{batch\_size}\times \text{seq\_len}\times \text{d\_model} + \\ & \text{batch\_size}\times \text{seq\_len}\times (\text{d\_model}-1)+ 1 + \\ & \text{batch\_size}\times \text{seq\_len} + \\ & 2\times \text{batch\_size}\times \text{seq\_len}\times \text{d\_model} \\ =&4\times \text{batch\_size}\times \text{seq\_len}\times \text{d\_model} + 1 \end{aligned} $$

其前向过程如下;

def forward(
    self, x: torch.Tensor, token_positions: torch.Tensor
) -> torch.Tensor:  # ([..., seq_len, d_k], [..., seq_len]) -> [..., seq_len, d_k]
    pos_sin = self.sin[token_positions]  # [..., seq_len, d_k/2]
    pos_cos = self.cos[token_positions]  # [..., seq_len, d_k/2]

    x_even = x[..., 0::2]  # [..., seq_len, d_k/2]
    x_old = x[..., 1::2]  # [..., seq_len, d_k/2]

    x_even_rot = x_even * pos_cos - x_old * pos_sin  # [..., seq_len, d_k/2]
    x_old_rot = x_even * pos_sin + x_old * pos_cos  # [..., seq_len, d_k/2]

    x_rot = rearrange([x_even_rot, x_old_rot], "two ... -> ... two")
    x_rot = rearrange(x_rot, "... d1 d2 -> ... (d1 d2)")

    return x_rot

只有此处有浮点数运算:

x_even_rot = x_even * pos_cos - x_old * pos_sin  # [..., seq_len, d_k/2]
x_old_rot = x_even * pos_sin + x_old * pos_cos  # [..., seq_len, d_k/2]

而 $d_k = \text{d\_model} / \text{num\_heads} = 64$,因此

$$ \begin{aligned} FLOPS_{RotaryPositionalEmbedding} &= (2*\text{batch\_size}\times \text{seq\_len}\times d_k/2 + 1)\times 2 \\ &= 128\times \text{batch\_size}\times \text{seq\_len} + 2 \end{aligned} $$
def forward(self, x: torch.Tensor):  # [batch_size, seq_len, d_model] -> [batch_size, seq_len, d_model]
    _x = self.w1(x)
    x_silu = _x * torch.sigmoid(_x)
    return self.w2(self.w3(x) * x_silu)

第一次矩阵乘法,_x = self.w1(x), FLOPs 数量为:

$$ \text{batch\_size}\times 2\times \text{seq\_len}\times \text{d\_model}\times \text{d\_ff} $$

第二次矩阵乘法,self.w3(x), FLOPs 数量为:

$$ \text{batch\_size}\times 2\times \text{seq\_len}\times \text{d\_model}\times \text{d\_ff} $$

第三次矩阵乘法,self.w2(self.w3(x) * x_silu), FLOPs 数量为:

$$ \text{batch\_size}\times 2\times \text{seq\_len}\times \text{d\_model}\times \text{d\_ff} $$

其中,$\text{d\_ff}=\frac{8\text{d\_model}}{3}$,故 SwiGLU 总 FLOPs 数量为:

$$ \text{batch\_size}\times 16\times \text{seq\_len}\times \text{d\_model}\times \text{d\_model} $$
def scaled_dot_product_attention(
    Q: torch.Tensor,  # (batch_size, ..., seq_len, d_k)
    K: torch.Tensor,  # (batch_size, ..., seq_len, d_k)
    V: torch.Tensor,  # (batch_size, ..., seq_len, d_v)
    mask: torch.Tensor | None = None,  # (seq_len, seq_len)
) -> torch.Tensor:  # (batch_size, ..., d_v)
    d_k = Q.shape[-1]
    score_attention = einsum(Q, K, "... seq_len_q d_k, ... seq_len_k d_k -> ... seq_len_q seq_len_k") / math.sqrt(d_k)
    if mask is not None:
        score_attention = torch.where(mask, score_attention, float("-inf"))
    score_attention = softmax(score_attention, dim=-1)  # [... seq_len_q seq_len_k]
    return einsum(score_attention, V, "... seq_len_q seq_len_k, ... seq_len_k d_v -> ... seq_len_q d_v")

$Q$ 和 $K$ 矩阵乘法,FLOPs 数量为:

$$ \text{batch\_size}\times 2 \times \text{seq\_len}\times d_k\times \text{seq\_len} $$

再乘 $V$ 矩阵,FLOPs 数量为:

$$ \text{batch\_size}\times 2 \times \text{seq\_len}\times \text{seq\_len}\times d_v $$

其中,$d_k=d_v=\text{d\_model}/\text{num\_heads}=64$,故 scaled_dot_product_attention 总 FLOPs 数量为:

$$ \begin{aligned} FLOPS_{\text{scaled\_dot\_product\_attention}}&= \text{batch\_size}\times 4 \times \text{seq\_len}\times \text{seq\_len}\times 64 \\ &= 128\times \text{batch\_size}\times \text{seq\_len}\times \text{seq\_len} \end{aligned} $$
def forward(self, x: torch.Tensor) -> torch.Tensor:  # [batch_size, seq_len, d_model]
    q = self.w_q(x)
    k = self.w_k(x)
    v = self.w_v(x)

    q = rearrange(
        q, "batch_size seq_len (num_heads d_k) -> batch_size num_heads seq_len d_k", num_heads=self.num_heads
    )
    k = rearrange(
        k, "batch_size seq_len (num_heads d_k) -> batch_size num_heads seq_len d_k", num_heads=self.num_heads
    )
    v = rearrange(
        v, "batch_size seq_len (num_heads d_v) -> batch_size num_heads seq_len d_v", num_heads=self.num_heads
    )

    if self.token_positions is None:
        self.token_positions = torch.arange(x.shape[-2], device=x.device)
    if self.rope is not None:
        q = self.rope(q, self.token_positions)
        k = self.rope(k, self.token_positions)

    mask = ~torch.triu(torch.ones((x.shape[-2], x.shape[-2]), device=x.device, dtype=torch.bool), diagonal=1)

    y = scaled_dot_product_attention(q, k, v, mask)
    y = rearrange(y, "batch_size num_heads seq_len d_v -> batch_size seq_len (num_heads d_v)")
    y = self.w_o(y)
    return y

其 $Q、K、V、O$ 均为 $\text{d\_model}\times \text{d\_model}$,其中 $O$ 是上投影层。四个矩阵乘法形状相同,总计算量为:

$$ 4\times \text{batch\_size}\times 2\times \text{seq\_len}\times \text{d\_model} \times \text{d\_model} $$

而 $FLOPS_{\text{scaled\_dot\_product\_attention}}=128\times \text{batch\_size}\times \text{seq\_len}\times \text{seq\_len}$,所以

$$ \begin{aligned} FLOPS_{CausalMultiHeadAttention}=&4\times \text{batch\_size}\times 2\times \text{seq\_len}\times \text{d\_model} \times \text{d\_model} + \\ &128\times \text{batch\_size}\times \text{seq\_len}\times \text{seq\_len} \\ =& 8\times \text{batch\_size} \times \text{seq\_len} \times \text{d\_model}^2 + 128 \times \text{batch\_size}\times \text{seq\_len}^2 \end{aligned} $$
def forward(self, x: torch.Tensor) -> torch.Tensor:
    y = x + self.attn(self.norm1(x))
    y = y + self.ff(self.norm2(y))
    return y

所以

$$ FLOPS_{TransformerBlock}= FLOPS_{CausalMultiHeadAttention} + FLOPS_{SwiGLU} $$
def forward(self, x: torch.Tensor):  # [batch_size, seq_len]
    assert x.shape[-2] < self.context_length, "Input sequence length is longer than the context length"
    y = self.embedding(x)  # [batch_size, seq_len, d_model]
    for transformer_block in self.transformer_blocks:  # [batch_size, seq_len, d_model]
        y = transformer_block(y)
    y = self.norm(y)  # [batch_size, seq_len, d_model]
    y = self.linear(y)  # [batch_size, seq_len, vocab_size]
    return y

所以

$$ \begin{aligned} FLOPS_{TransformerLM} =& \text{num\_layers}\times FLOPS_{TransformerBlock} + FLOPS_{RMSNorm} + FLOPS_{Linear}\\ =& \text{num\_layers}\times (FLOPS_{CausalMultiHeadAttention} + FLOPS_{SwiGLU}) + FLOPS_{RMSNorm} + FLOPS_{Linear}\\ =& 48\times(8\times \text{batch\_size} \times 1024\times 1600^2+128\times \text{batch\_size}\times 1024^2+\text{batch\_size}\times 16\times 1024\times 1600^2)+\\ &0+\\ &2\times \text{batch\_size}\times 1024\times 50257\times 1600 \\ =&1,177,881,214,976\times \text{batch\_size} \end{aligned} $$

问题三

1)哪部分 FLOPs 最多?

是 $FLOPS_{SwiGLU}$。其次是 $FLOPS_{CausalMultiHeadAttention}$。

问题四

1)对于 GPT-2 small (12 layers, 768 d_model, 12 heads)、 GPT-2 medium (24layers, 1024 d_model, 16 heads) 和 GPT-2 large (36 layers, 1280 d_model, 20 heads) 模型,哪个变量的改变最影响模型的 FLOPs 数量?

根据计算公式,很显然应该是平方项,即 $\text{d\_model}$ 和 $\text{context\_length}$。不过这里没有 $\text{context\_length}$,所以答案应该是 $\text{d\_model}$。

问题五

1)GPT-2 XL 把上下文窗口增加到 16,384,总共的 FLOPs 如何变化?

重新算一遍就完事了,懒得算了。不过根据上一题的分析,上下文窗口会导致 FLOPs 猛增。