🏷️ 카테고리
`#자료구조`, `#트리`
⏳ 시간복잡도
`O(N log N)`
📒 해설
이진 탐색 트리에서 `insert`의 재귀 횟수를 계산하는 문제입니다. 하지만 N이 30만 번이라는 것이 걸리는 문제였습니다. N이 30만 번이라면, BST에서의 Worst Case의 경우 `O(N^2)`이기 때문에 시간 초과가 발생하기 때문입니다. 그렇다면 이 문제의 핵심은 BST에서의 `insert` 횟수를, 실제 `BST`를 구현하지 않고 계산해야 한다는 것인데 어떻게 하면 이것이 가능할까요? 이 방법을 같이 찾아봅시다.
📜 문제 조건
문제의 조건은 다음과 같습니다.
- 입력으로 들어올 수열의 크기 `N (1 <= N <= 300,000)`
- 수열의 값 `X (1 <= X <= N)` , 중복되지 않음.
🔍 문제 접근
앞선 해설 부분에서도 말했듯이, '이 문제에서는 실제 BST를 구현하지 않고 insert의 횟수를 알아야한다' 는 점과 '이 동작을 `O(N log N)` 내에 끝내야 한다.'가 핵심이고, 해당 방법을 찾는다면 크게 어려움 없이 풀 수 있는 문제입니다.
우선 insert의 횟수를 쉽게 알아내는 방법으로는 `DP`를 생각했습니다. 만약 새로운 수 X가 어떤 노드의 자식이 될지를 안다면, 해당 노드 +1의 insert가 되기 때문입니다.
하지만, 새로운 X가 어떤 노드의 자식이 될지를 알아내는 방법이 중요하겠죠? 저는 이 방법에 대해서 새로운 X에 대해서 X보다 작은 노드 중에서 가장 큰 노드, 그리고 X보다 큰 노드들 중에서 가장 작은 노드를 찾고, 작은 노드의 Right child 혹은 큰 노드의 left child 중에 빈자리가 X의 자리라고 생각했습니다. X에 대해서 두 노드가 존재한다면, 무조건 둘 중 한 자리는 비어있기 때문입니다. 이 부분은 BST의 특성을 생각하면 쉽게 알 수 있었습니다.
그래서 생각했던 몇 가지 방법은 다음과 같습니다.
1. X위치에 대해서 왼쪽, 오른쪽을 순차 탐색을 한다.
이 경우 결국 순차탐색이 발생하기 때문에 `O(N^2)`라고 생각했습니다.
2. 바이너리 서치를 이용해 위치를 찾는다.
이 경우 일차적으로는 바이너리 서치가 `O(log N)`이기 때문에 가능할 것이라고 생각했었습니다. 하지만 이 경우 지속적으로 정렬된 상태의 배열 혹은 컬렉션이 필요했습니다.
순간 좀 혹하는 생각이 들었습니다.
'어? 리스트에 저장하고 바이너리 서치로 X가 있을 위치를 찾아내고, left, right를 결정한 이후 그 위치에 삽입하면 되는것 아닌가?' 뭐 이런 느낌이었죠.
하지만 다시 생각해봤을때, `ArrayList`의 경우 삽입 연산에서 shift 동작으로 인해 `O(N^2)`이 발생하고, `LinkedList`는 말할 필요도 없었죠.
그래서 지속적인 정렬을 하지만 정렬 비용이 `O(log N)`인 자료구조가 있는가에 대해서 생각해 봤고, Java에서 `TreeOOO` 시리즈가 Red-Black Tree로 구현된다는 것을 생각해 냈습니다. 그리고 `TreeMap`을 이용해 문제를 해결하게 되었습니다.
🔑 문제 풀이
사용 변수
TreeMap<Integer, Node> map = new TreeMap<>();
static class Node{
int depth;
boolean hasLeft;
boolean hasRight;
public Node(int depth) {
this.depth = depth;
}
}
앞에서 말했던 것처럼, `TreeMap`을 이용했고, 각 노드에 대한 표현으로 dp를 위한 depth 변수, 그리고 left, right 자식이 있는지만을 체크해주기 위한 boolean 변수를 사용했습니다.
X보다 작은 노드 중에서 큰 녀석, 큰 노드 중에서 작은 녀석 찾기
이 부분 역시 풀이 과정에서 중요한 부분이라고 생각했는데요, `TreeMap`을 이용하면 정말 쉽게 구현이 가능합니다. `TreeMap`에는 `lowerXXX()`, `higherXXX()` 과 같은 메서드가 존재해 파라미터로 key를 넣어주면 key보다 바로 작은(큰) 키 혹은 엔트리를 리턴해줍니다. 저의 경우 value의 값이 필요했기 때문에 엔트리를 리턴 받아서 이용했습니다.
코드로 치면 다음과 같습니다.
int k = Integer.parseInt(br.readLine());
Map.Entry<Integer, Node> lowerEntry = map.lowerEntry(k);
Map.Entry<Integer, Node> higherEntry = map.higherEntry(k);
insert 카운팅하기
이제 모든 준비가 끝났으니, 본래 문제였던 insert의 호출 횟수를 카운팅 하면 됩니다. 이 방법에 대해서는 처음 설명과 같이 lower와 higher value에 대해서 BST의 insert처럼 해주면 됩니다. 코드를 보시면 바로 이해가 가실 겁니다.
그리고 재귀 횟수(nextDepth)를 전체 카운팅(c)에서 더해주면 됩니다.
int nextDepth;
// higher만 있는 경우 즉, 현재 값이 higher의 left가 되는 경우
if (lowerEntry == null) {
Node higherNode = higherEntry.getValue();
higherNode.hasLeft = true;
nextDepth = higherNode.depth+1;
} else if (higherEntry == null) { // lower만 있는 경우, 즉 현재 값이 lower의 right가 되어야 하는 경우
Node lowerNode = lowerEntry.getValue();
lowerNode.hasRight = true;
nextDepth = lowerNode.depth+1;
} else {
// 둘 다 있는 경우
Node lNode = lowerEntry.getValue();
Node rNode = higherEntry.getValue();
// 둘 중에서 right나 left가 없는 곳이 내 자리임.
if(!rNode.hasLeft) {
rNode.hasLeft = true;
nextDepth = rNode.depth+1;
} else {
lNode.hasRight = true;
nextDepth = lNode.depth+1;
}
}
map.put(k, new Node(nextDepth));
c += nextDepth;
전체 코드
import java.util.*;
import java.io.*;
public class Main {
static BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
static StringBuilder answer = new StringBuilder();
public static void main(String[] args) throws Exception {
TreeMap<Integer, Node> map = new TreeMap<>();
int n = Integer.parseInt(br.readLine());
int root = Integer.parseInt(br.readLine());
map.put(root, new Node(0));
long c = 0;
answer.append(c).append("\n");
for (int i = 1; i < n; i++) {
int k = Integer.parseInt(br.readLine());
Map.Entry<Integer, Node> lowerEntry = map.lowerEntry(k);
Map.Entry<Integer, Node> higherEntry = map.higherEntry(k);
int nextDepth;
// higher만 있는 경우 즉, 현재 값이 higher의 left가 되는 경우
if (lowerEntry == null) {
Node higherNode = higherEntry.getValue();
higherNode.hasLeft = true;
nextDepth = higherNode.depth+1;
} else if (higherEntry == null) { // lower만 있는 경우, 즉 현재 값이 lower의 right가 되어야 하는 경우
Node lowerNode = lowerEntry.getValue();
lowerNode.hasRight = true;
nextDepth = lowerNode.depth+1;
} else {
// 둘 다 있는 경우
Node lNode = lowerEntry.getValue();
Node rNode = higherEntry.getValue();
// 둘 중에서 right나 left가 없는 곳이 내 자리임.
if(!rNode.hasLeft) {
rNode.hasLeft = true;
nextDepth = rNode.depth+1;
} else {
lNode.hasRight = true;
nextDepth = lNode.depth+1;
}
}
map.put(k, new Node(nextDepth));
c += nextDepth;
answer.append(c).append("\n");
}
System.out.println(answer);
}
static class Node{
int depth;
boolean hasLeft;
boolean hasRight;
public Node(int depth) {
this.depth = depth;
}
}
}
결과
Java의 라이브러리들에 대한 이해도가 중요했던 문제였던 거 같네요. 플레티넘 이상 문제에서는 뭔가 한번 꼬아서 생각해된다는 점이 재밌는 것 같습니다.
다른 사람들의 풀이를 봤을 때, 문제 접근에서의 2번 방법으로 풀어도 통과가 되는 것을 확인했습니다. 해당 방법으로도 한번 풀어보셔도 좋을 것 같아요.