-
Notifications
You must be signed in to change notification settings - Fork 2.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
shard_map can cause misses in compilation_cache #21236
Labels
bug
Something isn't working
Comments
One option to fix the problem is to always sort the axis names in the
|
Fixed by #21278 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Description
From version 0.4.27, grad of shard_map can produce different MLIR code on different runs, so the compilation cache does not trigger. See the bash invocation, example output and python code below for repro.
Bash commands:
Output:
Code for
sm.py
:The likely culprit is the code for sharding the residuals, where a recent change seems to have introduced set ordering for residual axis names, but set ordering is not guaranteed to be the same from run to run. In the MLIR code dumped by the example above, observe that
jax.result_info
for the residual is a random permutation of the('a', 'b', 'c', 'd')
list. For instance, in the output below, the residual axis ordering (output of theshmap_body
function) is('d', 'c', 'b', 'a')
.System info (python version, jaxlib version, accelerator, etc.)
The text was updated successfully, but these errors were encountered: