문제
https://www.acmicpc.net/problem/5419
문제 해석
강한 북서풍이 불고 있다. 이 뜻은 동쪽과 남쪽 사이의 모든 방향으로 항해할 수 있다는 뜻이다. 북쪽이나 서쪽으로 항해하는 것은 불가능하다. 작은 섬이 여러 개 있는 바다가 있다. 섬은 좌표 평면의 한 점으로 나타낼 수 있다. y 좌표가 증가하는 방향은 북쪽, x좌표가 증가하는 방향은 동쪽이다. 북서풍을 타고 항해할 수 있는 섬의 쌍의 수를 구하는 프로그램을 작성하시오.
북서풍을 타고 항해할 수 있는 섬의 쌍의 수를 구하는 문제이다.
즉 아래와 같이 화살표가 X가 증가, Y가 감소하는 방향으로만 향할 수 있다.
좌표 쌍을 5개 구할 수 있으므로, 5가 답이 된다.
풀이
스위핑과 세그먼트 트리 기법으로 문제를 풀 수 있다.
스위핑 기법
특정 선이나 공간을 한쪽에서부터 싹 쓸어버리는 식의 문제 해결 기법이다.
세그먼트 트리 알고리즘
대표 문제
2021.12.06 - [Problem Solving/BOJ] - [BOJ 2042] 구간 합 구하기 (Java)
처음에는 어떻게 세그먼트 트리로 접근하는 건지 도저히 감이 안왔다.
세그먼트 트리 알고리즘을 학습한 지 얼마 안 된 상태였고, 세그먼트 트리 초기화를 다~ 해준 후 쿼리하는 문제만 풀어봐서 쿼리와 트리 업데이트를 같이 해볼 생각을 전혀 못했다.
복습 차원에서 최대한 풀이를 자세히 써보려고 한다.
스위핑 기법을 통한 접근
위와 같이 섬이 위치해 있다고 가정하자.
먼저 X값으로 오름차순, X값이 같을 경우 Y값 내림차순으로 정렬해보자.
(1, 2), (1, 1), (2, 3), (3, 2), (3, 1)
정렬된 좌표를 하나씩 체크하면서 가능한 쌍의 수를 구해보자.
1. (1, 2) 체크 : 0개
(1, 2) 와 쌍이 되려면 x<=1 && y >=2 조건을 만족하는 섬이 있어야 한다.
X값 기준으로 오름차순을 했기 때문에 현재 좌표 이후에 나오는 좌표들은 현재 좌표와 쌍이 될 수 없다.
따라서 이후에 나오는 좌표는 비교할 가치가 없다.
즉, 처음 부터 ~ 현재 좌표 전까지의 좌표들을 체크하면서 Y값이 현재 좌표의 Y값과 같거나 큰 좌표만이 현재 좌표와 쌍이 될 수 있다.
2. (1, 1) 체크: 1개
이전 좌표 중 Y값이 2인 좌표가 1개 있다.
(1, 1)은 (1, 2)와 쌍이 될 수 있다.
3. (2, 3) 체크: 0개
이전에 나온 좌표 중에 (2, 3)의 Y값 이상 값을 가지는 좌표가 없다.
X값은 신경쓸 필요가 없다. 어차피 이전에 나온 좌표는 모두 현재 좌표의 X값 보다 같거나 작다.
4. (3, 2) 체크: 2개
5. (3, 1) 체크: 5개
따라서, 가능한 섬의 쌍 개수는 7이 된다.
세그먼트 트리로 어떻게 접근하지?
이제 위 순서대로 섬을 순회하면서 현재 섬의 가능한 쌍의 개수를 query할 것이고,
그 섬을 세그먼트 트리에서 update하고 다음 섬을 체크해나가면서 전체 쌍의 개수를 구해보려 한다.
아래 예제가 어떻게 세그먼트 트리로 구성되는지 코드 구현부와 함께 설명해보려고 한다.
1. 좌표를 받아서 List에 저장한다.
// 좌표 저장
StringTokenizer st = null;
for(int i = 0; i < N; i++){
st = new StringTokenizer(br.readLine());
list.add(new Node(Integer.parseInt(st.nextToken()), Integer.parseInt(st.nextToken())));
}
2. Y값 내림차순으로 좌표를 정렬한 후 좌표를 압축한다.
Why? Y값으로 세그먼트 트리를 구성할 것이다.
문제에서 Y값의 범위가 -10^9 ~ 10^9 이므로, Tree를 구성하기에는 값이 너무 크기 때문에 좌표를 압축한다.
예를 들어 Y좌표가 (1, 500, 10000000, 39) -> (4, 2, 1, 3)
Y값을 내림차순으로 정렬한 이유는 구현할 때 편해서 그렇게 했다.
현재 좌표에서 가능한 쌍을 세그먼트 트리에 query할 때, left: 1, right: 현재 Y값을 넣어주면 된다.
나는 Y값의 가장 작은 값을 1로 정했다.
y값을 내림차순하여 좌표압축했으므로 y값이 작을 수록 북쪽에 있는 섬 |
// y좌표 내림차순으로 정렬
Collections.sort(list, new Comparator<Node>(){
@Override
public int compare(Node n1, Node n2){
return n2.y - n1.y;
}
});
// 좌표 압축
// 북쪽에 있는 섬부터 y를 새로 매김
// 북쪽에 있는 섬이 가장 작은 값을 가지게 됨.
int ny = 1;
int prev = list.get(0).y;
for(Node n : list){
if(n.y == prev){
n.y = ny;
continue;
}
prev = n.y;
n.y = ++ny;
}
3. X값 오름차순, X값이 동일할 경우 Y값 오름차순 정렬
// x좌표 오름차순. x좌표 같을 경우 y좌표 북쪽 부터 정렬
Collections.sort(list, new Comparator<Node>(){
@Override
public int compare(Node n1, Node n2){
int rs = n1.x - n2.x;
if(rs == 0) rs = n1.y - n2.y;
return rs;
}
});
1 | 2 | 3 | 4 | |
x | -10 | -10 | 10 | 10 |
y | 1 | 2 | 1 | 2 |
4. 좌표를 순회하면서 현재 좌표에서 가능한 섬의 쌍을 구하고 (query) 현재 좌표를 Tree에 업데이트 (update)한다.
long ans = 0;
// 세그먼트 트리에 쿼리를 한 후, 해당 값을 업데이트 한다.
int[] tree = new int[N*4];
for(Node node : list){
ans += query(tree,1, ny, 1, node.y, 1);
update(tree, 1, ny, node.y, 1);
}
세그먼트 트리의 범위는 y값을 의미한다.
좌표 압축했을 때 y값에 갱신 했던 가장 작은 값과 가장 마지막 값이 루트 노드의 범위가 된다.
** 각 노드의 구간 합은 해당 구간에 Y값이 포함된 좌표의 개수를 의미한다.
1. (-10, 1) 체크 / 쌍의 개수: 0
query : 현재 Y값과 같거나 작은 Y값을 가지는 좌표의 개수를 리턴한다.
= 현재 좌표보다 북쪽에 있는 좌표의 개수 = 즉 현재 좌표와 쌍이 될 수 있는 개수를 의미한다.
1-1 구간 합을 찾는다. tree[2] 값이 실제로 더해지고 나머지는 범위 밖으로 0이 리턴된다.
public static long query(int[] tree, int start, int end, int left, int right, int node){
if(left > end || right < start) return 0;
if(left <= start && end <= right) return tree[node];
int mid = (start + end)/2;
return query(tree, start, mid, left, right, node*2) + query(tree, mid+1, end, left, right, node*2+1);
}
update: 현재 좌표의 Y값이 포함된 구간합 노드에 +1을 더해준다.
public static int update(int[] tree, int start, int end, int target, int node){
if(target < start || target > end) return tree[node];
if(start == end) {
tree[node] += 1;
return tree[node];
}
int mid = (start + end)/2;
return tree[node] = update(tree, start, mid, target, node*2) + update(tree, mid+1, end, target, node*2+1);
}
2. (-10, 2) 체크 / 쌍의 개수: 1
query | update |
3. (10, 1) 체크 / 쌍의 개수: 1
query | update |
4. (10, 2) 체크 / 쌍의 개수: 3
query | update |
답: 5
소스코드
package segmenttree;
import java.io.BufferedReader;
import java.io.BufferedWriter;
import java.io.InputStreamReader;
import java.io.OutputStreamWriter;
import java.util.*;
public class BOJ5419 {
public static void main(String[] args) throws Exception {
BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
BufferedWriter bw = new BufferedWriter(new OutputStreamWriter(System.out));
int T = Integer.parseInt(br.readLine());
for(int i = 0; i < T; i ++){
bw.write(solution(br, bw) + "\n");
}
bw.flush();
}
public static class Node {
int x;
int y;
public Node(int x, int y){
this.x = x;
this.y = y;
}
}
public static long solution(BufferedReader br, BufferedWriter bw) throws Exception{
// 좌표 개수
int N = Integer.parseInt(br.readLine());
// 좌표를 저장할 리스트
ArrayList<Node> list = new ArrayList<Node>();
// 좌표 저장
StringTokenizer st = null;
for(int i = 0; i < N; i++){
st = new StringTokenizer(br.readLine());
list.add(new Node(Integer.parseInt(st.nextToken()), Integer.parseInt(st.nextToken())));
}
// y좌표 내림차순으로 정렬
Collections.sort(list, new Comparator<Node>(){
@Override
public int compare(Node n1, Node n2){
return n2.y - n1.y;
}
});
// 좌표 압축
// 북쪽에 있는 섬부터 y를 새로 매김
// 북쪽에 있는 섬이 가장 작은 값을 가지게 됨.
int ny = 1;
int prev = list.get(0).y;
for(Node n : list){
if(n.y == prev){
n.y = ny;
continue;
}
prev = n.y;
n.y = ++ny;
}
// x좌표 오름차순. x좌표 같을 경우 y좌표 북쪽 부터 정렬
Collections.sort(list, new Comparator<Node>(){
@Override
public int compare(Node n1, Node n2){
int rs = n1.x - n2.x;
if(rs == 0) rs = n1.y - n2.y;
return rs;
}
});
long ans = 0;
// 세그먼트 트리에 쿼리를 한 후, 해당 값을 업데이트 한다.
int[] tree = new int[N*4];
for(Node node : list){
ans += query(tree,1, ny, 1, node.y, 1);
update(tree, 1, ny, node.y, 1);
}
return ans;
}
public static long query(int[] tree, int start, int end, int left, int right, int node){
if(left > end || right < start) return 0;
if(left <= start && end <= right) return tree[node];
int mid = (start + end)/2;
return query(tree, start, mid, left, right, node*2) + query(tree, mid+1, end, left, right, node*2+1);
}
public static int update(int[] tree, int start, int end, int target, int node){
if(target < start || target > end) return tree[node];
if(start == end) {
tree[node] += 1;
return tree[node];
}
int mid = (start + end)/2;
return tree[node] = update(tree, start, mid, target, node*2) + update(tree, mid+1, end, target, node*2+1);
}
}
'💡Problem Solving > BOJ' 카테고리의 다른 글
[BOJ 14716] 현수막 (Java) (0) | 2021.12.09 |
---|---|
[BOJ 2170] 선 긋기 (Java) (0) | 2021.12.08 |
[BOJ 2357] 최솟값과 최댓값 (Java) (0) | 2021.12.06 |
[BOJ 11505] 구간 곱 구하기 (Java) (0) | 2021.12.06 |
[BOJ 11659] 구간 합 구하기 4 (Java) (0) | 2021.12.06 |