-
Notifications
You must be signed in to change notification settings - Fork 2.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
convolve2D not supporting "boundary=wrap" #7276
Comments
Hi - thanks for the report! This is a known issue, though is not tracked anywhere (aside from the explicit Regarding the documentation, the reason it's mentioned there is because the docstring is copied verbatim from I'm going to change this from Thanks! |
One more thing: the reason this has not yet been implemented is because convolutions are computed via XLA's ConvWithGeneralPadding, and I'm not certain whether it is able to compute the equivalent of scipy's wrapped convolutions. |
The best way to implement this is probably to explicitly form the padding values as part of the input to the convolution. Note that |
@hawkinsp , great call. Here is a working example: from jax.scipy.signal import convolve2d as convolve_jax
from scipy.signal import convolve2d as convolve_scipy
from jax import random
import jax.numpy as jnp
key = random.PRNGKey(0)
def convolve_wrap(a1,a2):
N = a1.shape[0]
a1 = jnp.pad(a1,N,mode='wrap')
return convolve_jax(a1, a2,mode='same')[N:2*N,N:2*N]
#Filter
N = 5
a2 = jnp.ones((N,N))/N/N
#Image
N = 10
a1 = random.normal(key, (N,N))
#Scipy
a_scipy = convolve_scipy(a1, a2,mode='same',boundary='wrap')
#Jax
a_jax = convolve_wrap(a1,a2)
print(jnp.allclose(a_scipy,a_jax,a_tol=1e-6)) |
@romanodev Awesome! Would you be interested in contributing a PR that adds support for Line 73 in 9450b8f
|
@hawkinsp , I can definitely take a look at that. While the above case is for mode='full', I guess it must be implemented for all modes, as required in https://github.com/google/jax/blob/main/tests/scipy_signal_test.py#L50 |
Hi @romanodev - don't worry about implementing everything at once if it's blocking you. Even just implementing |
@jakevdp, great! I will resume this next week. |
Hi @romanodev A fix for this issue is included in pull request #21241. |
@rajasekharporeddy , awesome! |
Hi JAX developers,
I am trying to filter periodic images with jax.scipy.signal.convolve2d, but it seems that the flag boundary='wrap' is not supported, although it was mentioned in the doc.
Here is a MVE:
Here is the error:
raise NotImplementedError("convolve2d() only supports boundary='fill', fillvalue=0")
NotImplementedError: convolve2d() only supports boundary='fill', fillvalue=0
Thanks!
The text was updated successfully, but these errors were encountered: