TorchDrift and Partial MMD Drift Detection

July 6, 2021

So I have not blogged about TorchDrift yet, even though I did a lot of writing and talking on it since we released it in March.

Introducing TorchDrift

Reginald Cleveland Coxe, Drifting


The key idea behind TorchDrift is to give you tools to check whether the data your model sees is still compatible with what you used to train and test it on. Check out our poster from the PyTorch Ecosystem Day or the first PyTorch Community Voices podcast (YouTube link) for more details.

TorchDrift came into life when weTorchDrift is a joint project of my company, MathInf GmbH, with the great colleagues at Orobix srl (if you are into PyTorch, you probably Luca Antiga, the CTO as the co-author of our book). It originated with an internal project for Orobix in the context of their invariant.ai product, but we decided to provide a library as open-source. looked into how to accomplish this, and found that there was no PyTorch library providing the necessary tooling.Alibi-Detect is a library that does drift detection on TensorFlow. It added PyTorch support later.

Detecting drift

The basic technique was relatively easy to select: Following S. Rabanser et al.: Failing Loudly: An Empirical Study of Methods for Detecting Dataset Shift, NeurIPS 2019, we use two-sample testing. Because it is important, let us look at what we do for this in detail:

  • We consider two distributions, the reference and the test distribution. At validation (or early production) time we draw some samples $X_1, ..., X_N$ from the reference distribution.
  • In production, we then collect samples $Y_1, ..., Y_M$ which we consider to be drawn from the test distribution.
  • To assess whether the model has drifted, we conduct a statistical test with the null hypothesis, that the reference distribution and the test distribution are the same. As we are using $N$ and $M$ samples rather than fixing one distribution to some analytically known one, this is a two-sample test.

