DashAttention:可微分与自适应稀疏分层注意力

arXiv: 2605.18753v1

论文信息

标题: DashAttention: Differentiable and Adaptive Sparse Hierarchical Attention

作者: Yuxiang Huang, Nuno M. T. Gonçalves, Federico Alvetreti, et al.

发布日期: 2026-05-18

arXiv ID: 2605.18753v1

PDF 链接: 下载 PDF

背景与动机

在大语言模型(LLM)处理长达数万令牌的上下文时,标准 softmax 注意力会为每一个键值对分配非零权重,导致信息分散(dispersion),计算开销急剧增长。近年来,层次稀疏注意力成为缓解这一问题的核心思路,其代表方法如 NSA 和 InfLLMv2 通常采用“粗选—精算”两阶段策略:先将上下文分割为块,通过压缩表示计算块级相关性分数,再用 top‑k 选出固定数量的块,最后在这些块内执行细粒度的 softmax 注意力。

然而,固定预算的 top‑k 操作存在三个根本缺陷。第一,它假定每个查询所需的令牌数量是恒定的,忽视了不同查询、不同注意力头对信息量的实际需求差异。第二,top‑k 本身不可微,切断了从最终损失到块选择阶段的梯度通路,使块路由参数无法端到端训练,往往需要引入额外的辅助损失或舍弃对摘要层的优化。第三,基于 softmax 的块分数聚合在多查询头分组注意力(GQA)中会重新引入分散性,因为 softmax 输出的密集概率在进行头聚合时仍会向整个上下文蔓延,削弱了稀疏机制在长程任务中的集中优势。

论文提出的 DashAttention,全称 Differentiable and Adaptive Sparse Hierarchical Attention,正是要打破以上局限,构建一个全可微查询自适应的稀疏层次注意力机制。其核心理念是:将粗粒度路由从硬性的 top‑k 替换为 α‑entmax 变换,让模型根据当前查询动态决定关注哪些块,并使路由选择保持可微,从而让梯度可以一路贯穿至块摘要层,实现彻底的端到端学习。

方法剖析

DashAttention 包含三个阶段,分别完成块摘要、稀疏路由与精细化注意力计算。

阶段 0:局部块摘要

与传统方法采用平均池化或 MLP 不同,DashAttention 引入了一种可学习的局部注意力摘要头。具体而言,为每一个键值头设置一个可学习的摘要查询向量 qˉ\mathbf{\bar{q}},初始化为零。对于每个分块 cc,块摘要通过对块内键向量执行局部缩放点积注意力得到:

kˉc(r)=tCcexp(qˉ(r),kt(r)/dh)uCcexp(qˉ(r),ku(r)/dh)kt(r).\mathbf{\bar{k}}_c^{(r)} = \sum_{t \in C_c} \frac{\exp(\langle \mathbf{\bar{q}}^{(r)}, \mathbf{k}_t^{(r)} \rangle / \sqrt{d_h})}{\sum_{u \in C_c} \exp(\langle \mathbf{\bar{q}}^{(r)}, \mathbf{k}_u^{(r)} \rangle / \sqrt{d_h})} \mathbf{k}_t^{(r)}.

由于初始时 qˉ\mathbf{\bar{q}} 为零,该机制会自然退化为均匀平均池化,使得从全注意力预训练模型向稀疏模型的过渡十分平滑。随着训练推进,模型逐渐学会根据内容对块内的键进行加权,生成更具表达力的摘要。

阶段 1:α‑entmax 块路由

这是 DashAttention 区别于其他方法的关键。对于每个查询头 hh 与对应的键值头 rr,计算块级相关性分数 zˉi,c=qi(h),kˉc(r)/dh\bar{z}_{i,c} = \langle \mathbf{q}_i^{(h)}, \mathbf{\bar{k}}_c^{(r)} \rangle / \sqrt{d_h},然后通过 α‑entmax 变换得到稀疏的块概率分布:

w^i(h)=α-entmax ⁣(γzˉi).\mathbf{\hat{w}}_i^{(h)} = \alpha\text{-entmax}\!\left(\gamma \bar{\mathbf{z}}_i\right).

α‑entmax 是广义的稀疏注意力映射,当 α>1\alpha > 1 时可以产生恰好为零的概率,其实际支撑集(非零块)由分数本身的几何关系决定,而非预先指定的 kk。这意味着信息密集的查询可能分配到更多块,而针对性强的查询可能只保留极少量块,实现了真正的动态稀疏分配

在 GQA 场景下,需要保持同一查询组内各头的支撑一致。DashAttention 采用 α‑entmax 概率的直接平均,再根据联合支撑生成统一的块掩码,从而确保既满足硬件对齐要求,又避免 softmax 头聚合带来的额外分散性。

阶段 2:先验诱导稀疏 softmax

被选中的块需要展开回令牌分辨率,进行精细注意力计算。为避免重新引入密集 softmax 的分散性,同时保留端到端可微,DashAttention 从 softmax 的变分形式出发,将 KL 正则项中的均匀先验替换为由阶段 1 的路由分布构造的先验 g\mathbf{g}。具体地,给定块级分数 w\mathbf{w},令先验 g\mathbf{g} 在被选块上正比于 w1/σ\mathbf{w}^{1/\sigma},在对角块(当前块附近)上分配剩余概率。该先验通过一个偏置项 d\mathbf{d} 优雅地融入到 softmax 计算中:

