Brief Intro to Segment Tree

12 minute read

Something I just learned - segment tree, is a data structure more advanced and generalized than binary indexed tree. Even though I just learned it and might not be qualified to discuss it yet, I’m pretty excited so who cares. So here’s an intro to segment tree from a noob point of view. Most of the content comes from this Codeforces blog by Al.Cash, but that blog assumes more prior knowledge and I will attempt to explain it from scratch.

Motivation

Member binary indexed tree? I member. For range sum query, BIT supports logarithmic time complexity for updating an element and querying the sum of any range. The way querying from l to r was done was a 2-phase process: first find the prefix sum of l-1 and r, then subtract the two. Pretty straight forward.

But what if we want to take minimum or maximum of the range, instead of just taking the sum? With BIT, we can only query the range from the beginning to a certain point. It is trivial really, refer to the code below:

// binary indexed tree for taking maximum of range \[0, k)
void update(int v, int k, vector<int>& bit) {
    for (k++; k < bit.size(); k += k&(-k))
        bit[k] = max(bit[k], v);
}

int query(int k, vector<int>& bit) {
        int ans = 0;
        for (k++; k > 0; k -= k&(-k))
            ans = max(ans, bit[k]);
        return ans;
}

There are a few limitations with this approach. First, you can only update an element to a bigger integer, because there is no reverse operation for maximum like subtraction for addition. Second, you can only take maximum from the beginning to a certain element but not an arbitrary range. With summation that’s not a problem, as we can subtract sums of different ranges to get partial sums. But with taking minimum, you can’t just subtract minima can you?

Therefore, we need a data structure that is like BIT, but supports true range query, instead of prefix range query.

Basic specs

  1. Segment tree is like a binary indexed tree on steroids, all problems solved by BIT can be solved by segment tree.
  2. It takes 2 times the space of the original array (double of BIT).
  3. You can update any element of the original array in O(log(n)) time (same as BIT).
  4. You can look up any element of the original array in constant time (BIT takes log(n)).
  5. The original array can be of any data type (same as BIT).
  6. You can perform a range query of a certain associative function in O(log(n)) time (BIT can only make prefix range queries).

The last point is important. As far as a function is associative, you can use it with ST. It does not have to be commutative or reversible. A few examples:

  1. Addition. Just like BIT, given {1, 2, 3, 4, 5}, you can sum {2, 3, 4} in log(n) time, for example.
  2. Maximum. Given {1, 2, 3, 4, 5}, you can take maximum of {2, 3, 4} in log(n) time, for example.
  3. Matrix multiplication. Given {A1, A2, A3, A4, A5} of five matrices, you can get the product {A2, A3, A4} in log(n) time.

Associativity: say you have 3 variables a, b and c, and a function f that operates on two parameters. Then f(f(a, b), c) = f(a, f(b, c)) iff f is associative. Note that commutativity does not follow, f(a, b) does not necessarily equal f(b, a). Associativity is important because for a range {a, b, c, d, e}, you can precompute {a, b} and {c, d, e}, then combine the two results.

Basic idea

Let’s borrow the sample from BIT: {1, 2, 3, 4, 5, 6, 7, 8}. Here’s a graph of what segment tree stores: For comparison, here’s a graph of what binary indexed tree stores:

A few observations:

  1. ST is simply BIT with the whole table filled in, without any blanks.
  2. st[8..15] is our original array. In the second half of the ST array, we always store the original array.
  3. st[i] = st[i*2] + st[i*2+1]. Node i/2 is node i’s parent.
  4. To update any element changes the same number of nodes in the tree.
  5. We can sum up ranges directly. For example {3, 4, 5, 6, 7} = st[5]+st[6]+st[14].

It is easy to see that ST is an extension of BIT and supports direct range queries while BIT doesn’t. The only thing left now is how to actually implement it.

Code

In general, the function does not have to be addition of integers, so I will abstract it to any function that takes 2 structs and returns a struct.

struct data {
    int num;
    data() {
        num = 0;
    }
    data(int n) {
        num = n;
    }
};

data merge(data a, data b) {
    return data(a.num+b.num);
}

For simplicity I still made it a wrapper of addition of numbers. You can make it any function with any data, either minimum of long longs, multiplication of matrices etc.

Then, given an array, we need to build the segment tree.

vector<data> build(vector<data>& v) {
    int n = v.size();
    vector<data> st(n*2);
    for (int i = 0; i < v.size(); i++)
        st[i+n] = v[i];
    for (int i = n-1; i > 0; i--)
        st[i] = merge(st[i*2], st[i*2+1]);
    return st;
}

