DRMacIver's Notebook

Separating Sampling and Removal

Separating Sampling and Removal

This is a weird data structure that as far as I know I invented (it’s plausibly a reinvention, but it might also just be too niche a set of requirements for anyone else to have bothered with) and I quite like. It uses a combination of two neat tricks to support all the following operations in O(1):

It does assume that the values in the collection are hashable, but honestly I’ve only ever wanted to use it with integers.

It builds on the following trick:

class LazySequenceCopy(object):
    """A "copy" of a sequence that works by inserting a mask in front
    of the underlying sequence, so that you can mutate it without changing
    the underlying sequence. Effectively behaves as if you could do list(x)
    in O(1) time. The full list API is not supported yet but there's no reason
    in principle it couldn't be."""

    def __init__(self, values):
        self.__values = values
        self.__len = len(values)
        self.__mask = None

    def __len__(self):
        return self.__len

    def pop(self):
        if len(self) == 0:
            raise IndexError("Cannot pop from empty list")
        result = self[-1]
        self.__len -= 1
        if self.__mask is not None:
            self.__mask.pop(self.__len, None)
        return result

    def __getitem__(self, i):
        i = self.__check_index(i)
        default = self.__values[i]
        if self.__mask is None:
            return default
            return self.__mask.get(i, default)

    def __setitem__(self, i, v):
        i = self.__check_index(i)
        if self.__mask is None:
            self.__mask = {}
        self.__mask[i] = v

    def __check_index(self, i):
        n = len(self)
        if i < -n or i >= n:
            raise IndexError("Index %d out of range [0, %d)" % (i, n))
        if i < 0:
            i += n
        assert 0 <= i < n
        return i

This allows us to make a copy of the initial sequence without actually copying it.

Now if we wanted to do sampling without replacement it’s easy: We can just do a lazy fisher yates shuffle:

def sample_without_replacement(random, sequence):
    i = random.randrange(0, len(sequence))
    j = len(sequence) - 1
    sequence[i], sequence[j] = sequence[j], sequence[i]
    return sequence.pop()

We pick a random element, swap it to the end, and then pop the last element. This lets us randomly sample and then remove in O(1) because we don’t care about preserving the order of the sequence.

The tricky bit is what to do when you want to sample and then decide later whether you want to remove the value or not. I e.g. run into this requirement when randomly but exhaustively exploring a state space - sometimes we discover that a node is fully explored and we want to remove it, but we can’t know that at the time of sampling.

The idea is to do deferred deletion. Rather than attempting to actually delete it at the time of deletion, we simply record that we wanted to delete it, and the next time we sample we remove any elements we encounter.

This works as follows:

from collections import Counter

class RemovableSampler(object):
    def __init__(self, values):
        self.__values = LazySequenceCopy(values)
        self.__deletions = Counter()

    def sample(self, random):
        while True:
            i = random.randrange(0, len(self.values))
            v = self.values[i]
            if self.__deletions[v] > 0:
                j = len(sequence) - 1
                sequence[i], sequence[j] = sequence[j], sequence[i]
                self.__deletions[v] -= 1
                return v

    def delete(self, value):
        self.__deletions[value] += 1

i.e. we reuse the procedure from the lazy fisher yates shuffle but we only remove if we’re not going to return the value because we’ve previously deleted it.

We can see that this is ammortized O(1) by using a “debt” model. Each call to delete increases the total debt by one, and each iteration of sample either returns or repays one debt, so the total cost done is never more than O(deletes + samples).