Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Numerical differences between shardings in random algorithm #21232

Open
shawnwang18 opened this issue May 14, 2024 · 3 comments
Open

Numerical differences between shardings in random algorithm #21232

shawnwang18 opened this issue May 14, 2024 · 3 comments
Assignees
Labels
bug Something isn't working

Comments

@shawnwang18
Copy link

Description

We are seeing numerical differences between shardings in random number initialization on GPUs. For example, if I have a mesh of DP, FSDP, TP , based on what no of devices I allocate to each of these axes the numerical output of my initialization changes drastically. As a result of this when we are using TP we are seeing divergences in the network.

System info (python version, jaxlib version, accelerator, etc.)

`jax:    0.4.27.dev20240514
jaxlib: 0.4.27.dev20240420
numpy:  1.26.4
python: 3.10.12 (main, Nov 20 2023, 15:14:05) [GCC 11.4.0]
jax.devices (8 total, 8 local): [cuda(id=0) cuda(id=1) ... cuda(id=6) cuda(id=7)]
process_count: 1
platform: uname_result(system='Linux', node='ipp1-2023.nvidia.com', release='5.15.0-88-generic', version='#98-Ubuntu SMP Mon Oct 2 15:18:56 UTC 2023', machine='x86_64')


$ nvidia-smi
Tue May 14 23:37:29 2024
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.14              Driver Version: 550.54.14      CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|=========================================+========================+======================|
|   0  NVIDIA A30                     On  |   00000000:01:00.0 Off |                    0 |
| N/A   27C    P0             31W /  165W |     234MiB /  24576MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
|   1  NVIDIA A30                     On  |   00000000:23:00.0 Off |                    0 |
| N/A   27C    P0             32W /  165W |     234MiB /  24576MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
|   2  NVIDIA A30                     On  |   00000000:41:00.0 Off |                    0 |
| N/A   28C    P0             33W /  165W |     234MiB /  24576MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
|   3  NVIDIA A30                     On  |   00000000:61:00.0 Off |                    0 |
| N/A   26C    P0             32W /  165W |     234MiB /  24576MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
|   4  NVIDIA A30                     On  |   00000000:81:00.0 Off |                    0 |
| N/A   27C    P0             34W /  165W |     234MiB /  24576MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
|   5  NVIDIA A30                     On  |   00000000:A1:00.0 Off |                    0 |
| N/A   28C    P0             33W /  165W |     234MiB /  24576MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
|   6  NVIDIA A30                     On  |   00000000:C1:00.0 Off |                    0 |
| N/A   28C    P0             32W /  165W |     234MiB /  24576MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
|   7  NVIDIA A30                     On  |   00000000:E1:00.0 Off |                    0 |
| N/A   28C    P0             33W /  165W |     234MiB /  24576MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+

+-----------------------------------------------------------------------------------------+
| Processes:                                                                              |
|  GPU   GI   CI        PID   Type   Process name                              GPU Memory |
|        ID   ID                                                               Usage      |
|=========================================================================================|
|    0   N/A  N/A        41      C   python                                          0MiB |
|    1   N/A  N/A        41      C   python                                          0MiB |
|    2   N/A  N/A        41      C   python                                          0MiB |
|    3   N/A  N/A        41      C   python                                          0MiB |
|    4   N/A  N/A        41      C   python                                          0MiB |
|    5   N/A  N/A        41      C   python                                          0MiB |
|    6   N/A  N/A        41      C   python                                          0MiB |
|    7   N/A  N/A        41      C   python                                          0MiB |
+-----------------------------------------------------------------------------------------+
``

The re-produce unittest is as below, it is required to run on a node with 8GPUs

`import jax.numpy as jnp
import jax
from jax.experimental import mesh_utils as jax_mesh_utils
from jax.experimental.pjit import pjit
from jax.sharding import PartitionSpec as P
from jax.sharding import Mesh

MESH_DATA_AXIS = 'data'
MESH_TENSOR_AXIS = 'tensor'
MESH_FSDP_AXIS="pipeline"