Technically we do not need a build function if we have an update function, just like with BIT. We can just update the entries one by one. But that would take n*log(n) time, while this function is O(n), so this is not entirely useless.

The build function takes in the original array and makes an array double its size. Then, we copy the array to the second half of the tree. Constructing the tree is a bottom-up procedure, each time calculating the new sum from the two lower nodes (from observation 3). That’s it.

The update function is even simpler.

void update(data v, int k, vector<data>& st) {
    int n = st.size()/2;
    st[k+n] = v;
    for (int i = (k+n)/2; i > 0; i /= 2)
        st[i] = merge(st[i*2], st[i*2+1]);
}

It is a 4-line function. Here, we first update the entry in the second half of the tree, then we go to its parent iteratively by dividing by 2, until you hit 1, the root.

Now the actual hard part: queries. Given the range [l, r) from l to r-1, query the “sum” (could be product or any arbitrary function) of the range.

data query(int l, int r, vector<data>& st) {
    int n = st.size()/2;
    data ansl, ansr;
    for (l += n, r += n; l < r; l /= 2, r /= 2) {
        if (l%2 == 1) {
            ansl = merge(ansl, st[l]);
            l++;
        }
        if (r%2 == 1) {
            r--;
            ansr = merge(st[r], ansr);
        }
    }
    return merge(ansl, ansr);
}

I will only try my best to explain, but I will not go through everything in fine detail because it is too tedious. Let’s say we want to sum {2, 3, 4, 5, 6}, i.e. l = 1, r = 6 (because v[1] = 2, v[6-1] = 6). First we add n to both l and r, so we have l = 9, r = 14. Refer to the above graph, our goal is to sum st[9], st[5] and st[6]. In fact, in the first loop we will pick up st[9], because l is an odd number. After we add it to the left sum, we add 1 to l to denote that we have already added numbers in this subtree, so we have a smaller range of numbers to add. In the next loop, l and r are divided by 2 to go up a level in the tree, and become 5 and 7. Now both numbers are odd, and we will merge st[5] and st[6] to the left sum and right sum respectively. In the next iteration, l and r are both 3, meaning that the range to sum is empty, and the loop breaks.

To be honest, I don’t 100% understand why the conditions are l and r are odd. The idea could be that if l is even, that means both l and l+1 are in the range, and we would rather add the number at l/2 since it includes both, therefore we do not do anything when l is even. The only exception when not both l and l+1 are in the range is when r = l+1, but that means r is odd and we will add st[r-1], which is st[l]. Otherwise if l is odd, we might as well eliminate this subtree by adding st[l] and moving l to l+1. Everything should be mirrored and r should be checked to be even, but r is not included in the range [l, r), so the actual end point is r-1, so we check whether it is odd. Just like with low bits in BIT, you don’t actually need to understand it to use it; I bet you can’t implement a red black tree either but you still use set<> like an algorithm master anyway.

Important point: There are two ans variables, ansl and ansr, and they take sums from both parts respectively. This is to maintain the order of computation, in case the merge function is not commutative. In this case however it does not make any difference.

Array of arbitrary size

The above is explained with an array size of 8, which was sort of cheating, because you will most likely want an array of arbitrary size. Of course, you can add padding zeros at the end to make it a power of two.

n = v.size();
while (n != lowbit(n)) // n&(-n)
    n += lowbit(n);
v.resize(n);

That would do. This is because when n is a power of two, its low bit is itself. However this is not even needed; the original code, although designed for powers of two, will work for any vector size n. There is a short explanation on the original blog, but the full proof should be too complicated and not useful to know. We only need to know that it automagically works for any n, and happily copy paste code.

Example: matrix multiplication

Build, update and query are copy pasted, so they are omitted in the example.

