본문 바로가기
Baekjoon Online Judge/세그먼트 트리

11658번 구간 합 구하기 3

by jh280722 2020. 4. 2.

2차원 세그먼트 트리나 펜윅 트리를 이용한 문제이다. (x1, y1)부터 (x2, y2)까지의 합을 많은 수의 쿼리가 주어졌을 때 결과를 출력해야 한다. 펜윅 트리는 아직 안 배워서 2차원 세그먼트 트리를 이용해서 풀게 되었다. 처음 2차원 세그먼트 트리를 만들 때 어떻게 해야 할지 고민하다가 단순히 4 사분면처럼 쿼드 트리를 만들어서 구현을 해봤다.

 

시간 초과 소스코드

더보기
#include <bits/stdc++.h>
#define all(v) v.begin(), v.end()
#define pb push_back
#define fu(i,a,j) for(int i=a;i<j;i++)
#define fd(i,a,j) for(int i=a;i>=j;i--)
#define SYNC ios::sync_with_stdio(false),cin.tie(NULL),cout.tie(NULL)
#define MOD 998244353
#define MOD2 1000000021
#define INF 1e9
#define N 1025
using namespace std;

typedef long long ll;
typedef long double ld;
typedef double db;
typedef pair<int, int> pii;
typedef pair<ll, ll> pll;
typedef vector<int> vi;
typedef vector<ll> vll;
typedef vector<pii> vpii;
typedef vector<pll> vpll;

ll gcd(ll a, ll b) { return b ? gcd(b, a % b) : a; }
ll lcm(ll a, ll b) { return a * b / gcd(a, b); };
int n, m, k, t;
int dr[] = { 0,0,1,-1, -1,1,1,-1 };
int dc[] = { -1,1,0,0, 1,1,-1,-1 };

int seg[4 * N][4 * N];
int Map[N][N];
void init(int idx, int idx2, int l, int r, int u, int d) {
	//idx<<1 왼쪽 idx<<1|1 오른쪽
	//idx2<<1 위	  idx<<1|1 아래
	//idx<<1 idx2<<1 좌상, idx<<1 idx2<<1|1 좌하
	//idx<<1 idx2 좌 전체
	//idx<<1|1 idx2<<1 우상, idx<<1|1 idx2<<1|1 우하

	if (l != r) {
		int mid1 = (l + r) >> 1;
		int mid2 = (u + d) >> 1;
		init(idx << 1, idx2 << 1, l, mid1, u, mid2);
		init(idx << 1 | 1, idx2 << 1, mid1 + 1, r, u, mid2);
		init(idx << 1, idx2 << 1 | 1, l, mid1, mid2 + 1, d);
		init(idx << 1 | 1, idx2 << 1 | 1, mid1 + 1, r, mid2 + 1, d);
		seg[idx][idx2] =
			seg[idx << 1][idx2 << 1] + seg[idx << 1][idx2 << 1 | 1] +
			seg[idx << 1 | 1][idx2 << 1] + seg[idx << 1 | 1][idx2 << 1 | 1];
	}
	else
		seg[idx][idx2] = Map[l][u];
}
void update(int idx, int idx2, int l, int r, int u, int d, int x, int y, int val) {
	if (x<l || x>r) return;
	if (y<u || y>d) return;
	if (l == r && u == d) {
		seg[idx][idx2] = val;
		return;
	}
	int mid1 = (l + r) >> 1;
	int mid2 = (u + d) >> 1;
	update(idx << 1, idx2 << 1, l, mid1, u, mid2, x, y, val);
	update(idx << 1 | 1, idx2 << 1, mid1 + 1, r, u, mid2, x, y, val);
	update(idx << 1, idx2 << 1 | 1, l, mid1, mid2 + 1, d, x, y, val);
	update(idx << 1 | 1, idx2 << 1 | 1, mid1 + 1, r, mid2 + 1, d, x, y, val);
	seg[idx][idx2] =
		seg[idx << 1][idx2 << 1] + seg[idx << 1][idx2 << 1 | 1] +
		seg[idx << 1 | 1][idx2 << 1] + seg[idx << 1 | 1][idx2 << 1 | 1];
}

int getval(int idx, int idx2, int l, int r, int u, int d, int x1, int x2, int y1, int y2) {
	if (r < x1 || l > x2) return 0;
	if (d < y1 || u > y2) return 0;
	if (x1 <= l && r <= x2 && y1 <= u && d <= y2)
		return seg[idx][idx2];
	int mid1 = (l + r) >> 1;
	int mid2 = (u + d) >> 1;
	int ans = 0;
	ans += getval(idx << 1, idx2 << 1, l, mid1, u, mid2, x1, x2, y1, y2);
	ans += getval(idx << 1 | 1, idx2 << 1, mid1 + 1, r, u, mid2, x1, x2, y1, y2);
	ans += getval(idx << 1, idx2 << 1 | 1, l, mid1, mid2 + 1, d, x1, x2, y1, y2);
	ans += getval(idx << 1 | 1, idx2 << 1 | 1, mid1 + 1, r, mid2 + 1, d, x1, x2, y1, y2);
	return ans;
}