# create an FSDP mesh
ici_mesh = (2, 4, 1)  # DP, FSDP, TP
dcn_mesh = (1, 1, 1)  # DP, FSDP, TP
devices = jax_mesh_utils.create_hybrid_device_mesh(ici_mesh, dcn_mesh)
fsdp_mesh = Mesh(devices, (MESH_DATA_AXIS, MESH_FSDP_AXIS, MESH_TENSOR_AXIS))
print(fsdp_mesh.shape)  # (2, 8, 1)

# create an FSDP, TP mesh
ici_mesh = (1, 4, 2)  # DP, FSDP, TP
dcn_mesh = (1, 1, 1)  # DP, FSDP, TP
devices = jax_mesh_utils.create_hybrid_device_mesh(ici_mesh, dcn_mesh)
fsdp_tp_mesh = Mesh(devices, (MESH_DATA_AXIS, MESH_FSDP_AXIS, MESH_TENSOR_AXIS))
print(fsdp_tp_mesh.shape)  # (1, 4, 4)

# create an FSDP, TP, DP mesh
ici_mesh = (2, 2, 2)  # DP, FSDP, TP
dcn_mesh = (1, 1, 1)  # DP, FSDP, TP
devices = jax_mesh_utils.create_hybrid_device_mesh(ici_mesh, dcn_mesh)
fsdp_tp_dp_mesh = Mesh(devices, (MESH_DATA_AXIS, MESH_FSDP_AXIS, MESH_TENSOR_AXIS))
print(fsdp_tp_dp_mesh.shape)  # (2, 2, 4)

# generate the data
batch_size = 32
seq_len = 8192
n_heads = 32
head_dim = 128
emb_dim = 4096
DATA_SUBMESH = (MESH_DATA_AXIS, MESH_FSDP_AXIS)

def gen_data_fn():
    key = jax.random.PRNGKey(43)
    scale = 0.05
    activations = scale * jax.random.normal(key, shape=(batch_size, seq_len, emb_dim), dtype=jnp.bfloat16)
    weights = scale * jax.random.normal(key, shape=(emb_dim, n_heads, head_dim), dtype=jnp.bfloat16)
    return activations, weights

data_fn = pjit(
    gen_data_fn,
    out_shardings=(P(DATA_SUBMESH, None, MESH_TENSOR_AXIS), P(MESH_FSDP_AXIS, MESH_TENSOR_AXIS, None)),
)

# fsdp utputs
with fsdp_mesh:
    act1, weights1 = data_fn()

with fsdp_tp_mesh:
    act2, weights2 = data_fn()

with fsdp_tp_dp_mesh:
    act3, weights3 = data_fn()

# diff b/w fsdp and fsdp,tp
def get_diffs(x, y):
    abs_diff = jnp.abs(x - y)
    max_difference = round(jnp.max(abs_diff), 5)
    min_difference = round(jnp.min(abs_diff), 5)
    avg_difference = round(jnp.mean(abs_diff), 5)
    return max_difference, min_difference, avg_difference

max_diff, min_diff, avg_diff = jax.jit(get_diffs)(act1, act2)
print(f"Differences b/w FSDP and FSDP,TP: Max -- {max_diff}, Min -- {min_diff}, Average -- {avg_diff}")

max_diff, min_diff, avg_diff = jax.jit(get_diffs)(act1, act3)
print(f"Differences b/w FSDP and FSDP,TP,DP: Max -- {max_diff}, Min -- {min_diff}, Average -- {avg_diff}")
`
@shawnwang18 shawnwang18 added the bug Something isn't working label May 14, 2024
@froystig
Copy link
Member

This is fixed by upgrading to partitionable threefry, e.g. by adding the following line to the top of the file (after imports):

jax.config.update('jax_threefry_partitionable', True)

See #18480 for more on the upgrade (which was delayed a bit, but is still planned).

@mattjj
Copy link
Member

mattjj commented May 15, 2024

IIUC this is a bug (unintended behavior) even with jax_threefry_partitionable=False, and also we don't yet know what's causing this bug. Good to know that setting jax_threefry_partitionable=True fixes it though!

@froystig
Copy link
Member

Yes, I consider it a bug as well, but still undiagnosed.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

3 participants