Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

shard map implemenation of pmap only works for the first signature #21225

Open
jheek opened this issue May 14, 2024 · 1 comment
Open

shard map implemenation of pmap only works for the first signature #21225

jheek opened this issue May 14, 2024 · 1 comment
Labels
bug Something isn't working

Comments

@jheek
Copy link
Member

jheek commented May 14, 2024

Description

from jax.experimental.shard_map import pmap
from jax import numpy as jnp

def f(x):
  return x * x

f = pmap(f, axis_name='batch')
f(jnp.ones((8,)))
f(jnp.ones((8, 2))) # raises StoreException: Store occupied

Giving multiple input signatures fails on the shard map implemenation of pmap

System info (python version, jaxlib version, accelerator, etc.)

latest version of jax

@jheek jheek added the bug Something isn't working label May 14, 2024
@yashk2810
Copy link
Member

This is not polished or fully tested yet. It's purely experimental right now.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants