计算精度问题

如果计算精度误差比较大,那么可能为安培架构的GPU引入的TF32数值类型以及Torch等框架会自动启用TF32计算造成的。TF32可简单理解为FP16的精度,FP32的表示范围,带来了更强的性能但是可能更差的精度。


该问题可参考Torch官方文档:文档
一般来说TF32够用,但是如果权重值有比较大的异常数值(一般没有)时会出现较大误差。下面给一个简单结果对比:

code:

import torch
A = torch.tensor([[113.2017, 7.4502, 39.3118],
                  [-99.4285, 13.2169, 85.9321],
                  [194.0693, -4282.2979, 58.0138]]).float().cuda()
B = torch.tensor([[0.8673, -0.4966, 0.0337],
                  [0.0377, -0.0019, -0.9993],
                  [0.4963, 0.8680, 0.0171]]).float().cuda()
gpu = A @ B
cpu = A.cpu() @ B.cpu()
print('gpu:\n', gpu)
print('cpu:\n', cpu)
print('gpu-cpu:\n', gpu.cpu() - cpu)
print("-" * 10)
A = torch.rand(3, 3).float().cuda()
B = torch.rand(3, 3).float().cuda()
print("A:\n", A)
print("B:\n", B)
gpu = A @ B
cpu = A.cpu() @ B.cpu()
print('gpu:\n', gpu)
print('cpu:\n', cpu)
print('gpu-cpu:\n', gpu.cpu() - cpu)

output:

gpu:
 tensor([[ 1.1795e+02, -2.2091e+01, -2.9597e+00],
        [-4.3079e+01,  1.2396e+02, -1.5093e+01],
        [ 3.5670e+01, -3.7907e+01,  4.2894e+03]], device='cuda:0')
cpu:
 tensor([[ 1.1797e+02, -2.2107e+01, -2.9579e+00],
        [-4.3088e+01,  1.2394e+02, -1.5089e+01],
        [ 3.5666e+01, -3.7882e+01,  4.2868e+03]])
gpu-cpu:
 tensor([[-2.3331e-02,  1.6153e-02, -1.8351e-03],
        [ 9.2430e-03,  2.1469e-02, -3.5658e-03],
        [ 3.8834e-03, -2.4605e-02,  2.6079e+00]])
A:
 tensor([[0.2938, 0.5557, 0.5823],
        [0.7572, 0.8567, 0.8239],
        [0.1630, 0.3278, 0.0526]], device='cuda:0')
B:
 tensor([[0.6398, 0.1599, 0.5362],
        [0.6011, 0.3908, 0.5424],
        [0.5615, 0.7290, 0.6213]], device='cuda:0')
gpu:
 tensor([[0.8490, 0.6888, 0.8207],
        [1.4620, 1.0566, 1.3825],
        [0.3308, 0.1926, 0.2979]], device='cuda:0')
cpu:
 tensor([[0.8489, 0.6886, 0.8207],
        [1.4620, 1.0564, 1.3825],
        [0.3308, 0.1925, 0.2979]])
gpu-cpu:
 tensor([[ 4.5955e-05,  2.0498e-04, -6.3181e-06],
        [ 8.4162e-05,  1.1504e-04, -3.2902e-05],
        [ 1.8775e-06,  6.3717e-05,  2.8968e-05]])

从上面的结果可以发现第一组A和B计算出的结果,GPU对比CPU的误差比较大,主要原因在于这个A矩阵中有比较大的数字(绝对值),而第二组随机初始化的A和B,GPU对比CPU的误差小很多。
如何避免上述误差?可以禁止使用TF32计算。方法为:
torch.backends.cuda.matmul.allow_tf32 = False 禁止矩阵乘法使用tf32

torch.backends.cudnn.allow_tf32 = False 禁止卷积使用tf32

code:

import torch
torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cudnn.allow_tf32 = False
A = torch.tensor([[113.2017, 7.4502, 39.3118],
                  [-99.4285, 13.2169, 85.9321],
                  [194.0693, -4282.2979, 58.0138]]).float().cuda()
B = torch.tensor([[0.8673, -0.4966, 0.0337],
                  [0.0377, -0.0019, -0.9993],
                  [0.4963, 0.8680, 0.0171]]).float().cuda()
gpu = A @ B
cpu = A.cpu() @ B.cpu()
print('gpu:\n', gpu)
print('cpu:\n', cpu)
print('gpu-cpu:\n', gpu.cpu() - cpu)
print("-" * 10)
A = torch.rand(3, 3).float().cuda()
B = torch.rand(3, 3).float().cuda()
print("A:\n", A)
print("B:\n", B)
gpu = A @ B
cpu = A.cpu() @ B.cpu()
print('gpu:\n', gpu)
print('cpu:\n', cpu)
print('gpu-cpu:\n', gpu.cpu() - cpu)

output:

gpu:
 tensor([[ 1.1797e+02, -2.2107e+01, -2.9579e+00],
        [-4.3088e+01,  1.2394e+02, -1.5089e+01],
        [ 3.5666e+01, -3.7882e+01,  4.2868e+03]], device='cuda:0')
cpu:
 tensor([[ 1.1797e+02, -2.2107e+01, -2.9579e+00],
        [-4.3088e+01,  1.2394e+02, -1.5089e+01],
        [ 3.5666e+01, -3.7882e+01,  4.2868e+03]])
gpu-cpu:
 tensor([[0.0000e+00, 0.0000e+00, 2.3842e-07],
        [0.0000e+00, 0.0000e+00, 0.0000e+00],
        [7.6294e-06, 0.0000e+00, 0.0000e+00]])
A:
 tensor([[0.3775, 0.7031, 0.2857],
        [0.7453, 0.2000, 0.9838],
        [0.3098, 0.7035, 0.4328]], device='cuda:0')
B:
 tensor([[0.6860, 0.6289, 0.9266],
        [0.6632, 0.1984, 0.4418],
        [0.4027, 0.1074, 0.3741]], device='cuda:0')
gpu:
 tensor([[0.8404, 0.4076, 0.7673],
        [1.0401, 0.6141, 1.1470],
        [0.8534, 0.3809, 0.7598]], device='cuda:0')
cpu:
 tensor([[0.8404, 0.4076, 0.7673],
        [1.0401, 0.6141, 1.1470],
        [0.8534, 0.3809, 0.7598]])
gpu-cpu:
 tensor([[ 0.0000e+00,  0.0000e+00,  0.0000e+00],
        [-1.1921e-07,  0.0000e+00,  0.0000e+00],
        [ 5.9605e-08, -2.9802e-08,  0.0000e+00]])