Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

shard_map can cause misses in compilation_cache #21236

Closed
jaro-sevcik opened this issue May 15, 2024 · 2 comments
Closed

shard_map can cause misses in compilation_cache #21236

jaro-sevcik opened this issue May 15, 2024 · 2 comments
Assignees
Labels
bug Something isn't working

Comments

@jaro-sevcik
Copy link
Contributor

Description

From version 0.4.27, grad of shard_map can produce different MLIR code on different runs, so the compilation cache does not trigger. See the bash invocation, example output and python code below for repro.

Bash commands:

mkdir -p /jaxcache
rm /jaxcache/*
for i in {0..10}; do python3 sm.py; done
ls -1 /jaxcache

Output:

jit__multi_slice-71ebc5369b07c301e42e24ed1cdbd37e22fd425ff2b2ade831e6e42e64202941
jit_f-5db70a97edfbe721726b707afe0a483180f203ddb3a1c33b1069bcf25e8d20f1
jit_f-62e5467e3e3571ccd74523e6d41b663d152115d768fed7c3bde4789852c42066
jit_f-77404e9c95598acb3e1f983292c6fb44977150a9094b045a9fcbe23cea2d8c2e
jit_f-979aeff8e64675e9f7e8717c3c8c3c2c62d3ae0feffd08a2b78e0274efbd48c5
jit_f-ac9227f0d824cc94658d1c1af469e74df2b141dab49084eac90eb9d63b67537d
jit_f-b7e277b16ebf4a975df01eaac341077401b11aa583152907bafc929bb1575d5f
jit_f-df6c161b29e11544741cb6fbdc96fb5846447706881b0a021e6a9e7d0d557107
jit_f-e21f0d8b815e109de84081dd56209ee96e884cffc8f61dc35e0bbd3a909c2b16
jit_iota-84480c2d7065f65c1130244dfeb95dc97707b557937b0baba3fb02449f5e18d8
jit_reshape-ed18f2632c2a9023ece9f33d38d2f0e2da442805ae8d9ed6b65c8ee830d7a400

Code for sm.py:

from functools import partial
import numpy as np
import jax
import jax.numpy as jnp
from jax.sharding import Mesh, PartitionSpec as P
from jax.experimental import mesh_utils
from jax.experimental.shard_map import shard_map

jax.experimental.compilation_cache.compilation_cache.set_cache_dir('/jaxcache')

devices = mesh_utils.create_device_mesh((jax.device_count(), 1, 1, 1))
mesh = Mesh(devices, axis_names=('a', 'b', 'c', 'd'))

a = jnp.arange(float(jax.device_count())).reshape(jax.device_count(), 1)

@partial(shard_map, mesh=mesh, in_specs=(P('a', 'b'),),
         out_specs=P(None, None))
def g(a_block):
  return jax.lax.psum(jnp.sin(a_block), ('a', 'b'))

def f(x):
  return jnp.sum(g(x))

f = jax.jit(jax.grad(f))

f(a)
print(f.lower(a).as_text())

The likely culprit is the code for sharding the residuals, where a recent change seems to have introduced set ordering for residual axis names, but set ordering is not guaranteed to be the same from run to run. In the MLIR code dumped by the example above, observe that jax.result_info for the residual is a random permutation of the ('a', 'b', 'c', 'd') list. For instance, in the output below, the residual axis ordering (output of the shmap_body function) is ('d', 'c', 'b', 'a').

module @jit_f attributes {mhlo.num_partitions = 8 : i32, mhlo.num_replicas = 1 : i32} {
  func.func public @main(%arg0: tensor<8x1xf32> {mhlo.layout_mode = "default"}) -> (tensor<8x1xf32> {jax.result_info = "", mhlo.layout_mode = "default"}) {
    %0 = stablehlo.custom_call @Sharding(%arg0) {mhlo.sharding = "{devices=[8,1]<=[8]}"} : (tensor<8x1xf32>) -> tensor<8x1xf32>
    %1 = stablehlo.custom_call @SPMDFullToShardShape(%0) {mhlo.sharding = "{manual}"} : (tensor<8x1xf32>) -> tensor<1x1xf32>
    %2 = call @shmap_body(%1) : (tensor<1x1xf32>) -> tensor<1x1xf32>
    %3 = stablehlo.custom_call @Sharding(%2) {mhlo.sharding = "{manual}"} : (tensor<1x1xf32>) -> tensor<1x1xf32>
    %4 = stablehlo.custom_call @SPMDShardToFullShape(%3) {mhlo.sharding = "{devices=[8,1]<=[8]}"} : (tensor<1x1xf32>) -> tensor<8x1xf32>
    %cst = stablehlo.constant dense<1.000000e+00> : tensor<f32>
    %5 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor<f32>) -> tensor<1x1xf32>
    %6 = stablehlo.custom_call @Sharding(%5) {mhlo.sharding = "{replicated}"} : (tensor<1x1xf32>) -> tensor<1x1xf32>
    %7 = stablehlo.custom_call @SPMDFullToShardShape(%6) {mhlo.sharding = "{manual}"} : (tensor<1x1xf32>) -> tensor<1x1xf32>
    %8 = stablehlo.custom_call @Sharding(%4) {mhlo.sharding = "{devices=[8,1]<=[8]}"} : (tensor<8x1xf32>) -> tensor<8x1xf32>
    %9 = stablehlo.custom_call @SPMDFullToShardShape(%8) {mhlo.sharding = "{manual}"} : (tensor<8x1xf32>) -> tensor<1x1xf32>
    %10 = call @shmap_body_0(%7, %9) : (tensor<1x1xf32>, tensor<1x1xf32>) -> tensor<1x1xf32>
    %11 = stablehlo.custom_call @Sharding(%10) {mhlo.sharding = "{manual}"} : (tensor<1x1xf32>) -> tensor<1x1xf32>
    %12 = stablehlo.custom_call @SPMDShardToFullShape(%11) {mhlo.sharding = "{devices=[8,1]<=[8]}"} : (tensor<1x1xf32>) -> tensor<8x1xf32>
    return %12 : tensor<8x1xf32>
  }
  func.func private @shmap_body(%arg0: tensor<1x1xf32>) -> (tensor<1x1xf32> {jax.result_info = "[('d', 'c', 'b', 'a'), None]"}) {
    %0 = stablehlo.cosine %arg0 : tensor<1x1xf32>
    return %0 : tensor<1x1xf32>
  }
  func.func private @shmap_body_0(%arg0: tensor<1x1xf32>, %arg1: tensor<1x1xf32>) -> (tensor<1x1xf32> {jax.result_info = "[('a',), ('b',)]"}) {
    %0 = stablehlo.multiply %arg0, %arg1 : tensor<1x1xf32>
    return %0 : tensor<1x1xf32>
  }
}

System info (python version, jaxlib version, accelerator, etc.)

jax:    0.4.28.dev20240510+f21e3e82c
jaxlib: 0.4.28.dev20240510
numpy:  1.26.4
python: 3.10.12 (main, Nov 20 2023, 15:14:05) [GCC 11.4.0]
jax.devices (8 total, 8 local): [cuda(id=0) cuda(id=1) ... cuda(id=6) cuda(id=7)]
process_count: 1
platform: uname_result(system='Linux', node='...', release='5.15.0-1029-nvidia', version='#29-Ubuntu SMP Mon Jul 17 15:02:31 UTC 2023', machine='x86_64')


$ nvidia-smi
Wed May 15 10:04:12 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 545.29.06              Driver Version: 545.29.06    CUDA Version: 12.3     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|=========================================+======================+======================|
|   0  NVIDIA H100 80GB HBM3          On  | 00000000:1B:00.0 Off |                    0 |
| N/A   35C    P0             101W / 700W |    538MiB / 81559MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
|   1  NVIDIA H100 80GB HBM3          On  | 00000000:43:00.0 Off |                    0 |
| N/A   35C    P0             126W / 700W |    538MiB / 81559MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
|   2  NVIDIA H100 80GB HBM3          On  | 00000000:52:00.0 Off |                    0 |
| N/A   37C    P0             120W / 700W |    538MiB / 81559MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
|   3  NVIDIA H100 80GB HBM3          On  | 00000000:61:00.0 Off |                    0 |
| N/A   36C    P0             119W / 700W |    538MiB / 81559MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
|   4  NVIDIA H100 80GB HBM3          On  | 00000000:9D:00.0 Off |                    0 |
| N/A   36C    P0             116W / 700W |    538MiB / 81559MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
|   5  NVIDIA H100 80GB HBM3          On  | 00000000:C3:00.0 Off |                    0 |
| N/A   34C    P0             116W / 700W |    538MiB / 81559MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
|   6  NVIDIA H100 80GB HBM3          On  | 00000000:D1:00.0 Off |                    0 |
| N/A   36C    P0             114W / 700W |    538MiB / 81559MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
|   7  NVIDIA H100 80GB HBM3          On  | 00000000:DF:00.0 Off |                    0 |
| N/A   38C    P0             116W / 700W |    538MiB / 81559MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
@jaro-sevcik jaro-sevcik added the bug Something isn't working label May 15, 2024
@jaro-sevcik
Copy link
Contributor Author

One option to fix the problem is to always sort the axis names in the _all_mesh_names function (I verified that it fixes the repro). Something like:

diff --git a/jax/experimental/shard_map.py b/jax/experimental/shard_map.py
index a88571a90..3235b8e9d 100644
--- a/jax/experimental/shard_map.py
+++ b/jax/experimental/shard_map.py
@@ -1322,7 +1322,7 @@ def _shard_map_partial_eval(trace, shard_map_p, f, tracers, mesh, in_names,
     in_fwd, out_fwd, out_knowns, _, jaxpr, _ = aux()
     _, out_known_names = pe.partition_list(out_knowns, out_names_thunk())
     num_res = sum(f1 is None and f2 is None for f1, f2 in zip(in_fwd, out_fwd))
-    return (*out_known_names, *({0: (*all_names,)},) * num_res)
+    return (*out_known_names, *({0: all_names},) * num_res)
 
   known_params = dict(mesh=mesh, in_names=(*known_in_names,),
                       out_names_thunk=known_out_names, check_rep=check_rep,
@@ -1337,7 +1337,7 @@ def _shard_map_partial_eval(trace, shard_map_p, f, tracers, mesh, in_names,
   res = subs_list2(in_fwd, out_fwd, in_consts, out_consts, non_fwd_res)
   res_names = [known_in_names[f1] if f1 is not None else
                known_out_names_[f2] if f2 is not None else
-               {0: (*all_names,)} for f1, f2 in zip(in_fwd, out_fwd)]
+               {0: all_names} for f1, f2 in zip(in_fwd, out_fwd)]
   unk_in_names = (*res_names,) + ({},) * len(env) + (*unk_in_names,)
   const_tracers = map(trace.new_instantiated_const, res)
   env_tracers = map(trace.full_raise, env)
@@ -1380,7 +1380,7 @@ def _shard_map_partial_eval_post_process(
     const_tracers = map(trace.new_instantiated_const, res_)
     env_tracers = map(trace.full_raise, env)
 
-    staged_in_names = ({0: (*all_names,)},) * len(res_) + ({},) * len(env)
+    staged_in_names = ({0: all_names},) * len(res_) + ({},) * len(env)
     staged_params = dict(jaxpr=jaxpr_, mesh=mesh, in_names=staged_in_names,
                          out_names=(*out_names_unknown,), check_rep=False,
                          rewrite=rewrite, auto=auto)
@@ -1399,7 +1399,7 @@ def _shard_map_partial_eval_post_process(
   def out_names_transform(out_names):
     nonlocal out_names_unknown
     out_names_unknown, out_names_known = partition_list(out_knowns, out_names)
-    return (*out_names_known,) + ({0: (*all_names,)},) * len(res)
+    return (*out_names_known,) + ({0: all_names},) * len(res)
   out_names_unknown: list | None = None
 
   return out, (todo, out_names_transform)
@@ -1512,7 +1512,7 @@ def _partial_eval_jaxpr_custom_rule(
   params_known, params_staged, all_names = _pe_custom_params(
       unks_in, inst_in, map(op.not_, unks_out), inst_out, in_fwd, out_fwd, which,
       dict(eqn.params, jaxpr=jaxpr_known), dict(eqn.params, jaxpr=jaxpr_staged))
-  residuals = [newvar(_unshard_aval(mesh, {0: (*all_names,)}, var.aval))
+  residuals = [newvar(_unshard_aval(mesh, {0: all_names}, var.aval))
                for var, w in zip(jaxpr_staged.invars[:num_res], which) if w]
   eqn_known = pe.new_jaxpr_eqn(ins_known, [*out_binders_known, *residuals],
                                eqn.primitive, params_known, jaxpr_known.effects,
@@ -1564,7 +1564,7 @@ def _pe_custom_params(unks_in, inst_in, kept_outs_known, kept_outs_staged,
   all_names = _all_mesh_names(mesh)
   in_names_known, _ = partition_list(unks_in, params_known['in_names'])
   _, out_names_known = partition_list(kept_outs_known, params_known['out_names'])
-  out_names_known = out_names_known + [{0: (*all_names,)}] * sum(which)
+  out_names_known = out_names_known + [{0: all_names}] * sum(which)
   new_params_known = dict(params_known, in_names=tuple(in_names_known),
                           out_names=tuple(out_names_known))
 
@@ -1572,7 +1572,7 @@ def _pe_custom_params(unks_in, inst_in, kept_outs_known, kept_outs_staged,
   _, in_names_staged = partition_list(inst_in, params_staged['in_names'])
   res_names = [in_names_known[f1] if f1 is not None else
                out_names_known[f2] if f2 is not None else
-               {0: (*all_names,)} for f1, f2 in zip(in_fwd, out_fwd)]
+               {0: all_names} for f1, f2 in zip(in_fwd, out_fwd)]
   in_names_staged = res_names + in_names_staged
   _, out_names_staged = partition_list(kept_outs_staged, params_staged['out_names'])
   new_params_staged = dict(params_staged, in_names=tuple(in_names_staged),
@@ -1586,7 +1586,7 @@ def _all_mesh_names(mesh: Mesh) -> set[AxisName]:
   names = {n for frame in stack
            if (ns := frame.payload.get('spmd_axis_name', ())) is not None
            for n in ns}
-  return set(mesh.axis_names) - names
+  return tuple(sorted(set(mesh.axis_names) - names))
 
 
 # DCE

@jaro-sevcik
Copy link
Contributor Author

Fixed by #21278

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants