DRMacIver's Notebook

Speeding up conditional sampling with divide and conquer

Speeding up conditional sampling with divide and conquer

I got myself into a bit of a mess trying to explain why a particular result was true the other day, so this is my write up of a proof.

Suppose you’re sampling from some language model (large or otherwise) and you want to apply a constraint \(h\). That is, you’ve got some random variable \(S\), and you want to sample a random variable \(T\) such that \(P(T = t) \propto P(S = t) h(t)\).

One way to achieve this is rejection sampling: You just repeatedly sample from IID copies of \(S\) until you get one that satisfies \(h\).I’m playing a bit fast and loose with random variables here and just assuming you can take IID copies of any of them. Really our “random variables” of interest are randomized programs.

The problem with this is that it’s potentially very slow. It might require a lot of samples from \(S\), especially if you make bad choices on your initial choices of characters in the sequence.

Here’s a way of improving the performance , if for each character \(c\), we can calculate \(a(c) = P(h(S) | S \text{ starts with } c)\).

We do this by reweighting the probability of each start character \(c\) by \(\tau(c)\), so we choose \(c\) with probability proportional to \(P(S \text{ starts with } c) \tau(c)\). If \(c\) is the special EOS token, we stop, otherwise we then recursively apply this method to the rest of the string, reweighting the next character in a similar way (with newly calculated probabilities).

The advantage of this method is that we never have to backtrack or repeat ourselves: Our reweighting of the characters automatically takes the constraint into account without having to ever start over again.

The disadvantage is that we have to get these \(\tau\) reweights from somewhere, which may or may not be possible in general. I’m interested in some special cases where it’s easy because the condition has a nice simple structure that just comes from deleting prefixes, but we won’t go into that here.

The reason this works is essecially a light variant of the Probabilistic divide and conquer method, which rests on the following lemma:

Suppose we have discrete random variables \(A, B\) on \(U, V\), with joint law \(q\) and some constraint \(h: U \times V \to \{0, 1\}\) with \(Z = E(h(A, B)) > 0\), and we want to sample from the conditional distribution \(A, B | h(A, B)\). That is, we want random variables \(X, Y\) such that \(P(X = x, Y = y) = \frac{q(x, y) h(x, y)}{Z}\).

Construct \(X, Y\) as follows:

  1. Sample \(X\) such that \(P(X = x) = P(A = x | h(A, B))\).
  2. Then sample \(Y\) such that \(P(Y = y) = P(B = y | h(A, B), A = x)\).

If you can do this, then \(P(X = x, Y = y) = \frac{q(x, y)}{Z}\) as desired.

In the original PDC paper they describe this as a simple application of Bayes’ formula, which maybe it is but I then got myself into a muddle trying to prove it. In the end I found it easier to go back to just a straightforward definition of conditional probability. First, \(P(X = x, Y = y) = P(X = x) P(Y = y | X = x)\).

We now calculate these two quantities:

\[\begin{align*} P(X = x) &= P(A = x | h(A, B)) \\ &= \frac{P(A = x, h(A, B))}{P(h(A, B))} \\ &= \frac{P(A = x, h(A, B))}{Z} \\ \end{align*}\]

Then:

\[\begin{align*} P(Y = y | X = x) &= P(B = y | A = x, h(A, B)) \\ &= \frac{P(B = y, A = x, h(A, B))}{P(A = x, h(A, B))} \\ &= \frac{q(x, y) h(x, y)}{P(A = x, h(A, B))} \\ \end{align*}\]

So then multiplying these together we get:

\[\begin{align*} P(X = x, Y = y) & = P(X = x) P(Y = y | X = x) \\ & = \frac{P(A = x, h(A, B))}{Z} \frac{q(x, y) h(x, y)}{P(A = x, h(A, B))} \\ & = \frac{q(x, y) h(x, y)}{Z} \\ \end{align*}\]

As desired.

This result is practically trivial, and it’s not obvious a priori that sampling this way is in fact any easier than the original constrained sampling problem, but the observation of PDC is that sometimes it is, and that when it is it can be a huge speed improvement.