💡Problem Solving/BOJ

[BOJ 2042] 구간 합 구하기 (Java)

gom20 2021. 12. 6. 11:45

문제

https://www.acmicpc.net/problem/2042

 

2042번: 구간 합 구하기

첫째 줄에 수의 개수 N(1 ≤ N ≤ 1,000,000)과 M(1 ≤ M ≤ 10,000), K(1 ≤ K ≤ 10,000) 가 주어진다. M은 수의 변경이 일어나는 횟수이고, K는 구간의 합을 구하는 횟수이다. 그리고 둘째 줄부터 N+1번째 줄

www.acmicpc.net

 

풀이

Segment Tree를 이용하여 풀 수 있다. 아래 블로그를 참고 하여 학습하였다.

https://m.blog.naver.com/ndb796/221282210534

 

41. 세그먼트 트리(Segment Tree)

이번 시간에 다룰 내용은 여러 개의 데이터가 연속적으로 존재할 때 특정한 범위의 데이터의 합을 구하는 ...

blog.naver.com

 

예제 

5 2 2
1
2
3
4
5
1 3 6
2 2 5
1 5 2
2 3 5

 

1. 입력을 받는다.

        int N = Integer.parseInt(st.nextToken()); // 수의 개수
        int M = Integer.parseInt(st.nextToken()); // 수의 변경이 일어나는 횟수
        int K = Integer.parseInt(st.nextToken()); // 구간의 합을 구하는 횟수

2. 수열을 저장한다.

        arr = new long[N+1];
        for(int i = 1; i <= N; i++){
            arr[i] = Integer.parseInt(br.readLine());
        }

 

편의 상, index 1부터 사용

index 1 2 3 4 5
element 1 2 3 4 5

3. Segment Tree로 구간합을 저장한다.

2-1. Tree 배열 생성

완전 이진 트리의 Node 개수는 leaf Node 개수 * 2 이다. 

따라서 배열 크기는 2^k 중 n과 가까운 큰 값에 *2를 해주면 된다. 

예를 들어 수열의 개수가 7이라면 8*2, 16이라면 16*2, 12라면 16*2... 

이런 로직 추가 없이 좀 더 러프하게 잡아서 진행한다면, 수열의 개수에 *4를 해주면 완전히 커버할 수 있다. 

        tree = new long[N*4];

2-2. 구간합 저장

재귀 함수를 사용한다. 

    public static long set(int start, int end, int node){
        if(start == end) return tree[node] = arr[start];
        int mid = (start + end)/2;
        return tree[node] = set(start, mid, node*2) + set(mid+1, end, node*2+1);
    }

구간합

4. 특정 Index의 수를 변경

세 번째 수를 6으로 변경한다면, 범위에 해당하는 구간합을 갱신해야 한다

index 1 2 3 4 5
element 1 2 3 -> 6 4 5

 

    public static void update(int start, int end, int node, int target, long diff){
        if(target < start || target > end) return;
        tree[node] += diff;
        if(start == end) return;
        int mid = (start+end)/2;
        update(start, mid, node*2, target, diff);
        update(mid+1, end, node*2+1, target, diff);
    }

세 번째 수에서 6으로 변경하는 것이므로, 세 번째 수를 포함한 구간합 Node에 ( 6 - 세번째 수) 씩 더해주면 된다. 

특정 수 변경

5. 특정 구간 합 구하기

위 그래프에서 2~5 구간의 합을 구한다면?

2~5 구간에 포함되는 Tree Node의 값을 모두 더해주면 된다. 

Node의 구간이 2~5보다 클 경우, 자식 노드를 체크한다.

    public static long sum(int start, int end, int left, int right, int node){
        // 범위를 벗어남
        if(right < start || left > end) return 0;
        // tree node의 합이 구간합 범위 안에 있음
        if(left <= start && end <= right ) {
            return tree[node];
        }

        // 필요한 구간 합보다, tree node의 구간합 범위가 더 클 경우, 쪼개서 다시 진행
        int mid = (start + end)/2;
        return sum(start, mid, left, right, node*2) +  sum(mid+1, end, left, right, node*2+1);
    }

특정 구간 합

** 해당 문제는 overflow를 유의해야한다. 구간 합의 경우 long형을 써야 통과할 수 있다. 

소스코드

package segmenttree;

import java.io.BufferedReader;
import java.io.BufferedWriter;
import java.io.InputStreamReader;
import java.io.OutputStreamWriter;
import java.util.StringTokenizer;

public class BOJ2042 {
    public static long[] tree;
    public static long[] arr;
    public static void main(String[] args) throws Exception {
        BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
        BufferedWriter bw = new BufferedWriter(new OutputStreamWriter(System.out));
        StringTokenizer st = new StringTokenizer(br.readLine());
        int N = Integer.parseInt(st.nextToken()); // 수의 개수
        int M = Integer.parseInt(st.nextToken()); // 수의 변경이 일어나는 횟수
        int K = Integer.parseInt(st.nextToken()); // 구간의 합을 구하는 횟수

        arr = new long[N+1];
        for(int i = 1; i <= N; i++){
            arr[i] = Integer.parseInt(br.readLine());
        }

        // 1. 구간 합을 가지는 세그먼트 트리 생성하기
        tree = new long[N*4];
        set(1, N, 1);

        for(int i = 0; i < M+K; i++){
            st = new StringTokenizer(br.readLine());
            int flag = Integer.parseInt(st.nextToken());
            if(flag == 1){
                int from = Integer.parseInt(st.nextToken());
                long to = Long.parseLong(st.nextToken());
                update(1, N, 1, from, to-arr[from]);
                arr[from] = to;
            } else {
                int left = Integer.parseInt(st.nextToken());
                int right = Integer.parseInt(st.nextToken());
                bw.write(sum(1, N, left, right, 1) + "\n");
            }
        }
        bw.flush();
    }

    public static void update(int start, int end, int node, int target, long diff){
        if(target < start || target > end) return;
        tree[node] += diff;
        if(start == end) return;
        int mid = (start+end)/2;
        update(start, mid, node*2, target, diff);
        update(mid+1, end, node*2+1, target, diff);
    }

    public static long sum(int start, int end, int left, int right, int node){
        // 범위를 벗어남
        if(right < start || left > end) return 0;
        // tree node의 합이 구간합 범위 안에 있음
        if(left <= start && end <= right ) {
            return tree[node];
        }

        // 필요한 구간 합보다, tree node의 구간합 범위가 더 클 경우, 쪼개서 다시 진행
        int mid = (start + end)/2;
        return sum(start, mid, left, right, node*2) +  sum(mid+1, end, left, right, node*2+1);
    }

    public static long set(int start, int end, int node){
        if(start == end) return tree[node] = arr[start];
        int mid = (start + end)/2;
        return tree[node] = set(start, mid, node*2) + set(mid+1, end, node*2+1);
    }
}