You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
BF16 matmul appears to be slower than F32 matmul on T4. From my test, BF16 appears to be half the speed. I believe this is a bug and bf16 should be the same speed (or possibly better) than f32.
<class 'jax.numpy.bfloat16'>
Total TFLOP Count: 0.00003 | TFLOPS: 0.10
Total TFLOP Count: 0.00007 | TFLOPS: 0.04
Total TFLOP Count: 0.00013 | TFLOPS: 0.09
Total TFLOP Count: 0.00027 | TFLOPS: 0.16
Total TFLOP Count: 0.00054 | TFLOPS: 0.35
Total TFLOP Count: 0.00107 | TFLOPS: 0.61
Total TFLOP Count: 0.00215 | TFLOPS: 1.09
Total TFLOP Count: 0.00429 | TFLOPS: 1.22
Total TFLOP Count: 0.00859 | TFLOPS: 1.74
Total TFLOP Count: 0.01718 | TFLOPS: 2.27
Total TFLOP Count: 0.03436 | TFLOPS: 2.36
Total TFLOP Count: 0.06872 | TFLOPS: 2.36
Total TFLOP Count: 0.13744 | TFLOPS: 2.16
Total TFLOP Count: 0.27488 | TFLOPS: 2.19
Total TFLOP Count: 0.54976 | TFLOPS: 2.14
Total TFLOP Count: 1.09951 | TFLOPS: 2.09
<class 'jax.numpy.float16'>
Total TFLOP Count: 0.00003 | TFLOPS: 0.11
Total TFLOP Count: 0.00007 | TFLOPS: 0.22
Total TFLOP Count: 0.00013 | TFLOPS: 0.44
Total TFLOP Count: 0.00027 | TFLOPS: 0.92
Total TFLOP Count: 0.00054 | TFLOPS: 1.76
Total TFLOP Count: 0.00107 | TFLOPS: 3.53
Total TFLOP Count: 0.00215 | TFLOPS: 6.99
Total TFLOP Count: 0.00429 | TFLOPS: 14.04
Total TFLOP Count: 0.00859 | TFLOPS: 23.47
Total TFLOP Count: 0.01718 | TFLOPS: 25.02
Total TFLOP Count: 0.03436 | TFLOPS: 35.24
Total TFLOP Count: 0.06872 | TFLOPS: 37.16
Total TFLOP Count: 0.13744 | TFLOPS: 31.20
Total TFLOP Count: 0.27488 | TFLOPS: 24.41
Total TFLOP Count: 0.54976 | TFLOPS: 23.02
Total TFLOP Count: 1.09951 | TFLOPS: 22.13
<class 'jax.numpy.float32'>
Total TFLOP Count: 0.00003 | TFLOPS: 0.08
Total TFLOP Count: 0.00007 | TFLOPS: 0.16
Total TFLOP Count: 0.00013 | TFLOPS: 0.31
Total TFLOP Count: 0.00027 | TFLOPS: 0.66
Total TFLOP Count: 0.00054 | TFLOPS: 1.34
Total TFLOP Count: 0.00107 | TFLOPS: 2.61
Total TFLOP Count: 0.00215 | TFLOPS: 4.18
Total TFLOP Count: 0.00429 | TFLOPS: 4.92
Total TFLOP Count: 0.00859 | TFLOPS: 5.32
Total TFLOP Count: 0.01718 | TFLOPS: 4.59
Total TFLOP Count: 0.03436 | TFLOPS: 4.31
Total TFLOP Count: 0.06872 | TFLOPS: 4.19
Total TFLOP Count: 0.13744 | TFLOPS: 4.04
Total TFLOP Count: 0.27488 | TFLOPS: 4.30
Total TFLOP Count: 0.54976 | TFLOPS: 4.31
Total TFLOP Count: 1.09951 | TFLOPS: 4.37
Note how bf16 is much slower than f32. (side note: I also see that bf16 is way slower than f16, but my understanding is that it's because t4 doesn't support bf16, so JAX alters the computation to use f32).
System info (python version, jaxlib version, accelerator, etc.)
jax: 0.4.26
jaxlib: 0.4.26
numpy: 1.25.2
python: 3.10.12 (main, Nov 20 2023, 15:14:05) [GCC 11.4.0]
jax.devices (1 total, 1 local): [cuda(id=0)]
process_count: 1
platform: uname_result(system='Linux', node='063d876e5268', release='6.1.85+', version='#1 SMP PREEMPT_DYNAMIC Sun Apr 28 14:29:16 UTC 2024', machine='x86_64')
$ nvidia-smi
Mon May 13 18:04:34 2024
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.104.05 Driver Version: 535.104.05 CUDA Version: 12.2 |
|-----------------------------------------+----------------------+----------------------+
| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|=========================================+======================+======================|
| 0 Tesla T4 Off | 00000000:00:04.0 Off | 0 |
| N/A 75C P0 30W / 70W | 11457MiB / 15360MiB | 0% Default |
| | | N/A |
+-----------------------------------------+----------------------+----------------------+
+---------------------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=======================================================================================|
+---------------------------------------------------------------------------------------+
The text was updated successfully, but these errors were encountered:
T4 doesn't support bfloat16, but JAX (via the XLA GPU compiler) should be falling back to float32. The fact that the result is appreciably slower than native float32 may indicate a bug in the XLA GPU compiler.
Description
BF16 matmul appears to be slower than F32 matmul on T4. From my test, BF16 appears to be half the speed. I believe this is a bug and bf16 should be the same speed (or possibly better) than f32.
You can repro in a T4 colab with the following:
This results in the following output:
Note how bf16 is much slower than f32. (side note: I also see that bf16 is way slower than f16, but my understanding is that it's because t4 doesn't support bf16, so JAX alters the computation to use f32).
System info (python version, jaxlib version, accelerator, etc.)
The text was updated successfully, but these errors were encountered: