Computed Parameters in PyTorch - a hack

Sept. 17, 2020

Sometimes, we want our neural network's parameters to have contraints, e.g. be positive. PyTorch notoriously doesn't provide an infrastructure for this. We present a gross hack to deliver a neat interface.

If you are anything like me, you like PyTorch and you enjoy an occasional clever hack. So here we go.

The other day we joked online what breaking changes might be worth forking PyTorch for. So for all the ease with which to define networks and keep the parameters in nn.Modules, there isn't an easy way to create constrained parameters which magically enforce their constraints.

While many networks (obviously) work without them, constrained parameters are not entirely niche, either. Relatively simple constraints, such as values being positive, often occur when parametrizing distributions - e.g. for Mixture Density Networks used e.g. for Grave's famous Handwriting Generation RNN or for Gaussian Process modelling. But there are also more elaborate uses, e.g. spectral normalizaiton to impose a Lipschitz constraint in convolutions and linear layers.

PyTorch supports spectral norm contraints, but the mechanism it uses seems very elaborate for what should be a very simple thing. We get to this below.

For your convenience, I put up my Jupyter Notebook.

Acknowledgement: This post is dedicated to my one true fan on GitHub sponsors.

Tensor Machine

import torch
import numpy
import inspect  # this should raise the "we'll do gross things Python internals flag"

Spectral normalization and the PyTorch implementation

A while ago, people found that when training GANs, it was useful to have a continuity constraint on the regularity of the discriminator, or more precisely, a Lipschitz constraint. This led to WGAN (which constrained via weight clipping), WGAN-GP (constrained via a penalty on the "sampled" Lipschitz constant) which I discuss in an old blog post, and eventually Spectral Normalization (T. Miyato et. al: Spectral Normalization for Generative Adversarial Networks, ICLR 2018).

Spectral normalization bounds the continuity of linear or convolution operators when the domain and the image is equipped with the $l^2$ norm. The operator norm in this setting is the spectral norm, the largest spectral value of the operator. Now direct methods to compute the singular value decomposition are rather expensive, and so Miyato et al. use the fact that you can get the largest singular value by power iteration: Starting from some (hopefully not pathological) $v$ of the right dimension, we can iterate $v = \frac{A^T u}{|A^T u|}$, $u = \frac{A v}{|A v|}$, so that $u$ and $v$ become the singular vectors and $u^T A v \rightarrow \sigma(A)$. To bound our weight's spectral radius, we then devide it by $\sigma(A)$.

The key observation of T. Miyato is that we only need to do as little as one iteration per straining step to keep the weights in check. And this is relatively straightforward to implement as a PyTorch module taking no parameters and returning a spectrally normalized weight (I took the computation from PyTorch):

# Based on the the original implementation from PyTorch
# So portions copyright by the PyTorch contributors, in particular Simon Wang worked on it a lot.
# Errors probably are my doing.

class SpectralNormWeight(torch.nn.Module):
    def __init__(self, shape, dim=0, eps=1e-12, n_power_iterations=1):
        self.n_power_iterations = n_power_iterations
        self.eps = eps
        self.dim = dim
        self.shape = shape
        self.permuted_shape = (shape[dim],) + shape[:dim] + shape[dim+1:]
        h = shape[dim]
        w = numpy.prod(self.permuted_shape[1:])
        self.weight_mat = torch.nn.Parameter(torch.randn(h, w))
        self.register_buffer('u', torch.nn.functional.normalize(torch.randn(h), dim=0, eps=self.eps))
        self.register_buffer('v', torch.nn.functional.normalize(torch.randn(w), dim=0, eps=self.eps))

    def forward(self):
        u = self.u
        v = self.v
        if self.training:
            with torch.no_grad():
                for _ in range(self.n_power_iterations):
                    # Spectral norm of weight equals to `u^T W v`, where `u` and `v`
                    # are the first left and right singular vectors.
                    # This power iteration produces approximations of `u` and `v`.
                    v = torch.nn.functional.normalize(torch.mv(self.weight_mat.t(), u), dim=0, eps=self.eps, out=v)
                    u = torch.nn.functional.normalize(torch.mv(self.weight_mat, v), dim=0, eps=self.eps, out=u)
                    # See above on why we need to clone
                    u = u.clone(memory_format=torch.contiguous_format)
                    v = v.clone(memory_format=torch.contiguous_format)

        sigma = torch.dot(u, torch.mv(self.weight_mat, v))
        weight = (self.weight_mat / sigma).view(self.permuted_shape)
        if self.dim != 0:
            weight = weight.transpose(0, self.dim)
        return weight

