PyTorch is a great project and I have only met very helpful people when contributing to it. However, the code base can be quite intimidating. Here we look at fixing a simple bug in detail and see that it is a less daunting task than it might seem at first.
You have probably seen the awesome notes on PyTorch internals by my fellow PyTorch developer Edward Z. Yang, but also the great slides from Christian Perrone or even an earlier blog post of yours truly. If not, I highly recommend browsing through at least one of them.
The first step is to build your own PyTorch. The source code comes with a contributor's tutorial, which covers building, but also has a brief overview of the different components of the PyTorch repository.
Choosing a bug
We have checked out and built PyTorch. What now? Let us go fix a bug! But which one? PyTorch keeps a list of small issues which may be a good start.
For this example I chose a segmentation fault in
pack_padded_sequence.1 One super-nice thing about this issue is that the submitter provided a small script to reproduce the issue. That gives us a head start, as it is easy to trigger the issue. Sometimes, getting such a reproducing script is the hardest and most time consuming step in fixing a bug. It is a good idea to send a small notice to the bug saying that you are working on it so no work is duplicated - every small bug likes a friend to attend to it, so it is good to spread the love to many of them.
All theory is gray, my friend
To make things less abstract, I recorded a screen cast of me fixing the bug. It is real time, except for an edit for an unintended recompile.
When debugging a crash, it is always handy to fire up a debugger to get a stack trace.
I usually save the the reproducing script somewhere (
~/pytorch/scripts/ for me).
Then, from the torch checkout with the built PyTorch in it, I run
PYTHONPATH=./build/lib.SOMETHING/ gdb -ex run --args python3 ~/pytorch/scripts/my_repro.py. What this does is have
gdb load everything after the
--args and then run it. For a segfault, it will automatically stop when the segfault happens and can do
bt or so to get a backtrace. If you want to debug an exception
-ex catch throw before
-ex run is handy, and there also is
-ex 'break somefile.cpp:<lineno>' set a breakpoint at a specific line. Dynamic module loading has it that you'll be asked about the file not being known yet.
The backtrace tells us that the error happens in
_pack_padded_sequence. The reason is that we have a
TORCH_CHECK(lengths[batch_size - 1] > 0,...), but
0, a classic C++ bug...
Fixing the bug
We can fix this by checking whether the tensor is empty (as in has
0 elements) before assuming that it is at least one.
Most C++ functions correspond to the Python equivalents in PyTorch, so we can add a
TORCH_CHECK(input.numel(), ...) above, indeed, I linked the file after the fix above and you find this two lines above. A one line bugfix.
But we're not done: We should always test our code. To this we look for a good place to add our new test.
test/ directory, we find a number of files with tests. The most prominent ones are
test_torch.py for functions in ATen,
test_cuda.py for GPU-functions in ATen,
test_autograd.pyfor Autograd, and
test_jit*.py for JIT-related tests, and
test_nn.py for NN-related tests, but there are others, too. We changed a function exposed to Python via
test_nn.py is where we want to add a test.
There already is a test for
pack_padded_sequence unsurprisingly named
test_pack_padded_sequence2. It is a long function. At the very end there is an interesting bit
# test error message and a small
with self.assertRaisesRegex: block. We copy this block but pass an empty tensor and adapt the expected error message. (You can see my result in on github.) Of course, your test will look differently. The most common test pattern is perhaps using with
self.assertEqual to compare with some expectation, NumPy calculation or some such. You can check that the test segfaults without our fix and works with it.
Before you submit
There are two things you should do before you commit your patch:
- run the test suite in order to check that the tests are passing, (
PYTHONPATH=./build/libTHERIGHTTHING python3 test/run_test.py, don't forget the correct PYTHONPATH or you'll test with some system-installed PyTorch),
python3 -mflake8 test/test_nn.pyor whatever Python modules you have touched to make sure you've not messed up the formatting,
- for C++ there are less strict formatting rules, but clang format is considered good style (and some files are fully formatted, you would want to keep it that way
. You can usegit-clang-format` for your changes, but do not include formatting-only changes.
- Read the fine print in the contributor's guide.
After that, you are good! Push a git branch and open a PR (referencing the right issue number). Thank you for helping out!
After submitting the PR, the continuous integration will kick in to test your PR. Do watch what is happening there and look at the failure logs. Sometimes tests will be flaky, i.e. have spurious failures, but they are there to save us from introducing embarrassing bugs in our PRs. To me the most interesting times to check this are after 5 minutes, half an hour and maybe three hours or so.
So this bug has been fixed in less than an hour wall clock time (because the rebuilt after make clean took quite a while). Do not despair if it takes longer. The first PR is always the hardest and I picked a particularly easy one to record the video with - I've spent weeks on some of my more intricate PRs. It should be fun!
If you are close to Munich and want to learn more about how to use PyTorch, do checkout my workshop offering.
I hope you enjoyed this little trip into working on PyTorch. I would love to hear from you if you have feedback on this blog post and screen cast (it's my first) and also if you have topics you'd like to see covered here. Mail me at email@example.com.