Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BF16 matmul slower than F32 matmul on T4 GPU #21212

Open
sagelywizard opened this issue May 13, 2024 · 2 comments
Open

BF16 matmul slower than F32 matmul on T4 GPU #21212

sagelywizard opened this issue May 13, 2024 · 2 comments
Assignees
Labels
bug Something isn't working

Comments

@sagelywizard
Copy link
Contributor

sagelywizard commented May 13, 2024

Description

BF16 matmul appears to be slower than F32 matmul on T4. From my test, BF16 appears to be half the speed. I believe this is a bug and bf16 should be the same speed (or possibly better) than f32.

You can repro in a T4 colab with the following:

import jax
import jax.numpy as jnp
import timeit

def flops_calc(exponent=16, iters=10, dtype=jnp.float16):
  key = jax.random.PRNGKey(0)
  x_i = 2**exponent
  x_j = 4096
  y_j = 4096
  flop_count = x_i * x_j * y_j * 2
  x = jax.random.uniform(key, (x_i, x_j), dtype=dtype)
  y = jax.random.uniform(key, (x_j, y_j), dtype=dtype)
  matmul = jax.jit(lambda a, b: a @ b)
  matmul(x, y).block_until_ready()
  seconds_per_iter = timeit.timeit(lambda: matmul(x, y).block_until_ready(), number=iters) / iters
  flops = flop_count / seconds_per_iter
  return flop_count, flops

def flops_to_tflops(flops):
  return flops / 1e12

for dtype in [jnp.bfloat16, jnp.float16, jnp.float32]:
  print(dtype)
  for i in range(16):
    op_count, flops = flops_calc(exponent=i, dtype=dtype)
    print(f'Total TFLOP Count: {op_count / 1e12:.5f} | TFLOPS: {flops_to_tflops(flops):.2f}')
  print()

This results in the following output:

<class 'jax.numpy.bfloat16'>
Total TFLOP Count: 0.00003 | TFLOPS: 0.10
Total TFLOP Count: 0.00007 | TFLOPS: 0.04
Total TFLOP Count: 0.00013 | TFLOPS: 0.09
Total TFLOP Count: 0.00027 | TFLOPS: 0.16
Total TFLOP Count: 0.00054 | TFLOPS: 0.35
Total TFLOP Count: 0.00107 | TFLOPS: 0.61
Total TFLOP Count: 0.00215 | TFLOPS: 1.09
Total TFLOP Count: 0.00429 | TFLOPS: 1.22
Total TFLOP Count: 0.00859 | TFLOPS: 1.74
Total TFLOP Count: 0.01718 | TFLOPS: 2.27
Total TFLOP Count: 0.03436 | TFLOPS: 2.36
Total TFLOP Count: 0.06872 | TFLOPS: 2.36
Total TFLOP Count: 0.13744 | TFLOPS: 2.16
Total TFLOP Count: 0.27488 | TFLOPS: 2.19
Total TFLOP Count: 0.54976 | TFLOPS: 2.14
Total TFLOP Count: 1.09951 | TFLOPS: 2.09

<class 'jax.numpy.float16'>
Total TFLOP Count: 0.00003 | TFLOPS: 0.11
Total TFLOP Count: 0.00007 | TFLOPS: 0.22
Total TFLOP Count: 0.00013 | TFLOPS: 0.44
Total TFLOP Count: 0.00027 | TFLOPS: 0.92
Total TFLOP Count: 0.00054 | TFLOPS: 1.76
Total TFLOP Count: 0.00107 | TFLOPS: 3.53
Total TFLOP Count: 0.00215 | TFLOPS: 6.99
Total TFLOP Count: 0.00429 | TFLOPS: 14.04
Total TFLOP Count: 0.00859 | TFLOPS: 23.47
Total TFLOP Count: 0.01718 | TFLOPS: 25.02
Total TFLOP Count: 0.03436 | TFLOPS: 35.24
Total TFLOP Count: 0.06872 | TFLOPS: 37.16
Total TFLOP Count: 0.13744 | TFLOPS: 31.20
Total TFLOP Count: 0.27488 | TFLOPS: 24.41
Total TFLOP Count: 0.54976 | TFLOPS: 23.02
Total TFLOP Count: 1.09951 | TFLOPS: 22.13

<class 'jax.numpy.float32'>
Total TFLOP Count: 0.00003 | TFLOPS: 0.08
Total TFLOP Count: 0.00007 | TFLOPS: 0.16
Total TFLOP Count: 0.00013 | TFLOPS: 0.31
Total TFLOP Count: 0.00027 | TFLOPS: 0.66
Total TFLOP Count: 0.00054 | TFLOPS: 1.34
Total TFLOP Count: 0.00107 | TFLOPS: 2.61
Total TFLOP Count: 0.00215 | TFLOPS: 4.18
Total TFLOP Count: 0.00429 | TFLOPS: 4.92
Total TFLOP Count: 0.00859 | TFLOPS: 5.32
Total TFLOP Count: 0.01718 | TFLOPS: 4.59
Total TFLOP Count: 0.03436 | TFLOPS: 4.31
Total TFLOP Count: 0.06872 | TFLOPS: 4.19
Total TFLOP Count: 0.13744 | TFLOPS: 4.04
Total TFLOP Count: 0.27488 | TFLOPS: 4.30
Total TFLOP Count: 0.54976 | TFLOPS: 4.31
Total TFLOP Count: 1.09951 | TFLOPS: 4.37

Note how bf16 is much slower than f32. (side note: I also see that bf16 is way slower than f16, but my understanding is that it's because t4 doesn't support bf16, so JAX alters the computation to use f32).

System info (python version, jaxlib version, accelerator, etc.)

jax:    0.4.26
jaxlib: 0.4.26
numpy:  1.25.2
python: 3.10.12 (main, Nov 20 2023, 15:14:05) [GCC 11.4.0]
jax.devices (1 total, 1 local): [cuda(id=0)]
process_count: 1
platform: uname_result(system='Linux', node='063d876e5268', release='6.1.85+', version='#1 SMP PREEMPT_DYNAMIC Sun Apr 28 14:29:16 UTC 2024', machine='x86_64')


$ nvidia-smi
Mon May 13 18:04:34 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.104.05             Driver Version: 535.104.05   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|=========================================+======================+======================|
|   0  Tesla T4                       Off | 00000000:00:04.0 Off |                    0 |
| N/A   75C    P0              30W /  70W |  11457MiB / 15360MiB |      0%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
                                                                                         
+---------------------------------------------------------------------------------------+
| Processes:                                                                            |
|  GPU   GI   CI        PID   Type   Process name                            GPU Memory |
|        ID   ID                                                             Usage      |
|=======================================================================================|
+---------------------------------------------------------------------------------------+
@sagelywizard sagelywizard added the bug Something isn't working label May 13, 2024
@jakevdp
Copy link
Collaborator

jakevdp commented May 13, 2024

Thanks for the report!

Here's a ref to the T4 architecture spec: https://images.nvidia.com/aem-dam/Solutions/design-visualization/technologies/turing-architecture/NVIDIA-Turing-Architecture-Whitepaper.pdf

T4 doesn't support bfloat16, but JAX (via the XLA GPU compiler) should be falling back to float32. The fact that the result is appreciably slower than native float32 may indicate a bug in the XLA GPU compiler.

I'd suggest reporting at https://github.com/openxla/xla

@cheshire
Copy link
Member

Replied on the OpenXLA bug.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

4 participants