An efficient implementation of the Sinkhorn algorithm for the GPU
So I wrote quite a bit about PyTorch itself, today, we are doing a bit of cool things with PyTorch again. Unsurprisingly to regular readers, I use the Wasserstein distance as an example.
The Wasserstein distance has seen new applications in machine learning and deep learning. It commonly replaces the KullbackLeibler divergence (also often dubbed crossentropy loss in the Deep Learning context). In contrast to the latter, Wasserstein distances not only consider the values probability distribution or density at any given point, but also incorporating spatial information in terms of the underlying metric regarding these differences. Intuitively, it yields a smaller distance if probability mass moved to a nearby point or region and a larger distance if probability mass moved far away.
There are two predominant variants of Wasserstein distance approximations used in machine learning:

Stochastically optimised online estimates of the Wasserstein distance. This is the concept underpinning many of the GAN applications using a (heuristic approximation of) the Wasserstein distance as a discriminator. Starting from the Wasserstein GAN as an improvement over the KLbased DCGAN, with improvements to how to estimate the Wasserstein distance in WGANGP, and SNGAN.

Direct computation of the Wasserstein distance as a replacement for the crossentropy loss in minibatch training. This is commonly done using the entropy regularised Wasserstein distance and the Sinkhorn iterations Cuturi. In the context of deep learning this has been proposed by Frogner et al., but there is also earlier work in image retrieval using the (nonregularised) Wasserstein distance, see e.g. Y. Rubner et al. A comprehensive treatment is given in PeyrÃ© and Cuturi's book, R. Flamary's Python Optimal Transport library provides implementations for many algorithms in this area.
Two years ago, I wrote a function for the latter, but then blogged more about the first. Today, we revisit the latter use of the Wasserstein distance. One of the challenges is the numerical stability of the Sinkhorn iteration and carrying that over to minibatch computations efficiently. While the ingredients appear to be readily available, it seems that they have not been put together in recent implementations we observed, so I went ahead and put all the maths in one place (if only to be sure all sign and $\lambda$ or $1/\lambda$ conventions are the same) and then implemented a kernel loosely based on the BatchNorm improvements I contributed to PyTorch.
My writeup, Implementation of batched Sinkhorn iterations for entropyregularized Wasserstein loss is on arXiv 1907.01729, the I put the kernel itself on github.
I particularly like that the implementation is fast enough to support interpolation in the Barycenter variant:
The GPU kernel
At the core of the algorithm is the Sinkhorn step
$$ \log v_{bj} := \log \nu_{bj}  logsumexp_{i} (\frac{1}{\lambda} c_{ij} + \log u_{bi}) $$
This has two key properties that shape our implementation:

The overall reduction structure is akin to a matrix multiplication, i.e. we need memory accesses to $c_{ij}$ and $\log u_{bi}$ to compute the result $\log v_{bj}$, with the additional input $\log \nu$ following the same access pattern as the result. We parallelize in the independent dimensions ($b$ and $j$) and split the reduction over $i$ amongst multiple threads then combine their intermediate results. We have not employed tiling, which is commonly used to speed up the memory accesses for matrix multiplication.