There are some details to this, and an important one is how testing is done. Basically, we take a distance measure between the empirical distributions as the test statistic. A very natural choice is to take one that is an estimator for a distance between the distributions, A choice that seems to work well is the Maximum Mean Discrepancy Distance. The mechanics of statistical testing then want that we compute the distribution of this quantity under the null hypothesis. If we have no idea how to do this, we can always use the bootstrap distribution from permutations: For the two-sample test, this means that we combine the $M + N$ samples (because the null hypothesis says they're from the same distribution) and then repeatedly partition it randomly to and compute the distance. This gives us an estimate for the distribution $F$ of the distances under the null hypothesis. Now if our two samples (the real ones) have distance $d(X,Y)$, this corresponds to a $p$-value of $1 - F(d(x,y))$.

So far so good. But it turns out that in practice, the $p$-values we get are often extremely small, say $1e-10$ (so it is considered extremely unlikely that the reference and test distributions are the same), and it will set off the drift alarm quite frequently. Not good!

What do people do in these cases? The most common thing - and we do this, too - is to calibrate the threshold, i.e. forget the $p$-values and set the alarm threshold to some value that we expect or observe to be exceeded rarely enough in normal operations.

This then works reasonably well sometimes, but now we are considering things normal that our statistical model considers very unusual! We should ask us what went wrong.

What is wrong with vanilla two-sample testing

At MathInf and Orobix, we looked at this very hard and eventually we came to the conclusion that a (frequently large) part of the problem is in the underlying assumption: If we ask whether reference and test distribution are the same and the reference samples are OK, we are more or less asking whether the sample we have in the test distribution is somehow representative of the reference distribution. Quite often this is a highly unrealistic expectation in practice:

  • There may be fluctuations in the environment that mean that the reference dataset is more rich than we expect the test dataset to be. For example, outdoor lighting conditions vary over time through day and night, with weather conditions, or the seasons of the year. If our model is sufficiently trained on to cover all of these, it will operate in normal conditions even though the test samples, drawn from a much shorter time interval than the reference, do not show this variety. Now, we could make the variation explicit and then test against conditional references, but this would cause significant additional effort.
  • Inputs provided by human users, e.g. search queries, are likely to have spikes for user interests in a given time range rather than uniformly querying the entire reference (e.g. the contents of a knowledge database). This, too, should count as normal operation.

One way to mitigate these effects could be to enlarge the data sample (and the time to collect it), but this may render timely drift detection infeasible.

One thing to note here is that outlier detection, a technique where we ask whether a single sample could raesonably have originated from the reference distribution would conclude that the inputs in the above situation are not outliers. However, we still want the model monitoring to consider a collapse of the input distribution (e.g. a frozen camera image from a visual inspection system) to be flagged as problematic, but an outlier detection, unaware of the distribution of the inputs beyond a single given sample, cannot identify this. In this sense, we need a bridge between drift detection and outlier detection: We want the multi-sample approach on the test side like drift detection. At the same time we aim to remove the requirement to compare against the full test distribution, a requirement that outlier detection does not have.

A toy example

As an example, consider a binary classifier. In an idealized setting, it might have a latent space feature density like this (darker = higher probability density):

Now one thing that might happen is that we get a very imbalanced batchOf course, there are contexts where we would desire to detect such a class imbalance as drift from a more balanced expectation. But, crucially, not always!:

Now is this test batch representative of our reference distribution? Certainly not! But has it drifted, i.e. does the model operate out of spec now? It depends!

But now if we calibrate our drift detector to accept this by raising the detection threshold (above the 0.36), we miss configurations where drift has clearly occurred, such as the following:

So this is what we want to solve!

How to improve drift detection

So it is natural to ask whether we can check if the test distribution is representative for part of the reference distribution. It turns out that we can, and here is how we did it:

Our first step was to look at the Wasserstein distanceThe most ardent followers of this blog will know that I have a faible for the Wasserstein distance., also known as the Earth Mover's Distance. In a nutshell and specializing to the discrete case, given some cost function $C(X_i, Y_j)$, it tries to match points $X_i$ and $Y_j$ (allowing fractional matches) such that some the functional $W(X,Y) =\sum_{i, j} P(X_i, Y_j) C(X_i, Y_j)$ is minimalTo get the $p$-Wasserstein distance, one would choose $C(X_i,Y_i) = d(X_i,Y_i)^p$ and consider $W(X,Y)^{1/p}$, but for the purpose of hypothesis testing, the exponent does not matter until you approximate the distribution of the test statistic.. Here $P(X_i, Y_j) \geq 0$ gives the mass that is matched.

The sum is then (a power of) the Wasserstein distance. The relation to the empirical distributions behind the $X_i$ and $Y_j$ is that if each point has weight $1/N$ and $1/M$, we ask that $\sum_j P(X_i, Y_j) = 1/N$ and $\sum_i P(X_i, Y_j) = 1/M$ (and in addition to $P \geq 0$).

How does this help? It lets us do attribution, i.e. we can break down the overall distance to contributions of individual points. Now if we only want to match part of the reference distribution, we can just invent some unfavourable test point that is equally far away from the reference points. Then the optimal transport plan will map the real test points to nearby reference points and the remaining mass maps to the imaginary distant point. But now we can just leave out the that part of the mapping when computing the distance to get something that doesn't depend on the distant point.I was all proud of this trick, but of course, L. Caffarelli and R. McCann knew it a decade ago: Free boundaries in optimal transport and Monge-Ampère obstacle problems., Annals of Mathematics (171), 2010.

If we had mass $1-\alpha$ at the distant point, our partial matching now satisfies $\sum_i P(X_i, Y_j) = \alpha/M$ and $0 \leq \sum_j P(X_i, Y_j) \leq 1/N$ and we might rescale the cost functional by $1/\alpha$This means that if we mix in distant masses $1-\alpha$ on both sides and match only the "original part" $\alpha$, the $W_\alpha$ distance recovers the value of $W$ on the original distributions. to define $$ W_\alpha(X,Y) = \frac{1}{\alpha} \sum_{ij} P(X_i, Y_j) C(X_i, Y_j). $$

So this helps to not detect drift when there is a controlled narrowing. It turns out, however, that drift detectors using the Wasserstein distance as the test statistic have trouble detecting drift reliably (at least in our experiment), even in the vanilla "full match" case. So what to do?

Revisiting the MMD distance

The maximum mean discrepancy distanceA. Gretton et al.: A kernel two Sample Test; JMLR 13(25):723−773, 2012., which powers what has become our bread and butter drift detector appears to be have much better drift detection performance in the full match case. So it is natural to ask whether we can apply a similar trick as for the Wasserstein distance.

It turns out the trick is a bit different. For empirical distributions, MMD is computed using the matrix of kernel evaluations at pairs of pointsThere are several versions of this estimate. This one, considered as an estimate of the squared distance $|\mu_X - \mu_Y|^2$ between the distributions from which $x_i$ and $y_i$ are drawn, is biased and for an unbiased estimator one would want to remove the diagonals in the first two terms., i.e.

$$ \begin{aligned} MMD^2(X, Y) &= \frac{1}{n^2} \sum_{i} \sum_{j} k(x_i, x_j) + \frac{1}{m^2} \sum_{i} \sum_{j} k(y_i, y_j) \\ &\qquad - 2 \frac{1}{n m} \sum_{i} \sum_{j} k(x_i, y_j). \\ \end{aligned} $$ One thing we see here: in contrast to the coupling in the Wasserstein case all points from $X$ and all points from $Y$ interact in the same way. This means that adding a point and then removing it from the summation like we did above does not help us here.

But if we introduce two vectors $v = (1/N)_{j=1,...,N}$ and $v = (1/M)_{j=1,...,M}$ and introduce the kernel matrices $K^X = k(x_i, x_j)$, $K^Y = k(y_i, y_j)$ and $K^{XY} = k(x_i, y_j)$, we can rewrite this in matrix notation as

$$ MMD^2(X, Y) = w^T K^X w + v^T K^Y v - 2 w^T K^XY v. $$

But now $w$ is a weight vector representing a uniform distribution on the weight samples. A partial matching would deviate from this uniformity by allowing some weights to be $0$ and the others to grow larger. We can use the Wasserstein coupling above to get a replacement weight incorporating the idea of matching a fraction $\alpha$ by taking the marginal of the coupling (scaled by $\frac{1}{\alpha}$ to absorb the normalization factor we had above) $$ w^{twostage}_i := \frac{1}{\alpha} \sum_j P(x_i, y_j). $$

We call this the two-stage weight (and define the $MMD^2_{\alpha, twostage}$ with it) because we first use the Wasserstein distance and then the MMD distance. It turns out that this is a very good test statistic for our drift detection.

But we can expand on the idea of computing the MMD distance on a partial set of points by optimizing over the weight $w$ instead. The natural choice for the set $\mathcal M$ of admissible weight vectors $w$ is the set we identified as possible weights in our look at the partial Wasserstein distance: $$\mathcal M = \{w \in R^N | 0 \leq w_i \leq \frac{1}{\alpha N}, \sum_i w_i = 1 \}.$$ We thus define the partial MMD distance as the minimum $$ MMD^2_{\alpha} = \min_{w \in \mathcal M} w^T K^X w + v^T K^Y v - 2 w^T K^XY v. $$

This is a quadratic programming problem with equality and inequality constraints. As such, it is standard, but not "easy" to solve in the sense that there exist libaries like quadprog that do the solution for us, but the solution takes quite long to compute (for our application).

The simplicity of the problem means that we can also implement an ad-hoc active-set optimization schemeThe algorithm we implemented for TorchDrift is described in the report. (but our implementation cheats a bit because we do not perfectly project the solution back into the admissible set, potentially allowing some $w_i$ to be larger than the $\frac{1}{\alpha N}$ when we scale $w$ to enforce summing to $1$).

With this definition of the partial MMD distance, our two-stage weight $w^{twostage}$ is admissible and so we get the upper bound $MMD^2_{\alpha, twostage} \geq MMD^2_{\alpha}$. But is it a good approximation? Our empirical experiments suggest no: It seemed that $MMD^2_{\alpha, twostage}$ often was an order of magnitude larger. However, as our interest is the drift detector we get from it, that seemed to work rather well.

For TorchDrift, this means that while we implement the quadratic programming solution and this somewhat faster approximation, the two-stage drift detector is much cheaper computationally. Thus, until we have a faster implementation of the specific QP problem, the two-stage drift detector is our first stop when we find that we have permissible fluctuation in our deployed input and feature distribution.

Back to our example

We can see how this works in our toy example. If we match the full reference (I use the Wasserstein coupling and return the largest matches), half the probability mass needs to go to the right hand side, giving a very large distance.

On the other hand, if we only match 15% of the distribution (the size relation of the reference and test data), we get a rather clean match and small distance.

The drifted data is also able to take advantage of the partial matching, but the increased distance remains very visible:

The MMD distances shown in the plot titles have been computed with the two-stage method discussed above.

Bootstrapping improvements

There is another new feature of TorchDrift that incorporates what we think is mathematical progress. When computing $p$-values, there is a subtle loss of information when using the two-sample testing including the $p$-value as a black box.

The two-sample test null hypothesis is, of course, that $x_i$, $i=1,...,N$ and $y_j$, $j=1,..., M$ are sampled from the same distribution $P$. Then the bootstrapping pools $x_i$ and $y_j$ and computes the test statistic for random splits of the joint set into subsets of cardinality $N$ and $M$ to simulate drawing from the distribution $P$. This is all well, but in drift detection, the stochastic model is that the distribution $P_X$ of the $x_i$ is fixed and the null hypothesis is that $y_i$ are drawn from $P_X$. This makes the pooling step dubious, as it will "pollute" the distribution of the test-statistic for non-drifted data with whatever we get as the test set. We can improve the bootstrapping by taking $N+M$ samples from the reference distribution $P_X$ during the fitting of the drift detector.

Not only is this mathematically more sound, but we may also gain from it computationally: We can now fit a suitable parametric distribution to the bootstrap sample during fitting of the drift detectorGretton suggested a gamma distribution approximation. We found that this works even better when incorporating considering a shift, so we determine the shift from the minimum value we observe and use moment fitting to obtain a shifted gamma distribution.. This saves us from having to do the bootstrap sampling during the actual drift detection as we can compute the test statistic and plug it into (one minus) the test distriution to get the $p$-value.

TorchDrift does this if you provide the n_test argument to the call of the fit method.

Is this the end of calibration?

So we started by discussing why calibration is a very unsatisfactory answer to dealing with overly large detection rates when using $p$-value thresholds. Will we not need calibration when deploying our new drift detectors?

While we think that this improved methodology is taking drift detection a good step forward, it is very likely that there are still gaps. Until we can further refine our methodology, there will be the need for some big calibrarion hammer to overcome the mismatch between model and reality. But in many instances, we can do this much more judiciously with the improvements described above.

Try it in TorchDrift

You can try the new Partial MMD methods in TorchDrift today, by checking out the git repository. A release will follow soon!

A more mathematical writeup is in our report Partial Wasserstein and Maximum Mean Discrepancy distances for bridging the gap between outlier detection and drift detection.

Consulting and commercial support

If you need help with your drift detection or the process for deploying your models in general, do check out our commercial offerings at MathInf GmbH and Orobix srl.

I hope you enjoyed this little overview of brand new things in TorchDrift. I welcome your feedback at tv@lernapparat.de.