2D Wavelet Transformation in PyTorch

Oct. 29, 2017

The other day I got a question how to do wavelet transformation in PyTorch in a way that allows to compute gradients (that is gradients of outputs w.r.t. the inputs, probably not the coefficients). I like Pytorch and I happen to have a certain fancy for wavelets as well, so here we go. This post is available as a Jupyter Notebook as well.

We take an image of the Zuse Z4.

(Image Credit: Clemens Pfeiffer, CC-BY 2.5 at Wikimedia)

We will make use of PyTorch (of course) and the excellent PyWavelet aka pywt module. The latter includes a large library of wavelet coefficients and also functions to perform wavelet transforms. We only use the former. So let's import stuff.

import pywt
from matplotlib import pyplot
%matplotlib inline
import numpy
from PIL import Image
import urllib.request
import io
import torch
from torch.autograd import Variable

URL = 'https://upload.wikimedia.org/wikipedia/commons/thumb/b/bc/Zuse-Z4-Totale_deutsches-museum.jpg/315px-Zuse-Z4-Totale_deutsches-museum.jpg'

Let us see what wavelets are available:

['haar', 'db', 'sym', 'coif', 'bior', 'rbio', 'dmey', 'gaus', 'mexh', 'morl', 'cgau', 'shan', 'fbsp', 'cmor']

For this demo we will use the Biorthogonal 2.2 Wavelets. As we will not properly deal with boundaries, this is a compromise between not using the (almost trivial) Haar wavelet and using more elaborate but larger wavelets. When adapting this code to other wavelets, you will need to adjust the padding. To use multiple channels (color images) you would want to .view the channels into the batch dimension.

The basic idea is that for each coordinate direction you apply a high pass and a low pass filter with a stride of 2, and one can take PyTorch convolutions for this purpose. With orthogonal wavelets, you would feed the same wavelet into the transposed convolution for decoding, with biorthogonal wavelets, you have a separate set of coefficients for decoding. I adjust the filters so that they are ready for use with PyTorch convolutions, i.e. with correlations of the form $Y_{k,l} = \sum \psi_ij X_{i+k,j+l}$.

We take a look at the wavelets.

pyplot.plot(w.dec_hi[::-1], label="dec hi")
pyplot.plot(w.dec_lo[::-1], label="dec lo")
pyplot.plot(w.rec_hi, label="rec hi")
pyplot.plot(w.rec_lo, label="rec lo")
pyplot.title("Bior 2.2 Wavelets")
dec_hi = torch.Tensor(w.dec_hi[::-1]) 
dec_lo = torch.Tensor(w.dec_lo[::-1])
rec_hi = torch.Tensor(w.rec_hi)
rec_lo = torch.Tensor(w.rec_lo)


Let us have a black and white picture:

imgraw = Image.open(io.BytesIO(urllib.request.urlopen(URL).read())).resize((256,256))
img = numpy.array(imgraw).mean(2)/255
img = torch.from_numpy(img).float()
pyplot.imshow(img, cmap=pyplot.cm.gray)

<matplotlib.image.AxesImage at 0x7f24245c1f98>


We define the tensor product filter banks, i.e. we multiply filters for the two coordinates.

filters = torch.stack([dec_lo.unsqueeze(0)*dec_lo.unsqueeze(1),
                       dec_hi.unsqueeze(0)*dec_hi.unsqueeze(1)], dim=0)

inv_filters = torch.stack([rec_lo.unsqueeze(0)*rec_lo.unsqueeze(1),
                           rec_hi.unsqueeze(0)*rec_hi.unsqueeze(1)], dim=0)

We can now define the wavelet transform and its inverse using pytorch conv2d and conv_transpose2d. For the recursion, we only continue to process the top left component with two low passes. This is different from taking tensor-product of the full 1d wavelet basis in that we do not further refine one low-pass coordinate when the other coordinate has taken the high-pass.

This seems to be the convention used e.g. JPEG compression. I seem to remember that there are also stability reasons to keep the aspect ratio bounded for multi-resolution analysis, but I do not have a good reference to point to.

On the component with two low passes, one then applies another transform up to a desired level. We rearrange the four filter output into an image of the same size. To allow for full reconstruction, we would need to deal with the boundaries - either by adapting the wavelets or by padding more - but we do not do this.

As PyTorch's nn module does this by default, we consider batch x channels x height x width.

def wt(vimg, levels=1):
    h = vimg.size(2)
    w = vimg.size(3)
    padded = torch.nn.functional.pad(vimg,(2,2,2,2))
    res = torch.nn.functional.conv2d(padded, Variable(filters[:,None]),stride=2)
    if levels>1:
        res[:,:1] = wt(res[:,:1],levels-1)
    res = res.view(-1,2,h//2,w//2).transpose(1,2).contiguous().view(-1,1,h,w)
    return res

Similar, we do the reconstruction (Inverse Wavelet Transform) using the conv_transpose2d function. We drop the excess coefficients.

def iwt(vres, levels=1):
    h = vres.size(2)
    w = vres.size(3)
    res = vres.view(-1,h//2,2,w//2).transpose(1,2).contiguous().view(-1,4,h//2,w//2).clone()
    if levels>1:
        res[:,:1] = iwt(res[:,:1], levels=levels-1)
    res = torch.nn.functional.conv_transpose2d(res, Variable(inv_filters[:,None]),stride=2)
    res = res[:,:,2:-2,2:-2]
    return res

We can do this on our image. First the decomposition:

vimg = Variable(img[None,None])
res = wt(vimg,4)
<matplotlib.image.AxesImage at 0x7f2422cce518>


And then the reconstruction.

rec = iwt(res, levels=4)
<matplotlib.image.AxesImage at 0x7f2422cb0ef0>


We can see where the reconstruction errors are:

pyplot.imshow((rec-vimg).data[0,0].numpy(), cmap=pyplot.cm.gray)
<matplotlib.colorbar.Colorbar at 0x7f2424f165c0>


This concludes our little tour on how to use wavelets with PyTorch.

I recommend to also look at scattering networks for an application of fixed MRA in deep learning.