Lernapparat

Maschinelles Lernen Lernen

Debugging CUDA device-side assert in PyTorch

June 15, 2018

The beautiful thing of PyTorch's immediate execution model is that you can actually debug your programs. Sometimes, however, the asynchronous nature of CUDA execution makes it hard. Here is a little trick to debug your programs.

When you run a PyTorch program using CUDA operations, the program usually doesn't wait until the computation finishes but continues to throw instructions at the GPU until it actually needs a result (e.g. to evaluate using .item() or .cpu() or printing).

While this behaviour is key to the blazing performance of PyTorch programs, there is a downside: When a cuda operation fails, your program has long gone on to do other stuff. The usual symptom is that you get a very non-descript error at a more or less random place somewhere after the instruction that triggered the error. It typically looks like this:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-4-3d8a992c81ab> in <module>()
      1 loss = torch.nn.functional.cross_entropy(activations, labels)
      2 average = loss/4
----> 3 print(average.item())

RuntimeError: cuda runtime error (59) : device-side assert triggered at /home/tv/pytorch/pytorch/aten/src/THC/generic/THCStorage.cpp:36

Well, that is hard to understand, I'm sure that printing my results is a legitimate course of action. So a device-side assert means I just noticed something went wrong somewhere.

Here is the faulty program causing this output:

import torch
device = torch.device('cuda:0')
activations = torch.randn(4,3, device=device) # usually we get our activations in a more refined way...
labels = torch.arange(4, device=device)
loss = torch.nn.functional.cross_entropy(activations, labels)
average = loss/4
print(average.item())

One option in debugging is to move things to CPU. But often, we use libraries or have complex things where that isn't an option. So what now? If we could only get a good traceback, we should find the problem in no time.

This is how to get a good traceback:You can launch the program with the environment variable CUDA_LAUNCH_BLOCKING set to 1. But as you can see, I like to use Jupyter for a lot of my work, so that is not as easy as one would like. But this can be solved, too: At the very top of your program, before you import anything (and in particular PyTorch), insert

import os
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"

With this addition, we get a better traceback:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-4-3d8a992c81ab> in <module>()
----> 1 loss = torch.nn.functional.cross_entropy(activations, labels)
      2 average = loss/4
      3 print(average.item())

/usr/local/lib/python3.6/dist-packages/torch/nn/functional.py in cross_entropy(input, target, weight, size_average, ignore_index, reduce)
   1472         >>> loss.backward()
   1473     """
-> 1474     return nll_loss(log_softmax(input, 1), target, weight, size_average, ignore_index, reduce)
   1475 
   1476 

/usr/local/lib/python3.6/dist-packages/torch/nn/functional.py in nll_loss(input, target, weight, size_average, ignore_index, reduce)
   1362                          .format(input.size(0), target.size(0)))
   1363     if dim == 2:
-> 1364         return torch._C._nn.nll_loss(input, target, weight, size_average, ignore_index, reduce)
   1365     elif dim == 4:
   1366         return torch._C._nn.nll_loss2d(input, target, weight, size_average, ignore_index, reduce)

RuntimeError: cuda runtime error (59) : device-side assert triggered at /home/tv/pytorch/pytorch/aten/src/THCUNN/generic/ClassNLLCriterion.cu:116

So apparently, the loss does not like what we pass it. In fact, our activations have shape batch x 3, so we only allow for three categories (0, 1, 2), but the labels run to 3!

The best part is that this also works for nontrivial examples. Now if only we could recover the non-GPU bits of our calculation instead of needing a complete restart...


'