Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Checkify exception only reports error from one device (not all) #21246

Open
billmark opened this issue May 15, 2024 · 7 comments
Open

Checkify exception only reports error from one device (not all) #21246

billmark opened this issue May 15, 2024 · 7 comments
Labels
bug Something isn't working

Comments

@billmark
Copy link

Description

When running with four hosts and four devices on each host, I see an "errs" returned by pmap of checkify that looks like the folllowing:

Error(at mapped index 0: before: pmean_input_ok failed step @12290 (`check` failed)
at mapped index 1: after:neg_delta_params has NaN @12290 (`check` failed)
at mapped index 2: after:neg_delta_params has NaN @12290 (`check` failed)
at mapped index 3: after:neg_delta_params has NaN @12290 (`check` failed))

However, an errs.throw() (as recommended in JAX docs) only shows one of these four errors:

 Top-level exception: after:neg_delta_params has NaN @12290 (`check` failed)
  ...
  jax._src.checkify.FailedCheckError: after:neg_delta_params has NaN @12290 (`check` failed)

I consider this behavior to be a bug. No reasonable person would expect the exception string to omit the errors from three our of four devices on that host. The exception string should contain all four errors.

System info (python version, jaxlib version, accelerator, etc.)

HEAD at google as of May 15, 2024. Running on TPU. (Four hosts, four devices per host).

@billmark billmark added the bug Something isn't working label May 15, 2024
@billmark
Copy link
Author

An additional clarification. The code looks essentially like the following:

errs, returns = jax.pmap(jax.experimental.checkify(step_fn, ..), ...)(state, batch, rng)
print(errs). # shows all four errors (one error from each device)
errs.throw()  # shows only one error (from one device) -- not what is expected

@jakevdp
Copy link
Collaborator

jakevdp commented May 15, 2024

Thanks for the report! This code is a few years old now and the author is no longer working on the JAX project. I took a look and I found that where the extra errors are removed is in the _reduce_any_error function:

jax/jax/_src/checkify.py

Lines 445 to 457 in a820387

def _reduce_any_error(error: Error):
out_error = init_error
for error_effect in error._pred.keys():
errs, codes, payloads = (error._pred[error_effect],
error._code[error_effect],
error._payload[error_effect])
reduced_idx = jnp.argsort(errs)[-1]
pred, code, payload = tree_map(lambda x, idx=reduced_idx: x[idx],
(errs, codes, payloads))
out_error = out_error._update(error_effect, pred, code, {}, payload)
out_error = out_error._replace(_metadata=error._metadata)
return out_error

If I turn this function into an identity:

def _reduce_any_error(error: Error):
  return error

then err.throw() shows all the errors instead of just the first.

I'm not sure of the other implications of that, but if someone wants to explore this further, then understanding the intent of that helper function is probably where to start.

@mattjj
Copy link
Member

mattjj commented May 16, 2024

Thanks, Jake! @sharadmv and I worked closely with Lena on checkify, so I think we can debug this.

I think this was the intended behavior at one point; indeed that's what _reduce_any_error does! But we can re-evaluate it.

@billmark for prioritization purposes: is this blocking your work in some way? Or just a preference you wanted to surface?

@billmark
Copy link
Author

It is not blocking my work any more, but I think it is critical to address this issue for other users, either by a code change or a documentation change.

In the meantime, to help others...

The idiom implicitly recommended by the jax documentation is the following:

errs, returns = jax.pmap(jax.experimental.checkify(step_fn, ..), ...)(state, batch, rng)
errs.throw()

That idiom doesn't work properly -- it ignores errors from all but one device. Instead, I use a variant of the following:

errs, returns = jax.pmap(jax.experimental.checkify(step_fn, ..), ...)(state, batch, rng)
if errs.get_exception() is not None:
  print(errs)
  raise RuntimeError("Checkify caught an exception; see output above for details")

@jakevdp
Copy link
Collaborator

jakevdp commented May 16, 2024

For what it's worth, I think the current behavior is defensible: e.g. if you have 64 shards that all error, it's not terrible to only see one copy of the error in the traceback.

@billmark
Copy link
Author

I respectfully disagree. The current behavior is terrible, and just caused me to waste an enormous amount of time.

When trying to track down a NaN with checkify's "float" check, one device typically has the "original" error, but all devices have a NaN error, since the NaN's later propagate via collective operations. I was seeing the error from the collective operation without realizing that there was a "hidden" error from another device that was the original cause of the NaN.

At a bare minimum, the exception needs to state that it is reporting e.g. only one out of 64 errors and that the other errors may be different. The checkify documentation here would also need to discuss this case. I'll call this solution A.

The best solution ("solution B") would be to de-dup the errors, so that the message would say something like: 64 errors on 64 devices, of two different types. Error #1 (devices 0, 2, 3): XXX. Error #2 (device 1): YYY.

Solution "C" is to just dump all 64 of the errors.

Solution "D", currently implemented, is to arbitrarily choose one of the errors to display, without any indication that others are being suppressed. This solution is terrible, particularly since it is not documented.

I would rank the solutions as follows:
Best: Solution B (but it may be too complicated to implement)
Next best: Solution C
Next best: Solution A
Worst: Solution D, the current one.

@billmark
Copy link
Author

To clarify my example above: All devices had a NaN error, but there were two different source-code locations for the error. This is critical information!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

3 participants