Binary Search: A Better Way

7 minute read

So you think you know how to write binary search - just like everyone else. I mean, it’s easy, right? Maybe, but you can probably do it better. Let’s start with an example. Say you have a sorted integer array, and you want to find the largest integer no larger than k, and return its index.

So it usually goes like this: let n = size of array, l = 0 and r = n-1, and dive into a while loop. In the loop, you find a mid point between l and r. And you quit when you have only one number left, which is l = r.

int l = 0, r = n - 1;
while (l < r) {
    int mid = (l + r) / 2;
}

Then what do you do? Well, you could compare it with k.

while (l < r) {
    int mid = (l + r) / 2;
    if (v[mid] > k)
        r = mid - 1;
    else
        l = mid;
}
return l;

If v[mid] is larger than k, then all the valid answers can only be between l and mid-1. Otherwise, anywhere from mid to r could be the answer.

Now you’re happy and you try to run it - and it doesn’t work! In fact, not only does it not work, it’s a terrible piece of crap. Here’s why.

  1. Edge cases What happens if (1) the array is empty, (2) all integers are smaller than k, (3) all integers are larger than k? Well in (1) you’ll return 0 which is wrong, in (2) you’ll get an infinite loop (to be explained later), and in (3) you’ll get 0, which is also wrong. The correct results would be (1) -1 for not found, (2) n-1, and (3) -1 for not found. This algorithm gets all edge cases wrong!

  2. Infinite loop Say if we are given k = 2 and the array {0, 1}. l starts at 0 and r starts at 1, so mid would be (0 + 1) / 2 = 0. And 0 is not larger than 2, so we set l to mid, which is 0. We’re back to the same situation, and this loop will never end!

There are ways to fix the code, of course - you handle the special cases by spraying if statements all over the place, and you can fix the infinite loop with setting mid to (l + r + 1)/2. Like this:

if (n == 0  v[0] > k)
    return -1;
int l = 0, r = n - 1;
while (l < r) {
    int mid = (l + r + 1) / 2;
    if (v[mid] > k)
        r = mid - 1;
    else
        l = mid;
}
return l;

But it’s ugly! We have to pick between (l + r) and (l + r + 1), which seems totally arbitrary, and there are edge cases to consider. What went wrong?

Everything from the very beginning went wrong. At the point we set l to 0 and r to n-1, we’re already making an implicit statement (loop invariant). We’re assuming that every iteration as we enter the loop, the answer could be anywhere from l to r. But this isn’t even true! If the answer doesn’t exist, then obviously the answer can’t be between l to r. The other assumption made is that the r-l always decreases after each iteration (so the loop eventually terminates) - which is also not true, as examplified in the infinite loop case. Since the loop invariants never held, we get gibberish at the end of the algorithm.

It turns out that if we change our loop invariant, we could be in a much better situation. Here’s the proposal: the transition between the last integer that’s at most k and the first integer that’s larger than k is always between l and r.

That means l must start at -1, and r at n. This is because the transition could happen from -1 to 0 (before the first element) or from n-1 to n (after the last element). The stopping condition would be l + 1 = r, because by then we would know the transition is from l to r, and the answer would be l. Let’s try to code this out:

int l = -1, r = n;
while (l + 1 < r) {
    int mid = (l + r) / 2;
    if (v[mid] > k)
        r = mid;
    else
        l = mid;
}
return l;

Magically, this code handles all edge cases correctly! It’s absolutely correct given any input. We don’t have any more edge cases because we can express all cases using l and r. If the list is empty or all numbers are greater than k, then our “transition” would occur before the first number, which means l = -1 and r = 0. Even though v[-1] is invalid, (l, r) = (-1, 0) is a valid statement that implies there is no element at most k. We also don’t have infinite loops anymore, because when l and r differ at least by 2, (l + r) / 2 is guaranteed to not equal l or r.

As we can see here, instead of searching for an element, we’re really searching for a transition. And by making this change, our code becomes more elegant.

Let’s try this again and do Guess Number Higher or Lower. The problem is that given a function [int guess(int num)] which is equal to [int compare(int magic, int num)], guess the magic number.

int guessNumber(int n) {
    long long l = 0, r = n;
    while (r-l > 1) {
        long long mid = (l+r)/2;
        int res = guess(mid);
        if (!res) return mid;
        else if (res == -1)
            r = mid;
        else
            l = mid;
    }
    return r;
}

Binary search other than array indices

Split Array Largest Sum (hard)

Sometimes, instead of binary searching over the space of indices of an array, we might instead binary search over the space of the answer. See Split Array Largest Sum (hard). Given an array of nonnegative integers, return the least possible threshold such that you can partition the array into m subarrays where the sum of each subarray does not exceed the threshold. The idea here is that directly finding this number is hard, but it is easy to tell given the threshold, the minimum number of subarrays needed such that the threshold condition is met. The algorithm is then binary search for the threshold, compute the minimum number of partitions of that threshold, and find the smallest threshold such that the number of partitions is at most m.

bool valid(vector<int>& nums, int m, long long sum) {
    int cnt = 1;
    long long run = 0;
    for (int x : nums) {
        if (run + x > sum) {
            cnt++;
            run = x;
        } else {
            run += x;
        }
    }
    return cnt <= m;
}
 
int splitArray(vector<int>& nums, int m) {
    long long tot = nums\[0\];
    int low = nums\[0\];
    for (int x : nums) {
        tot += x;
        low = max(low, x);
    }
    long long l = low-1, r = tot;
    while (r-l > 1) {
        long long mid = (l+r)/2;
        if (valid(nums, m, mid))
            r = mid;
        else
            l = mid;
    }
    return r;
}

In the valid function, I calculate the number of sections needed, and if it goes over m, that means the current number is too low, hence invalid. We also know that the answer will not be smaller than the largest element in the array, and will not be larger than the sum of the entire array.

Searching for real

And of course you can binary search over a real number range as well. Imagine in the previous example, the input is instead an array of doubles, and the threshold is also a double. The algorithm is basically the same, except that we need to handle floating point comparisons using an [int sign(double x)] function.

int sign(double x) {
    double eps = 1e-10;
    return x < -eps ? -1 : x > eps;
}

bool valid(vector<double>& nums, int m, double sum) {
    int cnt = 1;
    double run = 0;
    for (double x : nums) {
        if (sign(run + x - sum) > 0) {
            cnt++;
            run = x;
        } else {
            run += x;
        }
    }
    return (cnt <= m);
}
 
double splitArray(vector<double>& nums, int m) {
    double tot = nums[0];
    int low = nums[0];
    for (double x : nums) {
        tot += x;
        low = max(low, x);
    }
    long long l = low, r = tot;
    while (sign(r - l) > 0) {
    double mid = (l + r) / 2;
    if (valid(nums, m, mid))
        r = mid;
    else
        l = mid;
    }
    return r;
}

The sign function is used to control the accuracy of double. This is because double arithmetic is not exact, and we need to make sure we are comparing them correctly. Using this function, a < b is rewritten as sign(b-a) > 0, and a >= b is rewritten as sign(a-b) >= 0. The easy way to think about this is to move terms across the inequality signs such that one side becomes zero, then wrap the other side in the sign function.

That’s it for now.

Categories:

Updated: