DRMacIver's Notebook
Separating Sampling and Removal
Separating Sampling and Removal
This is a weird data structure that as far as I know I invented (it’s plausibly a reinvention, but it might also just be too niche a set of requirements for anyone else to have bothered with) and I quite like. It uses a combination of two neat tricks to support all the following operations in O(1):
- Initialize with N elements.
- Sample randomly from one of the remaining elements or raise an error if the collection is empty.
- Remove a single copy of a value from the collection.
It does assume that the values in the collection are hashable, but honestly I’ve only ever wanted to use it with integers.
It builds on the following trick:
class LazySequenceCopy(object):
"""A "copy" of a sequence that works by inserting a mask in front
of the underlying sequence, so that you can mutate it without changing
the underlying sequence. Effectively behaves as if you could do list(x)
in O(1) time. The full list API is not supported yet but there's no reason
in principle it couldn't be."""
def __init__(self, values):
self.__values = values
self.__len = len(values)
self.__mask = None
def __len__(self):
return self.__len
def pop(self):
if len(self) == 0:
raise IndexError("Cannot pop from empty list")
= self[-1]
result self.__len -= 1
if self.__mask is not None:
self.__mask.pop(self.__len, None)
return result
def __getitem__(self, i):
= self.__check_index(i)
i = self.__values[i]
default if self.__mask is None:
return default
else:
return self.__mask.get(i, default)
def __setitem__(self, i, v):
= self.__check_index(i)
i if self.__mask is None:
self.__mask = {}
self.__mask[i] = v
def __check_index(self, i):
= len(self)
n if i < -n or i >= n:
raise IndexError("Index %d out of range [0, %d)" % (i, n))
if i < 0:
+= n
i assert 0 <= i < n
return i
This allows us to make a copy of the initial sequence without actually copying it.
Now if we wanted to do sampling without replacement it’s easy: We can just do a lazy fisher yates shuffle:
def sample_without_replacement(random, sequence):
= random.randrange(0, len(sequence))
i = len(sequence) - 1
j = sequence[j], sequence[i]
sequence[i], sequence[j] return sequence.pop()
We pick a random element, swap it to the end, and then pop the last element. This lets us randomly sample and then remove in O(1) because we don’t care about preserving the order of the sequence.
The tricky bit is what to do when you want to sample and then decide later whether you want to remove the value or not. I e.g. run into this requirement when randomly but exhaustively exploring a state space - sometimes we discover that a node is fully explored and we want to remove it, but we can’t know that at the time of sampling.
The idea is to do deferred deletion. Rather than attempting to actually delete it at the time of deletion, we simply record that we wanted to delete it, and the next time we sample we remove any elements we encounter.
This works as follows:
from collections import Counter
class RemovableSampler(object):
def __init__(self, values):
self.__values = LazySequenceCopy(values)
self.__deletions = Counter()
def sample(self, random):
while True:
= random.randrange(0, len(self.values))
i = self.values[i]
v if self.__deletions[v] > 0:
= len(sequence) - 1
j = sequence[j], sequence[i]
sequence[i], sequence[j]
sequence.pop()self.__deletions[v] -= 1
else:
return v
def delete(self, value):
self.__deletions[value] += 1
i.e. we reuse the procedure from the lazy fisher yates shuffle but we only remove if we’re not going to return the value because we’ve previously deleted it.
We can see that this is ammortized O(1) by using a “debt” model. Each call to delete increases the total debt by one, and each iteration of sample either returns or repays one debt, so the total cost done is never more than O(deletes + samples).