Segment Tree, 세그먼트 트리, 백준 2042


Segment Tree


어떤 순서있는 정보가 있을 때, 그 정보의 특정 구간의 정보를 빠르게 탐색할 수 있는 자료구조입니다.

예를 들면 [4, 1, 3, 2]라는 배열이 있습니다.

이때 2 ~ 4 구간에 있는 수의 합을 구하고자 한다고 가정해보겠습니다.

현재는 배열의 길이가 짧으므로 빠르게 1 + 3 + 2인 6이 정답이라고 말할 수 있습니다.

하지만 배열이 수억, 수입억 길이가 되면 O(n)이 걸려 시간이 많이 걸리게 됩니다.

또한 이러한 정보를 물어보는 쿼리가 많다고 하면, 이렇게 정보를 알아내는 방법은 더욱 비효율적이 됩니다.

이때 segment tree라는 자료구조를 사용할 수 있습니다.

다음은 segment tree의 설명입니다.

segment_tree1

처음에 정보를 가지고 segment tree를 만들때는 제일 하단 노드부터 채워갑니다.

즉, tree[4] ~ tree[7] 을 먼저 채운 후, 그 부모노드들을 채워나갑니다.

이때 주의해야할 점은 tree 배열의 크기입니다.

현재 [4, 1, 3, 2]라는 4개의 정보를 저장하기위해 segment트리는 7개의 노드 크기가 필요합니다.

즉, \(2^{n}\)개의 정보를 위한 segment tree의 최소 크기는

\[2^0 + 2^1 + 2^2 + \cdots + 2^n = \sum_{i=0}^n 2^i\]

가 됩니다. 하지만 데이터의 크기가 \(2^n\) 일꺼라는 보장이 없으므로,

데이터의 크기가 k개라 하면, k보다 큰 2의 제곱수에 두배 이상의 값을 segment tree의 크기로 잡는 걸 추천합니다.


예를 들어 데이터의 수가 9개라 하면, 9보다 큰 2의 제곱수는 16입니다.

즉, segment tree의 크기를 최소 32로는 잡아야 세그먼트 트리 관련 문제를 풀 때

배열 때문에 문제가 생기는 일을 피할 수 있을 것입니다.


segment_tree2

정보를 바꿔야 하는 일이 생기면, 가장 밑단의 노드를 바꾼 후,

부모노드를 따라가서 수정해 나갑니다.


segment_tree3

구간의 합을 구할 때는, 루트 노드부터 시작해서 밑으로 차근차근 탐색해 나갑니다.


위의 예제에서는 구간의 합을 예로 들었지만,

구간의 곱, 구간의 최대값, 구간의 최솟값 등등 다양하게 변형할 수 있습니다.




백준 2042

https://www.acmicpc.net/problem/2042

위의 예제에서 설명했던 구간 합 문제입니다.

#include <stdio.h>

const int NMAX = 1e6;
int n, m, k;
long long arr[NMAX + 7], tree[NMAX * 10];

void buildTree(int root, int s, int e) {
	if (s == e) {
		tree[root] = arr[s];
		return;
	}

	int m = (s + e) / 2, left = root * 2, right = left + 1;
	buildTree(left, s, m);
	buildTree(right, m + 1, e);
	tree[root] = tree[left] + tree[right];
}

void update(int root, int s, int e, int idx) {
	if (s == e) {
		tree[root] = arr[idx];
		return;
	}

	int m = (s + e) / 2, left = root * 2, right = left + 1;
	if (idx <= m) update(left, s, m, idx);
	else update(right, m + 1, e, idx);
	tree[root] = tree[left] + tree[right];
}

long long query(int root, int s, int e, int qs, int qe) {
	if (qs <= s && e <= qe) return tree[root];
	if (qe < s || e < qs) return 0;

	int m = (s + e) / 2, left = root * 2, right = left + 1;
	long long leftValue = query(left, s, m, qs, qe);
	long long rightValue = query(right, m + 1, e, qs, qe);
	return leftValue + rightValue;
}

int main() {
	scanf("%d %d %d", &n, &m, &k);
	int i;
	for (i = 1; i <= n; i++) scanf("%lld", arr + i);
	
	buildTree(1, 1, n);

	int a, b, c1, mk = m + k;
	long long c2;
	while (mk--) {
		scanf("%d %d", &a, &b);
		if (a == 1) {
			scanf("%lld", &c2);
			arr[b] = c2;
			update(1, 1, n, b);
		}
		else {
			scanf("%d", &c1);
			printf("%lld\n", query(1, 1, n, b, c1));
		}
	}

	return 0;
}