Visualizing PyTorch model structure

June 16, 2020

Did you ever wish to get a concise picture of your PyTorch model's structure and found that too hard to get?

Recently, I did some work that involved looking at model structure in some detailMore on this very soon!. For my write-up, I wanted to get a diagram of some model structures. Even though it is a relatively common model, searching for a diagram didn't turn up something in the shape what I was looking for.

So how do can we get model structure for PyTorch models? The first stop probably is the neat string representation that PyTorch provides for nn.Modules - even without doing anything, it'll also cover our custom models pretty well. It is, however not without shortcomings.

Let's look at TorchVision's ResNet18 basic block as an example.

m = torchvision.models.resnet18()


  (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)

So we have two convs and two batch norms. But how are things connected? Is there one ReLU?

Looking at the forward methodYou can get this using Python's inspect module (or ?? in IPython) print(inspect.getsource(m.layer1[0].forward))., we see some important details not in the summary:

def forward(self, x):
    identity = x

    out = self.conv1(x)
    out = self.bn1(out)
    out = self.relu(out)

    out = self.conv2(out)
    out = self.bn2(out)

    if self.downsample is not None:
        identity = self.downsample(x)

    out += identity
    out = self.relu(out)

    return out

So we missed the entire residual bit. Also, there are two ReLUs.Arguably, it is wrong to re-use stateless modules like this. It'll haunt you when you do things like quantization (because it becomes stateful then due to the quantization parameters) and it's mixing things too much. If you want stateless, use the functional interface.

So I looked at bti around and not finding something doing what I wanted, I made up a small function producing a graph like this:Made with make_graph(getattr(traced_model.layer1, "0")). I should say that I'm making the assumption that all interesting things happens in sub-modules and only show operation nodes when they are "junctions", i.e. has several inputs with distinct predecessors in the graph.

ResNet18 basic block

A high level representation look like this:Produced with make_graph(traced_model, classes_to_visit={'Sequential'}).

ResNet18 high level view

But we can also have the full thing:make_graph(traced_model)

ResNet18 in full detail

More advanced vision models

We can also do this with TorchVision's FCN ResNet50 for semantic segmentation:The graph can be made with make_graph(traced_model, classes_to_visit={'IntermediateLayerGetter', 'FCNHead'}), but tracing needs strict=False.

FCN Rsnet50

Segmentation models are also fun, here FasterRCNN:This needed a trick to make tracing work, I wrapped it in a model taking apart the output dictionary, also tracing needed check_trace=False. Then the graphs were made with make_graph(traced_model, classes_to_visit={}) and make_graph(traced_model, classes_to_visit={'RegionProposalNetwork', 'RoIHeads'}).

The high-level view is:

FasterRCNN high-level

And here is a more detailed view:

FasterRCNN with more detail


Finally, we aren't restricted to vison models. Taking BERT from HuggingFace's great transformers library, we can make an overviewI used make_graph(traced_model, classes_to_visit={'BertEncoder'})

BERT model

and zoom into a BertLayer:make_graph(getattr(traced_model.encoder.layer, "0"), classes_to_visit={'BertAttention', 'BertSelfAttention'}) - I should send a PR to PyTorch to enable indexing there...



I'm having way too much fun with this. Some things where one could reconsider or extend the design:

  • I don't handle loops and other control flow yet. This works as long as I restrict myself to traced modules (and it'll break if I have scripted modules called from the regular ones).
  • Instead of just looking at classes, what gets detailed might also be given by the submodule name.
  • The JIT doesn't seem to add nice names to the inputs of submodules during tracing. We'd love to have these, but it can be nontrivial to get the name (e.g. if you pass a tuple of tensors to a submodule, the tracer will only see the elements, not the tuple).

The code for the visualizations

The code was on the wish list of my single github sponsor when I asked a while ago what code he would like to be published, so with all its limitations it is now available.

PyTorch training and consulting

Do you want to get some help getting models to do awesome things or generally give your PyTorch and Deep Learning skills a boost? I offer consulting and inhouse and public workshops for beginner, intermediate and PyTorch expert levels. If you are in near Munich (say, in Europe) and need PyTorch training, I love to hear from you! I also do bespoke development.

I hope this blog post is useful to you, I appreciate and read every mail you send to tv@lernapparat.de.