You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
When using fmmax to simulate a system using a vector RCWA formulation jax.numpy.linalg.solve produces an array of NaN using the GPU backend, but not using the CPU backend, which is why I suspect the error lies with jax. This is the error message produced using the NaN debugging flag:
FloatingPointError: invalid value (nan) encountered in jit(triangular_solve). Because jax_config.debug_nans.value and/or config.jax_debug_infs is set, the de-optimized function (i.e., the function as if the `jit` decorator were removed) was called in an attempt to get a more precise error message. However, the de-optimized function did not produce invalid values during its execution. This behavior can result from `jit` optimizations causing the invalid value to be produced. It may also arise from having nan/inf constants as outputs, like `jax.jit(lambda ...: jax.numpy.nan)(...)`.
In my use case the error occured here. The input arrays had shape (490, 490) and (490,), the dtype was complex64.
System info (python version, jaxlib version, accelerator, etc.)
jax: 0.4.27
jaxlib: 0.4.27
numpy: 1.26.4
python: 3.11.5 (main, Oct 25 2023, 16:19:59) [GCC 8.5.0 20210514 (Red Hat 8.5.0-20)]
jax.devices (1 total, 1 local): [cuda(id=0)]
process_count: 1
platform: uname_result(system='Linux', release='4.18.0-513.24.1.el8_9.x86_64', version='#1 SMP Thu Apr 4 18:13:02 UTC 2024', machine='x86_64')
$ nvidia-smi
Wed May 8 00:34:18 2024
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.15 Driver Version: 550.54.15 CUDA Version: 12.4 |
|-----------------------------------------+------------------------+----------------------+
| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|=========================================+========================+======================|
| 0 NVIDIA H100 On | 00000000:9D:00.0 Off | 0 |
| N/A 36C P0 82W / 700W | 534MiB / 95830MiB | 0% Default |
| | | Disabled |
+-----------------------------------------+------------------------+----------------------+
+-----------------------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=========================================================================================|
| 0 N/A N/A 157152 C python 524MiB |
+-----------------------------------------------------------------------------------------+
The text was updated successfully, but these errors were encountered:
Description
This may be related to #20047 or lineax/#79.
When using fmmax to simulate a system using a vector RCWA formulation
jax.numpy.linalg.solve
produces an array ofNaN
using the GPU backend, but not using the CPU backend, which is why I suspect the error lies with jax. This is the error message produced using the NaN debugging flag:In my use case the error occured here. The input arrays had shape
(490, 490)
and(490,)
, thedtype
wascomplex64
.System info (python version, jaxlib version, accelerator, etc.)
The text was updated successfully, but these errors were encountered: