DRMacIver's Notebook
Filtered sampling from sorted values with incremental binary search
Filtered sampling from sorted values with incremental binary search
I had a problem recently that I solved badly. As is traditional, I figured out how to solve it well in the shower this morning. The solution is obvious in retrospect, but I thought I’d share it anyway as it’s an interesting algorithm that I’d not seen before.
def filtered_sample(random, values, lower_bound, upper_bound):
"""Samples a random value `x` from `values`, which must be a sorted list,
such that lower_bound <= x <= upper_bound, or raises ValueError if there are no such values.
"""
# Invariant: If 0 <= i < lo then values[i] < lower_bound
= 0
lo # Invariant: If hi <= i < len(values) then values[i] > upper_bound
= len(values)
hi while lo < hi:
= random.randrange(lo, hi)
i = values[i]
x if x < lower_bound:
= i + 1
lo elif x > upper_bound:
= i
hi else:
return x
raise ValueError(
f"`values` does not contain any elements between {lower_bound} and {upper_bound}"
)
This is a nice example algorithm that varies smoothly between two slightly more obvious algorithms:
- Rejection sampling (just sample until you get a value in the right range)
- Always do a binary search to find the bounds, and then random sample within that range.
Rejection sampling is the best when almost all of the values are in range. Binary search is best when few or none are (in particular the first cannot easily detect it. There are ways to mitigate that though)
My initial implementation did one round of rejection sampling and then the binary search, which is a sort of awkward compromise. The way the new algorithm works is essentially doing only as much of the binary search as you need. If a sample gives you a suitable value now, you can stop and return that. If it doesn’t, you’ve effectively chosen a midpoint in your binary search and can update the bounds for the next round of the sample.
The first algorithm basically only works as well as whichever is
better of the first two, while the new algorithm smoothly varies between
the two extremes. I believe (but haven’t checked the details) that it
runs in O(log(n / (k + 1))
where k
is the
number of elements in the array satisfying the condition.