But how to apply this to our weight?

We can, of course, now do this:

w = SpectralNormWeight((out_feat, in_feat))
res = torch.nn.functional.linear(inp, w, bias)

But wouldn't it be nice to be able to just do this:

l = torch.nn.Linear(3, 4)
l.weight = ComputedParameter(SpectralNormWeight(l.weight.shape))

Well, it never is that easy in real life. A Tensor is a Tensor and a Module is a Module. The PyTorch developers used hooks to implement spectral norm in a way that is convenient for the user. But using hooks that way can be brittle because you need to deal with replacing the weight discreetely, seeing that backpropagation works, saving and loading models, multi-GPU, ... All in all, PyTorch's spectral norm implementation runs to 300 lines at the time of writing, almost 10x of what we have above. Also, it isn't easily extended if we want to use a different constraint.

Now the complexity involved here has not gone unnoticed, and so there is a 2-year-old issue with discussion of how an abstraction could look like. Yours truely has once explored solutions with some code, more recently Mario Lezcano worked on an implementation available in a draft PR clocking in at 350 lines. Mario does great work there and it looks like a great improvement - it works with Modules as the things giving parametrizations, so it can be customized, and removes the fiddling with the hooks. It does retain some of the boring properties of the hook-based spectral norm implementation, notably the idea that the inputs to the computation should be the "original weight" and a buffering scheme.

Couldn't we skip all that?

A better way?

Why cannot we just assign our module above as weight, then? Because we want to update existing modules such as torch.nn.Conv2d and torch.nn.Linear to replace their parameters. They would need to do weight = self.weight() instead of weight = self.weight, as Modules obviously aren't Tensors.

But maybe if our Modules could be Tensors, too?