In our implementation, the stabilisation of the
logsumexp
calculation is carried out in an online fashion, i.e. computing the stabilisation and the reduction result in a single pass, similar to the Welford algorithm for the variance.
Most of it is straightforward, but maybe this is a good chance to elaborate on the reduction technique. It is the same technique used by many reductions, I started from an the optimized BatchNorm implementation I contributed to PyTorch (when I learned this pattern in detail because I added the Welford bits while rewriting it in C++).
Recall that in CUDA and ROCm threads are organized into warps. Threads within a warp can communicate very efficiently (using "warp shuffle" commands such as __shfl_xor_sync
which are more or less register operations), while all threads in the same thread block can only communicate via shared memory. The warp size is 32 for NVidia and 64 for AMD GPUs.
We start out by letting (up to) $\textrm{warpsize}^2$ threads reduce their part of the $i$ axis. This is just a for loop. Then each warp spends $log_2 \textrm{warpsize}$ steps reducing all information in the warps to the first thread in the warp (the warp leader, indexed by a thread id that is divisible by the warp size). Each warp leader stores the result in shared memory. Now we need to synchronise the threads.
Then one warp (say, the first) reads the results and then reduces again using shuffle operations. The first thread (thread id $0$) then can write the result to memory.
When calling such kernels, care must be taken that when you have less than $\mathrm{warpsize}^2$ threads, you have to prune the reduction to not access uninitialized values.
This simple trick seems to be usable for many reductions, it is a recurring pattern in GPU programming. This isn't even close to the tricks of high performance matirx multiplications, which use several levels of tiling, but it certainly helps you write decent GPU kernels.
Let's look at the kernel in detail:
The PackedTensorAccessor
gives the CUDA kernel a pointer to the data along with stride and size informaiton (as values, not as references as the TensorAccessor
would). You can get them with tensor.packed_accessor<...>()
. The drawback is that you need to know the dimension at compiletime.
template <typename scalar_t, typename index_t> __global__ void sinkstep_kernel( // compute log v_bj = log nu_bj  logsumexp_i 1/lambda dist_ij  log u_bi // for this compute maxdiff_bj = max_i(1/lambda dist_ij  log u_bi) // i = reduction dim, using threadIdx.x PackedTensorAccessor<scalar_t, 2, RestrictPtrTraits, index_t> log_v, const PackedTensorAccessor<scalar_t, 2, RestrictPtrTraits, index_t> dist, const PackedTensorAccessor<scalar_t, 2, RestrictPtrTraits, index_t> log_nu, const PackedTensorAccessor<scalar_t, 2, RestrictPtrTraits, index_t> log_u, const scalar_t lambda) { using accscalar_t = scalar_t;
We declare a shared memory variable. We want two values per warp for the reduction. We also set up indices and the thread id for the reduction and immediately exit if we're not in the range of our tensor.
__shared__ accscalar_t shared_mem[2 * WARP_SIZE]; index_t b = blockIdx.y; index_t j = blockIdx.x * blockDim.z + threadIdx.z; int tid = threadIdx.x; if (b >= log_u.size(0)  j >= log_v.size(1)) { return; }
The next step is to reduce within each thread. At the heard there is just a for
loop with a step size from the number of threads working on the reduction. I chose the online logsumexp
. As Natalia Gimelshein and Christian Sarofeen pointed out, having more exp
than needed in the hot loop is bad enough that it might be more efficient to just take the maximum separately. A nice writeup of the online sumexpreduction used in logsumexp
is in in M. Milakov & N. Gimelshein: Online normalizer calculation for softmax.
// reduce within thread accscalar_t max = std::numeric_limits<accscalar_t>::infinity(); accscalar_t sumexp = 0; if (log_nu[b][j] == std::numeric_limits<accscalar_t>::infinity()) { if (tid == 0) { log_v[b][j] = std::numeric_limits<accscalar_t>::infinity(); } return; } for (index_t i = threadIdx.x; i < log_u.size(1); i += blockDim.x) { accscalar_t oldmax = max; accscalar_t value = dist[i][j]/lambda + log_u[b][i]; max = max > value ? max : value; if (oldmax == std::numeric_limits<accscalar_t>::infinity()) { // sumexp used to be 0, so the new max is value and we can set 1 here, // because we will come back here again sumexp = 1; } else { sumexp *= exp(oldmax  max); sumexp += exp(value  max); // if oldmax was not infinity, max is not either... } }
Now that each thread has done its work, we reduce within the warp to get to one value per warp. The shuffle instructions are fast. At the $i$th reduction step, we reduce so that each thread has the reduced value from the next $2^i$ threads. When $2^i$ is the WARP_SIZE
, thread $0$ has all of them.
for (int i = 0; i < getMSB(WARP_SIZE); ++i) { accscalar_t o_max = __shfl_xor_sync(/*mask=*/0xffffffff, max, 1 << i, /*width=*/WARP_SIZE); accscalar_t o_sumexp = __shfl_xor_sync(/*mask=*/0xffffffff, sumexp, 1 << i, /*width=*/WARP_SIZE); if (o_max > max) { // we're less concerned about divergence here sumexp *= exp(max  o_max); sumexp += o_sumexp; max = o_max; } else if (max != std::numeric_limits<accscalar_t>::infinity()) { sumexp += o_sumexp * exp(o_max  max); } }
Now every warp leader stores it's result into shared memory. We use __syncthreads
before and after so everyone has the same shared memory. After that the first warp reads the (at most WARP_SIZE
) results from the shared memory.
__syncthreads(); // this writes each warps accumulation into shared memory // there are at most WARP_SIZE items left because // there are at most WARP_SIZE**2 threads at the beginning if (tid % WARP_SIZE == 0) { shared_mem[tid / WARP_SIZE * 2] = max; shared_mem[tid / WARP_SIZE * 2 + 1] = sumexp; } __syncthreads(); if (tid < WARP_SIZE) { max = (tid < blockDim.x / WARP_SIZE ? shared_mem[2 * tid] : std::numeric_limits<accscalar_t>::infinity()); sumexp = (tid < blockDim.x / WARP_SIZE ? shared_mem[2 * tid + 1] : 0); }
With that we have a second round of reductions. We are only interested in the first warp. At the end the very first thread in the block has the reduced value.
for (int i = 0; i < getMSB(WARP_SIZE); ++i) { accscalar_t o_max = __shfl_xor_sync(/*mask=*/0xffffffff, max, 1 << i, /*width=*/WARP_SIZE); accscalar_t o_sumexp = __shfl_xor_sync(/*mask=*/0xffffffff, sumexp, 1 << i, /*width=*/WARP_SIZE); if (o_max > max) { // we're less concerned about divergence here sumexp *= exp(max  o_max); sumexp += o_sumexp; max = o_max; } else if (max != std::numeric_limits<accscalar_t>::infinity()) { sumexp += o_sumexp * exp(o_max  max); } }
The first thread does the postprocessing and writes the result to memory.
if (tid == 0) { log_v[b][j] = (max > std::numeric_limits<accscalar_t>::infinity() ? log_nu[b][j]  log(sumexp)  max : std::numeric_limits<accscalar_t>::infinity()); } }
And that's the kernel! The full code for the Sinkhorn step is in the notebook.
Conclusion
We have a fast batched stable sinkhorn iteration. As argued in the article, the speedup would appear to enable using many more iterations than the existing implementation and so allow to use (almost) converged distances and the analytical gradient.
Here on the blog, we used the opportunity to look at a common reduction technique.
I hope you will find exiting new uses for the Wasserstein distance and for the reduction technique. When you do please do let me know!
PyTorch training
Do you want to learn to write awesome GPU kernels or generally give your PyTorch and Deep Learning skills a boost? I offer inhouse and public workshops for beginner, intermediate and PyTorch expert levels. If you are in near Munich (say, in Europe) and need PyTorch training, I love to hear from you! I also do bespoke development.
I hope this blog post is useful to you, I appreciate and read every mail you send to tv@lernapparat.de.