为什么 Attention 是 O(n²),能不能降?

  |   0 评论   |   0 浏览

一、为什么 Attention 天生是 O(n²)?

先看一句话本质:

**Attention 在做的事情是:

序列中“每一个 token”都要和“所有 token”算一次相关性。**


1️⃣ 数学来源(真正的根因)

标准 Self-Attention:

Q ∈ R^(n×d) K ∈ R^(n×d) V ∈ R^(n×d) Attention(Q,K,V) = softmax(QKᵀ / √d) V

关键在这一步:

QKᵀ → (n×d) × (d×n) = (n×n)

👉 n 个 query × n 个 key = n² 个相似度

不管你用什么硬件,这一步不可避免。


2️⃣ 直观理解

  • 每个词都要问一句:

    “我和序列里的每一个词像不像?”

  • 一共 n 个词

  • 每个问 n 次

👉 所以是 n² 次比较

📌 面试金句:

Attention 是“全连接图”,不是“链”或“网格”。


二、那这个 n² 是不是“必须的”?

答案是:理论上不是,但语义上是

解释这句话👇

1️⃣ 如果你要「全局任意依赖」

  • 任意 token ↔ 任意 token
  • 那就是一个完全图
  • 完全图的边数就是

👉 全局 Attention ≈ 完全图


2️⃣ CNN / RNN 为什么不是 n²?

  • CNN:局部连接(稀疏图)
  • RNN:链式连接(n 条边)

👉 你牺牲的是 全局感知能力


三、能不能把 O(n²) 降下来?

能,而且已经有一整代模型在干这件事。

我按思想流派给你分类(不是简单列名字)。


四、降复杂度的 4 大思路(重点)


方案一:限制“谁能看谁”(稀疏 Attention)

思路

不让每个 token 看所有 token。

每个 token 只看 k 个 复杂度:O(nk)

代表:

  • Longformer(滑动窗口)
  • Sparse Transformer
  • BigBird(局部 + 全局 + 随机)

✅ 适合:长文档

❌ 缺点:可能错过远程关键信息

📌 金句:

稀疏 Attention 是在“赌”重要信息是局部的。


方案二:低秩近似(矩阵数学)

思路

n×n 的 Attention 矩阵其实是低秩的

QKᵀ ≈ Q(EK)ᵀ 复杂度:O(nr)

代表:

  • Linformer
  • Nyströmformer

✅ 复杂度线性

❌ 近似误差,rank 选不好就翻车


方案三:核函数技巧(最聪明的一类)

思路

把 softmax Attention 变成 可交换的形式

softmax(QKᵀ)V ≈ φ(Q) [φ(K)ᵀ V]

先算 K、V,再算 Q

👉 避免 n×n 矩阵

代表:

  • Performer(FAVOR+)
  • Linear Attention

复杂度:

O(n d²)

✅ 理论优雅

❌ 数值稳定性、精度问题

📌 面试加分点:

这是目前“理论上最漂亮”的降复杂度方案。


方案四:工程优化(不降 n²,但跑更快)

注意 :这是面试陷阱点。

FlashAttention

  • 不改变复杂度
  • 改变 内存访问方式
  • 避免显存 IO 瓶颈

👉 结果:

  • 同样 n²
  • 快 2–4 倍

📌 金句:

FlashAttention 优化的是“显存”,不是“算法阶数”。


五、那 GPT / LLM 用了哪种?

现实答案(非常真实)👇

技术是否降复杂度
FlashAttention
Sliding Window✅(局部)
KV Cache(推理)✅(时间维度)
真正线性 Attention❌(主流)

👉 大模型训练阶段仍然是 O(n²)

👉 靠工程和硬件硬扛


六、终极总结

Attention 是 O(n²),因为它在建一个“全连接的语义图”。

想降复杂度,本质只有三条路:

1)让图变稀疏

2)承认它是低秩

3)把 softmax 拆掉

代价是:表达能力、稳定性或精度。


标题:为什么 Attention 是 O(n²),能不能降?
作者:guobing
地址:http://guobingwei.tech/articles/2025/12/26/1766737132447.html