Input:
- <TreeNode*>root = [1,2,3,null,4]
- <int>distance = 3
Output:
- <int>1
요약:
leaf node끼리의 거리가 distance보다 작거나 같은 모든 leaf node pair를 반환하라
조건:
- 노드의 개수 범위 [1, 2^10]
- 1 <= 노드의 값 <= 100
- 1 <= 거리 <= 10
풀이 과정:
- 설명
unordered_map으로 각 노드를 key값으로 부모와 depth를 엮어서 모든 Leaf 노드들을 찾아서 모은 후, 각 Leaf 노드 의 pair를 겹치지 않게 서로 비교하여 distance에 맞는 pair 개수를 구하기로 했다
- 시간복잡도
- DFS에 [O(N)]
- 각 리프노드끼리 비교하는데 [O(L^2)]
- pair의 distance를 구하는데 [O(H)]
[O(N + L^2 * H)]
- 코드
/**
* Definition for a binary tree node.
* struct TreeNode {
* int val;
* TreeNode *left;
* TreeNode *right;
* TreeNode() : val(0), left(nullptr), right(nullptr) {}
* TreeNode(int x) : val(x), left(nullptr), right(nullptr) {}
* TreeNode(int x, TreeNode *left, TreeNode *right) : val(x), left(left), right(right) {}
* };
*/
class Solution {
using NodeDepthPair = pair<TreeNode*, int>;
using ChildOfParentDepth = unordered_map<TreeNode*, NodeDepthPair>;
public:
int countPairs(TreeNode* root, int distance) {
ChildOfParentDepth child_info;
vector<TreeNode*> leaf_nodes;
dfs(root, child_info, leaf_nodes, 1);
return (getPairNum(child_info, leaf_nodes, distance));
}
private:
int getPairNum(ChildOfParentDepth child_info, vector<TreeNode*> leaf_nodes, int distance) {
int leaf_len = leaf_nodes.size();
int result = 0;
for (int i = 0; i < leaf_len; ++i) {
for (int j = i + 1; j < leaf_len; ++j) {
if (getDistance(child_info, leaf_nodes[i], leaf_nodes[j], distance) <= distance)
result++;
}
}
return (result);
}
int getDistance(ChildOfParentDepth child_info, TreeNode* node1, TreeNode* node2, int distance) {
int i, node_distance = 0;
int node1_depth = child_info[node1].second, node2_depth = child_info[node2].second;
if (node1_depth > node2_depth) {
i = node1_depth - node2_depth;
node_distance += i;
while (i--)
node1 = child_info[node1].first;
}
else if (node1_depth < node2_depth) {
i = node2_depth - node1_depth;
node_distance += i;
while (i--)
node2 = child_info[node2].first;
}
while (node1 != node2) {
node1 = child_info[node1].first;
node2 = child_info[node2].first;
node_distance += 2;
}
cout << node_distance << endl;
return (node_distance);
}
void dfs(TreeNode* node, ChildOfParentDepth &child_info, vector<TreeNode*> &leaf_nodes, int depth) {
if (!node)
return ;
else if (!node->left && !node->right)
leaf_nodes.emplace_back(node);
if (node->left)
child_info.insert(make_pair(node->left, make_pair(node, depth)));
if (node->right)
child_info.insert(make_pair(node->right, make_pair(node, depth)));
dfs(node->left, child_info, leaf_nodes, depth + 1);
dfs(node->right, child_info, leaf_nodes, depth + 1);
}
};
풀이 결과:
성능문제로 실패...
테스트 케이스에 넣어서 돌려보면 답은 나오지만 시간 초과가 나오는 데다 700ms~ 가 나와서 성능문제가 맞는것 같다
심지어 내가 분석한 방식은 2ms인것을 보니 확실히 문제가 있다
배운 점 및 코드:
비록 탐색과 추가에 O(1) 이 든다고 하지만 내부에서 Hash Table을 쓰는만큼 충돌(Collisions) 발생시 그 오버헤드가 굉장히 큰데다 각 노드의 value는 [1 ~ 100] 이지만 노드의 개수는 [1 ~ 2^10] 이므로 동일한 value가 중첩되는 상황을 피하고자 노드의 주소값을 이용했으나 애초에 2^10 개나 되는 노드들의 충돌을 피할 수는 없었던 것 같다
하지만 이뿐만이 문제는 아닌것 같지만 내 지식으로는 지금 당장 알아낸 바는 이뿐이다
좀더 알아보야겠다
그렇게 결국 답을 확인할 수 밖에 없었고...
굉장히 특이한 방법을 찾게 되었는데 이 방식을 해석하는데 꽤나 애먹었다
이부분은 인상깊은 방법이고 나중에 이진트리 말고도 다른 곳에서도 쓰일 수 있는 방식이라고 생각이 되어 분석했다
private:
vector<int> postOrder(TreeNode* currentNode, int distance) {
if (!currentNode)
return vector<int>(12);
else if (!currentNode->left && !currentNode->right) {
vector<int> current(12);
current[0] = 1;
return current;
}
vector<int> left = postOrder(currentNode->left, distance);
vector<int> right = postOrder(currentNode->right, distance);
vector<int> current(12);
for (int i = 0; i < 10; i++) {
current[i + 1] = left[i] + right[i];
}
current[11] += left[11] + right[11];
for (int d1 = 0; d1 <= distance; d1++) {
for (int d2 = 0; d2 <= distance; d2++) {
if (2 + d1 + d2 <= distance) {
current[11] += left[d1] * right[d2];
}
}
}
return current;
}
public:
int countPairs(TreeNode* root, int distance) {
return postOrder(root, distance)[11];
}
};
하나하나 뜯어서 분리해보자
if (!currentNode)
return vector<int>(12);
else if (!currentNode->left && !currentNode->right) {
vector<int> current(12);
current[0] = 1;
return current;
}
vector<int> left = postOrder(currentNode->left, distance);
vector<int> right = postOrder(currentNode->right, distance);
DFS의 후위순회 방식으로 Leaf node의 경우 크기 12의 vector<int>를 선언해 0번 인덱스를 1로 초기화하고 반환한다
이렇게되면 우리가 처음으로 만나는 Leaf node의 Parent node가 받는 left, right의 종류는 두가지가 있다
- Leaf node: [1, 0, 0, ..., 0]
- Non-Leaf node: [0, 0, 0, ..., 0]
이는 0번째 Distance(Index)내에 Leaf node가 하나 있다는 뜻이다
이렇게 받아든 결과를 아래의 코드를 이용하면
vector<int> current(12);
for (int i = 0; i < 10; i++) {
current[i + 1] = left[i] + right[i];
}
새로운 벡터를 생성하여 전의 자식노드와의 Distance가 하나 늘었으니 인덱스를 늘리게 되며 기존의 left, right를 추가한다
그렇게되면 현재 node(Parent node)가 가지는 Leaf node 목록은
- [0, 1, 0, ..., 0]
이 되게 되는데 이는 1 Distance(Index)내에 Leaf node가 하나 있다는 뜻이 된다
current[11] += left[11] + right[11];
for (int d1 = 0; d1 <= distance; d1++) {
for (int d2 = 0; d2 <= distance; d2++) {
if (2 + d1 + d2 <= distance) {
current[11] += left[d1] * right[d2];
}
}
}
return current;
current[11] 에 더해주는 것은 곧 현재 node의 Sub Tree에서 계산된 Leaf node pair 개수를 누적하는 것이다
그 증거로는 바로 아래의 반복문을 들여다 보면 확인 될 것이다
그동안 d1, d2의 값을 0 ~ Distance(Index)로 주며 그 합이 Distance 보다 같거나 작은 경우에만 current[11]에 넣는데, 이는 Distance에 합당한 거리 내에 있는 pair의 개수를 current[11]에 넣어주는 것이다
이렇게 만들어진 vector를 반환하고
public:
int countPairs(TreeNode* root, int distance) {
return postOrder(root, distance)[11];
}
}
반환된 vector의 11번 인덱스를 참조하여 최종 반환한다
- 시간복잡도
한번의 DFS 탐색 [O(n)]
매 node 마다의 distance 수 만큼 pair의 개수를 구하기 위해 [O(d^2)] 만큼 반복문을 돌리기 때문에
최종적으로는 [O(n * d^2)]가 되겠다