Segment Tree
Understanding Segment Trees and their importance in range queries.
What is a Segment Tree?
A segment tree is a data structure used for solving range queries efficiently. It’s like having a supercharged array where you can quickly perform operations like sum, minimum, or maximum over a range of elements, and even update elements in logarithmic time. It’s super handy when you need to process multiple queries on an array.
If you’ve ever wanted to calculate the sum of an array range or find the minimum element in a range in a way faster than a loop, a segment tree is your friend.
How Does a Segment Tree Work?
Imagine you take an array and build a binary tree on top of it. Each leaf node of the tree represents an element of the array, and each internal node represents a combination (like sum, min, or max) of its children. This structure lets you query or update any part of the array in O(log n) time.
For example:
- Build the tree with an array of numbers.
- Query for a specific range, like the sum of elements from index
2
to5
. - Update the value of a single element and automatically reflect that in the tree.
Common Operations on a Segment Tree
- Build the tree: Construct the tree from the array.
O(n)
- Range Query: Perform operations like sum or minimum over a specific range.
O(log n)
- Update: Change the value of an element and update the tree.
O(log n)
Why Use Segment Trees?
- They handle range queries efficiently, even for large datasets.
- Perfect for problems involving frequent updates and queries.
- They work with various operations, not just sums—like min, max, gcd, etc.
If you need blazing-fast range operations and updates, a segment tree is your go-to data structure!
Code Examples
Let me show you how a segment tree works with an example in Python. I’ll use the range sum query as the operation, but you can replace it with other operations like minimum or maximum.
# Python implementation of a Segment Tree for range sum queries
class SegmentTree:
def __init__(self, array):
n = len(array)
self.n = n
self.tree = [0] * (4 * n) # Allocate memory for the tree
self.build(array, 0, 0, n - 1)
# Build the segment tree
def build(self, array, node, start, end):
if start == end:
# Leaf node, store the array value
self.tree[node] = array[start]
else:
mid = (start + end) // 2
left_child = 2 * node + 1
right_child = 2 * node + 2
# Recursively build left and right subtrees
self.build(array, left_child, start, mid)
self.build(array, right_child, mid + 1, end)
# Internal node stores the sum of its children
self.tree[node] = self.tree[left_child] + self.tree[right_child]
# Range query
def query(self, node, start, end, l, r):
if r < start or l > end:
# Range is completely outside
return 0
if l <= start and end <= r:
# Range is completely inside
return self.tree[node]
# Partial overlap, query left and right children
mid = (start + end) // 2
left_child = 2 * node + 1
right_child = 2 * node + 2
left_sum = self.query(left_child, start, mid, l, r)
right_sum = self.query(right_child, mid + 1, end, l, r)
return left_sum + right_sum
# Point update
def update(self, node, start, end, idx, value):
if start == end:
# Update the leaf node
self.tree[node] = value
else:
mid = (start + end) // 2
left_child = 2 * node + 1
right_child = 2 * node + 2
if start <= idx <= mid:
self.update(left_child, start, mid, idx, value)
else:
self.update(right_child, mid + 1, end, idx, value)
# Update the internal node
self.tree[node] = self.tree[left_child] + self.tree[right_child]
# Example usage
array = [1, 3, 5, 7, 9, 11]
st = SegmentTree(array)
print("Sum of range [1, 3]:", st.query(0, 0, len(array) - 1, 1, 3)) # Output: 15
st.update(0, 0, len(array) - 1, 1, 10) # Update index 1 to 10
print("Updated sum of range [1, 3]:", st.query(0, 0, len(array) - 1, 1, 3)) # Output: 22