跳转至

L. 稀疏注意力

分数:200 分

任务背景

注意力机制是大模型中的关键组件,在理解和生成任务中起着重要作用。然而,传统的平方级全注意力机制在处理长序列时面临计算复杂度高和内存消耗大的问题。在 LLM 处理超长序列的趋势下,为了解决这些问题,我们仍需要通过稀疏化来进一步降低计算开销,即通过仅关注输入序列中的一部分元素来降低计算复杂度。特别地,本题中我们将考虑一种基于分块的 Top-K 稀疏注意力机制。

本题要求在 Ampere 和 Hopper GPU 架构上,使用 Triton 或 TileLang 这两种 Tile-based DSL 之一实现一个高性能的分块 Top-K 稀疏注意力算子,如果算子能同时通过 Ascend 的测例、性能超过 Torch 实现 40%,可以拿到另外 100 分的 bonus 。

形式化定义

全注意力

给出三个矩阵:查询矩阵 Q \in \mathbb{R}^{N \times d},键矩阵 K \in \mathbb{R}^{M \times d},值矩阵 V \in \mathbb{R}^{M \times d},其中 N 是查询序列长度,M 是键值序列长度,d 是嵌入维度。

注意力分数矩阵 A \in \mathbb{R}^{N \times M} 通过以下公式计算:

A = \text{softmax}\left(\frac{Q K^T}{\sqrt{d}}\right)

其中 softmax 函数在每一行上应用,

\text{softmax}(x_{ij}) = \frac{e^{x_{ij}}}{\sum_k e^{x_{ik}}}

输出矩阵 O \in \mathbb{R}^{N \times d} 为注意力分数矩阵与值矩阵的乘积:O = A V

在实际应用中,为了保持因果性,通常会对注意力分数矩阵进行掩码处理,使得查询位置只能关注其之前的位置。

具体而言,掩码矩阵 Mask \in \mathbb{R}^{N \times M} 定义为:

Mask_{ij} = \begin{cases} 0, & \text{if } j \leq i \\ -\inf, & \text{otherwise} \end{cases}

应用掩码后,注意力分数矩阵的计算变为:

A = \text{softmax}\left(\frac{Q K^T}{\sqrt{d}} + Mask\right)

稀疏注意力

容易看出,全注意力的计算复杂度为 O(N \times M \times d),当 NM 较大时,计算和内存开销显著增加。为了解决这个问题,稀疏注意力机制通过限制每个查询位置只能关注部分键位置。

为了降低复杂度,我们将 M 维度的键值序列划分为大小为 bs 的块。对于每个查询位置 i,其仅关注由索引矩阵 Index 指定的 top_K 个块,其中每一行包含对应查询位置所关注的键值块的索引。掩码条件演变为:

Mask_{ij} = \begin{cases} 0, & \text{if } \lfloor \frac{j}{bs} \rfloor \in Index[i,:] \text{ and } j \leq i \\ -\inf, & \text{otherwise} \end{cases}

多头注意力和分组注意力

现代大模型通常采用多头注意力机制(Multi-Head Attention, MHA),将查询、键、值矩阵分为多个头进行并行计算。假设有 h 个头,每个头的维度为 d_h = d / h

对于每个头 i,我们有对应的查询、键、值矩阵 Q_i, K_i, V_i。每个头独立计算注意力输出 O_i,最后将所有头的输出拼接起来:

O = \text{Concat}(O_1, O_2, \ldots, O_h)

此外,现代大模型中还引入了分组注意力机制(GQA)。在分组注意力中,多个查询头形成一组,共享同一组键值头(和稀疏索引)。我们假设有 h_q 个查询头和 h_k 个键值头,且查询头数量 h_q 是键值头数量 h_k 的整数倍 g,即 h_q = g \times h_k

实践中,引入这一机制在保持模型能力的同时,为推理带来了内存效率的提升。对于本题的稀疏注意力场景,分组大小也包含了工程上的考虑,这里我们可以考虑较大的 GQA 分组大小的情况。

最后,我们引入批次维度 B,不同批次的数据计算完全并行与独立。

输入与输出

输入包含:

  • 查询张量 Q:形状 [B, h_q, N, d_h],类型float16
  • 键张量 K:形状 [B, h_k, M, d_h],类型float16
  • 值张量 V:形状 [B, h_k, M, d_h],类型float16
  • 稀疏索引张量 Index:形状 [B, h_k, N, top_K],类型int32
  • 超参数:
    • 块大小 block_size:整数,表示键值序列划分的块大小 bs
    • 缩放因子 sm_scale:浮点数,表示注意力分数的缩放因子 \frac{1}{\sqrt{d_h}}

Index 中的索引值可能为 -1,用于填充关注的键值块数量不足 top_K 的情况。

输出一个张量:

  • 输出张量 O:形状 [B, h_q, N, d_h],类型float16

你需要在 solution.py 中实现以下接口:

def sparse_attention(
    q: torch.Tensor,  # [B, H_Q, N, D_H]
    k: torch.Tensor,  # [B, H_K, M, D_H]
    v: torch.Tensor,  # [B, H_K, M, D_H]
    index: torch.Tensor,  # [B, H_K, N, TOP_K]
    block_size: int,
    sm_scale: float,
) -> torch.Tensor:  # returns [B, H_Q, N, D_H]

测试框架会调用你的实现,在预热一定轮次后进行结果验证和性能测量(考虑到首次 JIT 编译开销等)。

Handout 说明

下发的 Handout 包含以下文件:

  • benchmark.py: 主评测脚本。
  • solution.py: 你的工作区。你需要在此文件中实现sparse_attention 函数。目前文件中有 pytorch 实现的 ref
  • requirements.txt: 依赖库列表。

运行评测

使用 benchmark.py 来测试正确性和性能:

# 小规模测试 (用于快速校验正确性)
python benchmark.py --size small
# 完整性能测试
python benchmark.py --size all

可选参数:

  • --size: 选择测试规模 (small,medium,large,all)。
  • --no-check: 跳过正确性校验(仅测速)。

提交要求

你需要提交 solution.py文件。评测系统将替换环境中的solution.py并运行benchmark.py 。请确保你的实现自包含或只引用了标准库/Triton/TileLang 。

测试点说明

测试点将覆盖不同的参数组合。主要评测场景如下:

场景 Batch H_Q H_K N M D_H Top_K Block Size
Small 1 16 1 128 128 64 4 32
Medium 2 64 4 1024 1024 64 16 64
Large-1 4 64 4 4096 4096 128 32 64
Large-2 8 64 4 8192 8192 128 32 64

注意:

  • 只有通过正确性检查的实现才能获得性能分。

选手测试环境说明

本次比赛平台没有Hopper卡,选手可以使用我们分发的Autodl子账号做题,提交Z-1题可以获得账号。每个人有60元额度。因正常使用导致的余额不够可以找我们再加。不得用于比赛之外的用途。

评分方法

  1. 正确性评分会运行在 NVIDIA H20-96G、 NVIDIA A100和华为 Ascend 910B-2上,必须全部通过
  2. 性能评分会运行在 NVIDIA H20-96G 上。记第一名的成绩为X,你的得分为Y,那本题你的排行榜得分为 Y/X * 200
  3. 对于华为 Ascend 910B-2,如果性能 Torch 实现 40% 以上,可以获得组委会定制晶圆钥匙扣一个

附件