下面给你一段最小可读的 PyTorch 风格实现,直观展示 N(序列长度)与 d = head_dim × num_heads 在张量维
一段最小可读的 PyTorch 风格实现,直观展示 N(序列长度)与 d = head_dim × num_heads 在张量维度里的“落地位置”。我把关键形状都写在注释里
下面给你一段最小可读的 PyTorch 风格实现,直观展示 N(序列长度)与 d = head_dim × num_heads 在张量维度里的“落地位置”。我把关键形状都写在注释里。
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
# ---------- 1) Scaled Dot-Product Attention(单头示意) ----------
def scaled_dot_product_attention(Q, K, V):
"""
Q: [B, N, d_k]
K: [B, N, d_k]
V: [B, N, d_v] —— 通常 d_v == d_k
"""
# scores = Q · K^T / sqrt(d_k)
# 形状:[B, N, N] —— 注意这里产生了 "N×N" 的注意力矩阵
d_k = Q.size(-1)
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d_k)
attn = F.softmax(scores, dim=-1) # [B, N, N]
out = torch.matmul(attn, V) # [B, N, d_v]
return out, attn # attn 按元素数量 ~ O(B·N²)
# ---------- 2) Multi-Head Attention(标准结构) ----------
class MultiHeadAttention(nn.Module):
def __init__(self, d_model: int, num_heads: int):
super().__init__()
assert d_model % num_heads == 0, "d_model 必须能整除 num_heads"
self.d_model = d_model
self.num_heads = num_heads
self.head_dim = d_model // num_heads # <- 这就是 d_k
# d = head_dim × num_heads = d_model
# 线性投影到 Q/K/V 空间
self.W_q = nn.Linear(d_model, d_model, bias=False)
self.W_k = nn.Linear(d_model, d_model, bias=False)
self.W_v = nn.Linear(d_model, d_model, bias=False)
self.W_o = nn.Linear(d_model, d_model, bias=False)
def forward(self, x):
"""
x: [B, N, d_model]
其中:
B = batch_size
N = seq_len(token 数)
d_model = num_heads × head_dim
"""
B, N, _ = x.shape
H, D = self.num_heads, self.head_dim # H×D = d_model
# 线性映射后 reshape 成多头形状
# [B, N, d_model] -> [B, H, N, D]
Q = self.W_q(x).view(B, N, H, D).transpose(1, 2)
K = self.W_k(x).view(B, N, H, D).transpose(1, 2)
V = self.W_v(x).view(B, N, H, D).transpose(1, 2)
# 注意力分数:Q·K^T / sqrt(D)
# 形状:[B, H, N, N] —— 这里就是 O(N^2) 的来源
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(D)
attn = scores.softmax(dim=-1) # [B, H, N, N]
# 乘以 V,得到各头输出
heads = torch.matmul(attn, V) # [B, H, N, D]
# 合并头: [B, H, N, D] -> [B, N, H·D=d_model]
out = heads.transpose(1, 2).contiguous().view(B, N, H * D)
out = self.W_o(out) # [B, N, d_model]
return out
关键点速记
- d_model = num_heads × head_dim 是源码里必然出现且必须满足整除关系的结构常量;它决定了投影权重矩阵的尺寸与张量 reshape 的形状。
- 注意力权重矩阵的形状是 [B, num_heads, N, N],元素数量随 N² 增长;因此计算与显存都会在长序列时急剧上升。
- 计算复杂度近似 O(N² × d_model): N² 来自 QK^T; 对每个头计算,再在头内用 head_dim (D) 做乘加,所有头相加等价于与 d_model 线性相关。
拓展(可选)
- FlashAttention 等优化能明显降低显存占用与常数因子(通过块化/在线 softmax),但算术交互仍旧覆盖 N×N,渐近复杂度仍是 O(N²)。
- 只有稀疏/近似注意力或图结构检索式路由(如你提出的 WAO/SKG 思路,或部分 linear/performer/LSH/路由注意力)才有机会把有效复杂度降到近似 O(N log N) 或更低。
如果你愿意,我可以把这段代码做成一个最小可运行示例(含随机输入与显存/时间打印),方便你在本地快速对比 N、d_model 的伸缩效应。
Comments (0)
No comments