struct data {
    int m, n;
    vector<vector<int> > A;
    data() {
        m = 0;
        n = 0;
    }
    data(const vector<vector<int> >& B) {
        A = B;
        m = A.size();
        n = a[0].size();
    }
    void print() {
        for (auto r : A) {
            for (auto c : r)
                cout << c << " ";
            cout << endl;
        }
    }
};
data merge(data a, data b) {
    if (a.m*a.n == 0)
        return b;
    if (b.m*b.n == 0)
        return b;
    int m = a.m, n = b.n, l = a.n;
    vector<vector<int> > C(m, vector<int>(n));
    for (int i = 0; i < m; i++)
        for (int j = 0; j < n; j++)
            for (int k = 0; k < l; k++)
                C[i][j] += a.A[i][k]*b.A[k][j];
    return data(C);
}
vector<data> build(vector<data>& v) {...}
void update(data v, int k, vector<data>& st) {...}
data query(int l, int r, vector<data>& st) {...}
int main() {
    vector<data> v;
    v.push_back(data({ {2, 0}, {0, 2} }));
    v.push_back(data({ {1, 1, 4}, {4, 2, 2} }));
    v.push_back(data({ {-2, 0}, {1, 4}, {1, 2} }));
    v.push_back(data({ {0, 1}, {1, 0} }));
    vector<data> st = build(v);
    int l, r;
    while (cin >> l >> r)
        query(l, r, st).print();
    return 0;
}

You can mostly just copy paste the three functions and modify the data and merge definitions to fit your applications. I was motivated to study segment trees because of this problem on Codeforces. I was not able to do this problem during the contest, but neither could tourist, so whatever.

New Year and Old Subsequence

This is a slightly more advanced application of segment tree. The problem: given a string of digits, return the minimum number of digits that need to be removed such that there is a subsequence of “2017” but not “2016”. If 2017 is not a subsequence, print -1. More precisely, after a string of length up to 200,000 characters is given, there are up to 200,000 queries of the range l to r, and we need to answer what is the minimum number of digits to remove such that the sequence [l, r] has a subsequence of “2017” but not “2016”. The algorithm is described here. The gist is that for an interval of a string, all that we need to know is given we already have a certain prefix of 2017, how many digits do we need to erase in this interval so that we will have a longer prefix of 2017. For example, if the current digit is 6, then given “201” or “2017”, we will have to erase one digit (the 6) to ensure we have a prefix of “201” or “2017” without any “2016”. Otherwise the 6 doesn’t matter. Here’s my code.

struct data {
    unsigned int dp[5][5];
    data() {
        for (int j = 0; j < 5; j++)
            for (int i = 0; i <= j; i++)
                dp[j][i] = INT_MAX;
    }
    void clear() {
        for (int i = 0; i < 5; i++)
        dp[i][i] = 0;
    }
};
 
data merge(const data& a, const data& b) {
    data temp;
    for (int j = 0; j < 5; j++)
        for (int i = 0; i <= j; i++)
            for (int k = i; k <= j; k++)
                temp.dp[j][i] = min(temp.dp[j][i], a.dp[k][i]+b.dp[j][k]);
    return temp;
}
 
int main() {
    int n, q;
    string s;
    cin >> n >> q >> s;
    vector<data> st(2*n);
    for (int i = 0; i < n; i++) {
        st[i+n].clear();
        if (s[i] == '2') {
            st[i+n].dp[0][0] = 1;
            st[i+n].dp[1][0] = 0;
        } else if (s[i] == '0') {
            st[i+n].dp[1][1] = 1;
            st[i+n].dp[2][1] = 0;
        } else if (s[i] == '1') {
            st[i+n].dp[2][2] = 1;
            st[i+n].dp[3][2] = 0;
        } else if (s[i] == '7') {
            st[i+n].dp[3][3] = 1;
            st[i+n].dp[4][3] = 0;
        } else if (s[i] == '6') {
            st[i+n].dp[3][3] = st[i+n].dp[4][4] = 1;
        }
    }
    // build segment tree
    for (int i = n-1; i; i--)
        st[i] = merge(st[i<<1], st[i<<1|1]);
    while (q--) {
        int l, r;
        cin >> l >> r;
        // query from l-1 to r
        data ansl, ansr;
        ansl.clear();
        ansr.clear();
        for (l += n-1, r += n; l < r; l >>= 1, r >>= 1) {
            if (l&1) {
                ansl = merge(ansl, st[l]);
                l++;
            }
            if (r&1) {
                r--;
                ansr = merge(st[r], ansr);
            }
        }
        int ans = merge(ansl, ansr).dp[4][0];
        cout << (ans == INT_MAX ? -1 : ans) << endl;
    }
    return 0;
}

If you paid attention to the code, you will see I embedded the build and query functions in the main function. Also you will notice I used integer array in C style instead of vectors in C++ style. There are also changes in details such as replacing *2 by <<1 (left shift) and +1 by 1 (bitwise or). I hate to do this, but the judge on Codeforces is very demanding and my normal coding style got TLE.