oi= ⁣ ⁣ ⁣jRiDiexp(zi,j+di,j)texp(zi,t+di,t)vj,o_i = \!\!\!\sum_{j \in \mathcal{R}_i \cup \mathcal{D}_i} \frac{\exp(z_{i,j} + d_{i,j})}{\sum_{t} \exp(z_{i,t} + d_{i,t})} \mathbf{v}_j,

其中 di,jd_{i,j} 等于 (logwi,c(j)μi)/σ(\log w_{i,c(j)} - \mu_i)/\sigma 在选定块上,对角块上则为零。这一设计使得阶段 2 天然继承阶段 1 的稀疏性和可微性,且直接兼容 FlashAttention 等高效核函数。论文还证明了该偏置等价于引入以 gg 为先验的 KL 正则化,从而保持了严格的概率解释。

GPU 感知的实现

DashAttention 的三个阶段分别对应三个高度融合的 Triton 核。阶段 0 利用在线 softmax 完成块摘要,无额外 HBM 往返。阶段 1 将块分数与自适应阈值求解保持在寄存器中,并利用 AdaSplash‑2 的快速算法;掩码按 int32 位打包存储。阶段 2 则是在激活块上带偏置执行 FlashAttention,回避了对稀疏索引的二次扫描。整个实现兼顾了预填充和解码两种推断模式,能够充分发挥现代 GPU 的并行能力。

创新点与理论贡献

  1. 全可微的层次稀疏注意力:首次用 α‑entmax 替换 top‑k 块路由,重建了从损失函数到块摘要层的完整梯度链路,使块选择从离散硬性变为连续稀疏映射,训练更稳定且无需辅助损失。

  2. 自适应动态稀疏:放弃了固定的块预算,模型可以在不同层、不同头、甚至不同查询上自适应地分配稀疏性,实验表明早期层通常更密集、中间层更稀疏,自动涌现出类似预算分配策略的效果。

  3. 非分散性理论保证:论文从理论上证明,在 GQA 头聚合时,α‑entmax 头聚合保持非分散性,而 softmax 头聚合会导致分散性重新出现。这解释了 DashAttention 在长程检索任务(如 RULER 的多密钥查找)上对 NSA、InfLLMv2 的显著优势。

  4. 先验诱导的软性注意力整合:通过 KL 正则化形式将块级先验注入令牌级软性 softmax,使得稀疏粗选与细粒度精算融为一体,同时用 σ 参数灵活控制先验强度,便于与已有的全注意力推理流程兼容。

实验结果分析

实验选取了 MiniCPM‑4 的 1B、3B 和 8B 模型,在长上下文持续预训练后进行评估。在 16K 上下文长度下,DashAttention 在所有模型规模上均达到了与全注意力相当的精度,且稀疏度约 75% —— 比 NSA 和 InfLLMv2 略高的稀疏度下获得了更好的平均分数。尤其在 RULER 的多密钥检索任务(MK1–MK3)上,DashAttention 显著超越了两种基线,例如 8B 模型在 MK2 上达到 96%,而 NSA 仅为 34%,InfLLMv2 为 82%,充分展示了动态路由在需要精确定位信息时的优势。

在 HELMET 综合评测中,DashAttention 在召回、ICL、重排序等子任务上同样整体领先,且能在测试时切换为标准的全注意力 softmax,实现即插即用,无需重新训练。效率基准测试显示,DashAttention 相较 FlashAttention‑3 的预填充最高加速 3.1 倍,解码最高加速 3.36 倍;在同等稀疏度下也比 InfLLMv2 更快,尤其在解码场景中得益于掩码融合遍历避免了显式的 top‑k 和索引实例化。

成本效益的帕累托前沿分析进一步表明,随着稀疏度增加到 90%,DashAttention 仍能维持 39.4% 的 HELMET 综合准确率,而 InfLLMv2 下降至 30%,NSA 仅有 20%,凸显了其在极端压缩下的鲁棒性。

实践应用与未来方向

对于实际部署,DashAttention 提供了一个低摩擦的升级路径:可直接基于已有的全注意力预训练模型进行长上下文微调,无需新增大量参数;在推理时,可选择稀疏模式加速,亦可无缝切换至标准 softmax 以复用 vLLM、SGLang 等成熟服务框架的优化。开发者在集成时,需注意调整 α 和温度参数 γ 以平衡稀疏度与精度,通常采用 α=1.5、σ 很大(如 10810^8)的工作点,既能保留可微性又避免先验过强导致性能下降。

未来的优化工作可以从几个方向展开:将 DashAttention 正式集成到 vLLM 和 SGLang 框架中,实现端到端的长上下文服务加速;进一步探索在混合架构(如 Mamba‑Transformer)中的应用,以及将这种可微路由思想扩展到令牌级别的细粒度稀疏注意中;此外,引入自动学习 α 或动态调整 γ 可能进一步提升模型在不同任务下的适应能力。

总结与展望

DashAttention 通过 α‑entmax 驱动的层次注意力设计,成功解决了传统层次稀疏注意中固定预算和不可微两大核心痛点,实现了真正查询自适应的、全链路可微的稀疏注意机制。在多个长上下文评估集上,它不仅追平全注意力性能,更在极端稀疏场景下表现出色,且速度大幅领先同类方法。这项工作重新定义了层次稀疏注意的设计范式,为低成本、高效率的长上下文建模提供了可靠蓝图。随着对稀疏性理论理解的深化和硬件协同设计的推进,可微动态稀疏有望成为下一代大模型上下文扩展的基础组件。