Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

convolve2D not supporting "boundary=wrap" #7276

Open
romanodev opened this issue Jul 13, 2021 · 10 comments · May be fixed by #21241
Open

convolve2D not supporting "boundary=wrap" #7276

romanodev opened this issue Jul 13, 2021 · 10 comments · May be fixed by #21241
Assignees
Labels
contributions welcome The JAX team has not prioritized work on this. Community contributions are welcome. enhancement New feature or request

Comments

@romanodev
Copy link

romanodev commented Jul 13, 2021

Hi JAX developers,

I am trying to filter periodic images with jax.scipy.signal.convolve2d, but it seems that the flag boundary='wrap' is not supported, although it was mentioned in the doc.

Here is a MVE:

from jax.scipy.signal import convolve2d
import jax.numpy as jnp

a1 = jnp.ones((3,3))

a = convolve2d(a1, a1,mode='same',boundary='wrap')

Here is the error:
raise NotImplementedError("convolve2d() only supports boundary='fill', fillvalue=0")
NotImplementedError: convolve2d() only supports boundary='fill', fillvalue=0

Thanks!

@romanodev romanodev added the bug Something isn't working label Jul 13, 2021
@jakevdp
Copy link
Collaborator

jakevdp commented Jul 13, 2021

Hi - thanks for the report! This is a known issue, though is not tracked anywhere (aside from the explicit NotImplementedError you found in the code).

Regarding the documentation, the reason it's mentioned there is because the docstring is copied verbatim from scipy.signal.convolve, as is the custom with JAX wrapped numpy/scipy functions. In many cases if you dig hard enough you'll find unimplemented keywords like this one.

I'm going to change this from bug to enhancement, because it's something we know is unimplemented and we hope will be implemented by a team member or community member in the future.

Thanks!

@jakevdp jakevdp added contributions welcome The JAX team has not prioritized work on this. Community contributions are welcome. enhancement New feature or request and removed bug Something isn't working labels Jul 13, 2021
@jakevdp
Copy link
Collaborator

jakevdp commented Jul 13, 2021

One more thing: the reason this has not yet been implemented is because convolutions are computed via XLA's ConvWithGeneralPadding, and I'm not certain whether it is able to compute the equivalent of scipy's wrapped convolutions.

@hawkinsp
Copy link
Member

The best way to implement this is probably to explicitly form the padding values as part of the input to the convolution. Note that jnp.pad supports wrapping padding modes, so perhaps the implementation is as simple as composing the two?

@romanodev
Copy link
Author

romanodev commented Jul 13, 2021

@hawkinsp , great call.

Here is a working example:

from jax.scipy.signal import convolve2d as convolve_jax
from scipy.signal     import convolve2d as convolve_scipy
from jax import random
import jax.numpy as jnp

key = random.PRNGKey(0)

def convolve_wrap(a1,a2):

    N = a1.shape[0]
    a1 = jnp.pad(a1,N,mode='wrap')

    return convolve_jax(a1, a2,mode='same')[N:2*N,N:2*N]

#Filter
N = 5
a2 = jnp.ones((N,N))/N/N

#Image
N = 10
a1 = random.normal(key, (N,N))

#Scipy
a_scipy = convolve_scipy(a1, a2,mode='same',boundary='wrap')

#Jax
a_jax = convolve_wrap(a1,a2)

print(jnp.allclose(a_scipy,a_jax,a_tol=1e-6))

@hawkinsp
Copy link
Member

@romanodev Awesome! Would you be interested in contributing a PR that adds support for boundary="wrap" to convolve2d? The implementation is here:

def convolve2d(in1, in2, mode='full', boundary='fill', fillvalue=0,

@romanodev
Copy link
Author

@hawkinsp , I can definitely take a look at that. While the above case is for mode='full', I guess it must be implemented for all modes, as required in https://github.com/google/jax/blob/main/tests/scipy_signal_test.py#L50

@jakevdp
Copy link
Collaborator

jakevdp commented Aug 2, 2021

Hi @romanodev - don't worry about implementing everything at once if it's blocking you. Even just implementing mode='full' would be useful, and you could adjust the tests to skip unimplemented combinations.

@romanodev
Copy link
Author

@jakevdp, great! I will resume this next week.

@rajasekharporeddy
Copy link
Contributor

Hi @romanodev

A fix for this issue is included in pull request #21241.

@romanodev
Copy link
Author

@rajasekharporeddy , awesome!

@jakevdp jakevdp self-assigned this May 16, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
contributions welcome The JAX team has not prioritized work on this. Community contributions are welcome. enhancement New feature or request
Projects
None yet
4 participants