为什么 Attention 是 O(n²),能不能降?
一、为什么 Attention 天生是 O(n²)?
先看一句话本质:
**Attention 在做的事情是:
序列中“每一个 token”都要和“所有 token”算一次相关性。**
1️⃣ 数学来源(真正的根因)
标准 Self-Attention:
关键在这一步:
👉 n 个 query × n 个 key = n² 个相似度
不管你用什么硬件,这一步不可避免。
2️⃣ 直观理解
-
每个词都要问一句:
“我和序列里的每一个词像不像?”
-
一共 n 个词
-
每个问 n 次
👉 所以是 n² 次比较
📌 面试金句:
Attention 是“全连接图”,不是“链”或“网格”。
二、那这个 n² 是不是“必须的”?
答案是:理论上不是,但语义上是
解释这句话👇
1️⃣ 如果你要「全局任意依赖」
- 任意 token ↔ 任意 token
- 那就是一个完全图
- 完全图的边数就是 n²
👉 全局 Attention ≈ 完全图
2️⃣ CNN / RNN 为什么不是 n²?
- CNN:局部连接(稀疏图)
- RNN:链式连接(n 条边)
👉 你牺牲的是 全局感知能力
三、能不能把 O(n²) 降下来?
能,而且已经有一整代模型在干这件事。
我按思想流派给你分类(不是简单列名字)。
四、降复杂度的 4 大思路(重点)
方案一:限制“谁能看谁”(稀疏 Attention)
思路 :
不让每个 token 看所有 token。
代表:
- Longformer(滑动窗口)
- Sparse Transformer
- BigBird(局部 + 全局 + 随机)
✅ 适合:长文档
❌ 缺点:可能错过远程关键信息
📌 金句:
稀疏 Attention 是在“赌”重要信息是局部的。
方案二:低秩近似(矩阵数学)
思路 :
n×n 的 Attention 矩阵其实是低秩的
代表:
- Linformer
- Nyströmformer
✅ 复杂度线性
❌ 近似误差,rank 选不好就翻车
方案三:核函数技巧(最聪明的一类)
思路 :
把 softmax Attention 变成 可交换的形式
先算 K、V,再算 Q
👉 避免 n×n 矩阵
代表:
- Performer(FAVOR+)
- Linear Attention
复杂度:
✅ 理论优雅
❌ 数值稳定性、精度问题
📌 面试加分点:
这是目前“理论上最漂亮”的降复杂度方案。
方案四:工程优化(不降 n²,但跑更快)
注意 :这是面试陷阱点。
FlashAttention
- 不改变复杂度
- 改变 内存访问方式
- 避免显存 IO 瓶颈
👉 结果:
- 同样 n²
- 但 快 2–4 倍
📌 金句:
FlashAttention 优化的是“显存”,不是“算法阶数”。
五、那 GPT / LLM 用了哪种?
现实答案(非常真实)👇
| 技术 | 是否降复杂度 |
|---|---|
| FlashAttention | ❌ |
| Sliding Window | ✅(局部) |
| KV Cache(推理) | ✅(时间维度) |
| 真正线性 Attention | ❌(主流) |
👉 大模型训练阶段仍然是 O(n²)
👉 靠工程和硬件硬扛
六、终极总结
Attention 是 O(n²),因为它在建一个“全连接的语义图”。
想降复杂度,本质只有三条路:
1)让图变稀疏
2)承认它是低秩
3)把 softmax 拆掉
代价是:表达能力、稳定性或精度。
标题:为什么 Attention 是 O(n²),能不能降?
作者:guobing
地址:http://guobingwei.tech/articles/2025/12/26/1766737132447.html