int main() {
	SYNC;
	cin >> n >> m;
	fu(i, 1, n + 1) {
		fu(j, 1, n + 1) {
			cin >> Map[i][j];
		}
	}
	init(1, 1, 1, n, 1, n);
	fu(i, 0, m) {
		int w;
		cin >> w;
		if (w) {
			int x1, y1, x2, y2;
			cin >> x1 >> y1 >> x2 >> y2;
			cout << getval(1, 1, 1, n, 1, n, x1, x2, y1, y2) << '\n';
		}
		else {
			int x, y, c;
			cin >> x >> y >> c;
			update(1, 1, 1, n, 1, n, x, y, c);
		}
	}
	return 0;
}

 

하지만 쿼드트리를 구현하였더니 시간 초과가 나게 되었다. 2차원 세그먼트 트리도 해본 적이 없었기에 시간 초과가 나는 줄 몰랐고 어떻게 구현해야 할지 막막하였다.

결국 알아낸것은 X축과 Y축을 두 번 구현하는 것이었다. 2차원 배열을 만들고 x의 값의 세그먼트 트리를 y의 세그먼트에 모아놓고 다시 y의 배열에서 합을 구하는 것이다. 

	int i = y1 + h - 1, j = x1 + h - 1;
	seg[i][j] = val;
	while (j > 1) {
		j /= 2;
		seg[i][j] = seg[i][j << 1] + seg[i][j << 1 | 1];
	}

이렇게 연산을 하게되면 seg [i]에 seg [i][j]의 합이 모이게 되고 그걸 다시 i를 기준으로 합치는 것이다.

while (i > 1) {
		j = x1 + h - 1;
		i /= 2;
		seg[i][j] = seg[i << 1][j] + seg[i << 1 | 1][j];
		while (j > 1) {
			j /= 2;
			seg[i][j] = seg[i][j << 1] + seg[i][j << 1 | 1];
		}
	}

두 번을 합치게 되었으므로 getval 함수도 x와 y로 나누어서 y인덱스와 x 인덱스를 두 번 구해야 한다.

 

소스코드

더보기
#include <bits/stdc++.h>
#define all(v) v.begin(), v.end()
#define pb push_back
#define fu(i,a,j) for(int i=a;i<j;i++)
#define fd(i,a,j) for(int i=a;i>=j;i--)
#define SYNC ios::sync_with_stdio(false),cin.tie(NULL),cout.tie(NULL)
#define MOD 998244353
#define MOD2 1000000021
#define INF 1e9
#define N 1025
using namespace std;

typedef long long ll;
typedef long double ld;
typedef double db;
typedef pair<int, int> pii;
typedef pair<ll, ll> pll;
typedef vector<int> vi;
typedef vector<ll> vll;
typedef vector<pii> vpii;
typedef vector<pll> vpll;

ll gcd(ll a, ll b) { return b ? gcd(b, a % b) : a; }
ll lcm(ll a, ll b) { return a * b / gcd(a, b); };
int n, m, k, t;
int dr[] = { 0,0,1,-1, -1,1,1,-1 };
int dc[] = { -1,1,0,0, 1,1,-1,-1 };

int seg[2 * N][2 * N];
int h = 1;
void update(int x1, int y1, int val) {
	int i = y1 + h - 1, j = x1 + h - 1;
	seg[i][j] = val;
	while (j > 1) {
		j /= 2;
		seg[i][j] = seg[i][j << 1] + seg[i][j << 1 | 1];
	}
	while (i > 1) {
		j = x1 + h - 1;
		i /= 2;
		seg[i][j] = seg[i << 1][j] + seg[i << 1 | 1][j];
		while (j > 1) {
			j /= 2;
			seg[i][j] = seg[i][j << 1] + seg[i][j << 1 | 1];
		}
	}
}

int getx(int y, int idx, int l, int r, int x1, int x2) {
	if (x1 > r || x2 < l) return 0;
	if (x1 <= l && r <= x2) return seg[y][idx];
	int mid = (l + r) >> 1;
	return getx(y, idx << 1, l, mid, x1, x2) +
		getx(y, idx << 1 | 1, mid + 1, r, x1, x2);
}

int gety(int idx, int l, int r, int x1, int x2, int y1, int y2) {
	if (y1 > r || y2 < l) return 0;
	if (y1 <= l && r <= y2) return getx(idx, 1, 1, h, x1, x2);
	int mid = (l + r) >> 1;
	return gety(idx << 1, l, mid, x1, x2, y1, y2) +
		gety(idx << 1 | 1, mid + 1, r, x1, x2, y1, y2);
}
int main() {
	SYNC;
	cin >> n >> m;
	while (h < n) h <<= 1;
	fu(i, 1, n + 1) {
		fu(j, 1, n + 1) {
			int a;
			cin >> a;
			update(i, j, a);
		}
	}
	fu(i, 0, m) {
		int w;
		cin >> w;
		if (w) {
			int x1, y1, x2, y2;
			cin >> x1 >> y1 >> x2 >> y2;
			cout << gety(1, 1, h, x1, x2, y1, y2) << '\n';
		}
		else {
			int x, y, c;
			cin >> x >> y >> c;
			update(x, y, c);
		}
	}
	return 0;
}

댓글