Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add boundary='wrap' support for jax.scipy.signal.convolve2d and jax.scipy.signal.correlate2d #21241

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

rajasekharporeddy
Copy link
Contributor

Currently jax.scipy.signal.convolve2d and jax.scipy.signal.correlate2d functions support only boundary='fill'. This PR will add boundary='wrap' support for both the functions.

Fixes #7276

jax/_src/scipy/signal.py Outdated Show resolved Hide resolved
jax/_src/scipy/signal.py Outdated Show resolved Hide resolved
jax/_src/scipy/signal.py Outdated Show resolved Hide resolved
jax/_src/scipy/signal.py Outdated Show resolved Hide resolved
jax/_src/scipy/signal.py Outdated Show resolved Hide resolved
@jakevdp jakevdp self-assigned this May 15, 2024
jax/_src/scipy/signal.py Outdated Show resolved Hide resolved
jax/_src/scipy/signal.py Outdated Show resolved Hide resolved
jax/_src/scipy/signal.py Outdated Show resolved Hide resolved
@jakevdp
Copy link
Collaborator

jakevdp commented May 16, 2024

Looks great! The last thing we need is to squash the changes into a single commit (see https://jax.readthedocs.io/en/latest/contributing.html#single-change-commits-and-pull-requests). Let me know if you need help with that process

@rajasekharporeddy
Copy link
Contributor Author

Squashed all the commits into a single one.

Copy link
Collaborator

@jakevdp jakevdp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great, thanks!

@jakevdp
Copy link
Collaborator

jakevdp commented May 16, 2024

We're seeing errors when running a larger number of test cases. You can reproduce them locally using e.g.

JAX_NUM_GENERATED_CASES=50 pytest -n auto tests/scipy_signal_test.py -k testConvolutions

It looks like it's an issue with boundary='wrap' combined with mode='full'.

@jakevdp
Copy link
Collaborator

jakevdp commented May 16, 2024

Example failure:

testConvolutions2DNotValidMode7 (mode='full', boundary='wrap', op='convolve2d', dtype=<class 'numpy.float64'>, xshape=(3, 4), yshape=(2, 5)) (shard 21)content_copy
---
  File "jax/tests/scipy_signal_test.py", line 163, in testConvolutions2DNotValidMode
    self._CheckAgainstNumpy(osp_fun, jsp_fun, args_maker, check_dtypes=False,
  File "jax/_src/public_test_util.py", line 127, in _assert_numpy_allclose
    np.testing.assert_allclose(a, b, **kw, err_msg=err_msg)
  File "numpy/testing/_private/utils.py", line 1504, in assert_allclose
    assert_array_compare(compare, actual, desired, err_msg=str(err_msg),
  File "numpy/testing/_private/utils.py", line 797, in assert_array_compare
    raise AssertionError(msg)
AssertionError: 
Not equal to tolerance rtol=1e-12, atol=1e-12

Mismatched elements: 4 / 32 (12.5%)
Max absolute difference: 11.06572735
Max relative difference: 2.87831292
 x: array([[ -7.221209, -14.248762, -25.984588, -41.801186,  -7.221209,
        -14.248762, -25.984588, -41.801186],
       [-47.816622,  -3.279326, -16.912216,  42.092961, -47.816622,...
 y: array([[  3.844519, -14.248762, -25.984588, -41.801186,  -7.221209,
        -14.248762, -25.984588, -41.801186],
       [-45.311671,  -3.279326, -16.912216,  42.092961, -47.816622,...

@rajasekharporeddy
Copy link
Contributor Author

It is fixed now. It is the problem with padding in1. It appears that the pad_width should have same value in all dimensions which is the maximum shape value among the input shapes.

@jakevdp
Copy link
Collaborator

jakevdp commented May 16, 2024

That's surprising to me – say you are convolving two (1000, 2) arrays. Padding it out to (2000, 1002) seems unnecessary, because we really only need 4 elements in the second dimension.

Can we fix this without blowing up the memory requirements in cases like this?

@rajasekharporeddy
Copy link
Contributor Author

I will check and let you know, if it is possible or not.

@rajasekharporeddy
Copy link
Contributor Author

I could not find any other way to fix the issue without padding it with maximum value in all dimensions.

@jakevdp
Copy link
Collaborator

jakevdp commented May 16, 2024

I would consider this a blocker to merging this PR: we need to figure out how to implement this without excessive padding.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
pull ready Ready for copybara import and testing
Projects
None yet
Development

Successfully merging this pull request may close these issues.

convolve2D not supporting "boundary=wrap"
3 participants