And here is where our clever hack comes in. PyTorch wants to enable people extend it, and so one of the things people wanted was to have way of building on Tensors (subclassing which works as you expect and is done e.g. by torch.nn.ParameterI should point out here that things like FloatTensor aren't a subclass of Tensor but they're a gross hack to make isinstance work. Don't use them! In fact you should have stopped when PyTorch 0.4 came out.) without quite subclassing them (because they are more restricted in a sense) and but have regular PyTorch operations return the new class (whereas usually they return Tensors). Incidentally, one of the use-cases is to have constrained tensors, e.g. skew-symmetric ones.

The interesting thing is that these things aren't subclasses of tensors. Rather, in true Python fashion, they have a special __torch_function__ method that magically converts inputs into Tensors calls PyTorch functions and then post-processes the (Tensor) results to whatever it wants. Now but if they aren't subclasses, we can easily be subclasses of torch.nn.Module and define the special method. Bingo!

We also implement caching. The hard part about caching isn't to keep the result around, instead it is to figure out when we need to re-compute. We do this when - the parameters used in our module are updated (e.g. by our optimizer), because the new result will be different, handily, PyTorch has a counter to keep track of it, - the cached tensor has been back-propagated through, because the next back-propagation will fail, (ideally, we'd only do this if retain_graph hadn't been used, and there might be corner cases when the module does funny things to cause the hook not to work, but hey...).

class ComputedParameter(torch.nn.Module):
    def __init__(self, m):
        self.m = m
        self.needs_update = True
        self.cache = None
        self.param_versions = None  # should we also treat buffers?

    def require_update(self, *args):  # dummy args for use as hook
        self.needs_update = True

    def check_param_versions(self):
        if self.param_versions is None:
        for p, v in zip(self.parameters(), self.param_versions):
            if p._version != v:

    def tensor(self):
        if self.needs_update:
            self.cache = self.m()
            self.param_versions = [p._version for p in self.parameters()]
        return self.cache

    def __torch_function__(cls, func, types, args=(), kwargs=None):
        if kwargs is None:
            kwargs = {}
        args = tuple(a.tensor() if isinstance(a, ComputedParameter) else a for a in args)
        return func(*args, **kwargs)

    def __hash__(self):
        return super().__hash__()

    def __eq__(self, other):
        if isinstance(other, torch.nn.Module):
            return super().eq(other)
        return torch.eq(self, other)

Unsurprisingly, as we don't subclass Tensor, we don't have all the methods. Typing them up would be a lot of work, but happily Python let's us patch them in programatically. (And here I cut corners by ignoring properties and documentation, type annotation etc., but this is just to show off a gross hack, remember?)

# this is very overly simple and should take care of signatures, docstrings and handle class methods, properties
for name, member in inspect.getmembers(torch.Tensor):
    if not hasattr(ComputedParameter, name):
        if inspect.ismethoddescriptor(member):
            def get_proxy(name):
                def new_fn(self, *args, **kwargs):
                    return getattr(self.tensor(), name)(*args, **kwargs)
                return new_fn
            setattr(ComputedParameter, name, get_proxy(name))

We also want to be able to replace parameters with our neat new ComputedParameters, so we monkey-patch Module's __setattr__ routine.

def replace_setattr():
    # make old_setattr local..
    old_setattr = torch.nn.Module.__setattr__
    def new_setattr(self, n, v):
        oldval = getattr(self, n, 1)
        if isinstance(v, ComputedParameter) and oldval is None or isinstance(oldval, torch.nn.Parameter):
           delattr(self, n)
        old_setattr(self, n, v)
    torch.nn.Module.__setattr__ = new_setattr



Now we can get use computed parameters in the most elegant way:

l = torch.nn.Linear(3, 4)
l.weight = ComputedParameter(SpectralNormWeight(l.weight.shape))
/usr/local/lib/python3.8/dist-packages/torch/functional.py:1241: UserWarning: torch.norm is deprecated and may be removed in a future PyTorch release. Use torch.linalg.norm instead.

Our computed parameters show up in the module structure:

  in_features=3, out_features=4, bias=True
  (weight): ComputedParameter(
    (m): SpectralNormWeight()

Let's try this on an admittedly silly and trivial example, to fit a spectrally normalized target:

target = torch.randn(3, 4)
target /= torch.svd(target).S[0]
opt = torch.optim.SGD(l.parameters(), 1e-1)
for i in range(1000):
    inp = torch.randn(20, 3)
    t = inp @ target
    p = l(inp)
    loss = torch.nn.functional.mse_loss(p, t)
    if i % 100 == 0:

It works:

l.weight - target.t()
    tensor([[ 0.0062, -0.0094,  0.0040],
            [-0.0025,  0.0056, -0.0054],
            [ 0.0026,  0.0012, -0.0014],
            [ 0.0192, -0.0099,  0.0039]], grad_fn=<SubBackward0>)
    Parameter containing:
    tensor([ 6.9609e-05, -1.2957e-04, -1.0454e-04, -2.2778e-04],


I hope you enjoyed the our little hack for computed parameters.

We looked at constrained parameters, with spectral normalization as an example. We learned how to make our own Tensorish class using __tensor_function__ and used that (and a bit of exploiting the opportunities to rewire almost anything that Pyton offers) to make using parameters computed Modules easy.

Remember that this is a fun hack and for educational purposes only.And of course, by the time we fill in the gaps, we'll be at 300 lines of code, too, and have meddled a lot with Python internals to make our class a good proxy for Tensor.

If you want to take your PyTorch skills to the next level - check out my workshop offerings.

I appreciate your feedback at tv@lernapparat.de.