H. Posit GEMM¶
分数:200分
用 int8 实现 fp64 乘法,老黄刀法能耐我何?
背景¶
【提醒:附件 Tab 中有 主程序和数据生成器 】 【1.20提醒:附件已更新,请重新下载】
Posit (Type III Unum) 是一种新的浮点数格式,对于同样的位数,Posit 和 IEEE754 浮点数具有类似的动态范围,同时可以更高精度表示接近 1 的数,没有 IEEE754 的逐渐下溢(gradual underflow) 和 ±0 问题,更加容易实现。这里是一个关于 posit numbers 的介绍。
下面介绍本题中可能用到的关于 posit numbers 的知识,其中略去了一些细节:
一个 posit number 由四个部分组成:
- 最高位是 sign bit,s \in \{0,1\},当 s=1 时,我们在解码时要把后面的二进制串取负(在补码意义下)(换言之,x 和 -x 的 posit 二进制表示恰也是补码意义下的相反数)
- 然后是至少两位(长度不固定)regime bits,其表示的值记为 r(见后文)
- 然后是 es 位的指数 (exponent bits),记为 e \in [0,2^{es}),es 是一个固定参数
- 剩余位数全是尾数 (fraction bits),表示一个 [0,1) 之间的二进制小数 f,最高位的权重是 1/2
这个数表示实数
其中 r 的表示方式如下:regime bits 段包括第一位和。如果第一位是 1,则 r 为 1 的个数减 1,而若第一位是 0,表示 r 为 0 的个数的相反数。举例如下:
001: r=-210: r=0
其中还有如下的特殊情况(在本题中几乎不会遇到):
- 全
0:表示 0 1000...:表示NaR,即无法计算出实数结果- regime bits 过长,剩余不够 es 位:只记录 e 的高位,低位补 0,同时 f=0
- 在此定义中,我们不考虑
Inf
看一个 16 位,es=3 的例子:

不难看出,posit number 对于接近 1 的数有较高的精度,同时也可以表示绝对值极小或极大的数,但精度上有所损失。
说明¶
在本题中,你将用 CUDA 处理 64 位,es=3 的 posit number 的矩阵乘法。为了数值稳定性和实现方便,我们保证输入的矩阵是从 [-2+\epsilon, 2-\epsilon] 中均匀随机选取的。
你需要在 impl.cu 中实现函数 void cuda_posit_gemm_d(const uint64_t *dA, const uint64_t *dB, uint64_t *dC, int n, int m, int k),表示执行 C = AB,dA, dB, dC 是三个 GPU global memory 上的数组,按 row major 存储。A, B, C 的形状分别为 n x k, k x m, n x m。
评测环境¶
容器镜像:cuda,已安装 CUDA 12.6 和 cublas。保证nvcc在PATH中。
评测的 GPU 为完整独占的 NVIDIA L40,显存为 48G。
评测方式¶
提交一个impl.cu文件,其中包含一个实现了cuda_posit_gemm_d函数的CUDA程序。
我们用 nvcc driver.cpp your.cu -o driver -O3 -std=c++20 --expt-relaxed-constexpr --extended-lambda -arch=sm_89 来编译你的程序。
我们会用同样的数据调用你的函数多次,并取(预热后的)平均时间。请不要在多次调用间保存状态。在你的函数返回后,我们将调用 cudaDeviceSynchronize() 并计时。
要通过本题,你的程序和标准程序之间的绝对误差不能超过 4 \times 10^{-14}。
本题将在 Ada Lovelace 架构 GPU 上评测,其 compute capability 为 8.9。
| 编号 | 约束 | 分数百分比 | 时间限制 | 满分时间 |
|---|---|---|---|---|
| 1 | n,m,k=32 | 10% | 1s | 20ms |
| 2 | n,m,k=4096 | 20% | 3s | 40ms |
| 3 | n,m,k=10240 | 20% | 3s | 450ms |
| 4 | n,m,k=12000 | 50% | 20s | 750ms |
每个测试点得到正确结果获得基本分数,性能分数与运行时间倒数成正比。
final_score = correct ? perf_score * min(goal / time, 1) : 0
提示¶
- 下发文件中在
include文件夹提供了 universal 库,提供了 posit number 模板类。 - 学术上有很多用整形实现高精度浮点数乘法的方法,ozIMMU及其背后的 Ozaki scheme 是非常好的参考资料。
- 程序会被链接到cublas
- 我们提供了几个小工具:
matgen <file> <N> <M>:随机生成 N \times M 矩阵;matdiff <file1> <file2>:计算两个矩阵的最大绝对误差和误差最大的位置;matprint <file>:输出矩阵每个位置的的二进制形式和十进制近似值;baseline <input1> <input2> <output>:在 CPU 上计算 posit 矩阵乘法;driver <input1> <input2> <output>:用你实现的函数计算 posit 矩阵乘法;test.sh <N>:用 N \times N 的矩阵测试你的程序。