본문 바로가기
코딩테스트/Leetcode

1530. Number of Good Leaf Nodes Pairs

by Ken out of ken 2024. 7. 18.

링크

Input:

  • <TreeNode*>root = [1,2,3,null,4]
  • <int>distance = 3

Output:

  • <int>1

 

요약:

 

leaf node끼리의 거리가 distance보다 작거나 같은 모든 leaf node pair를 반환하라

 

조건: 

  • 노드의 개수 범위 [1, 2^10] 
  • 1 <= 노드의 값 <= 100 
  • 1 <= 거리 <= 10

풀이 과정:

 

  • 설명

unordered_map으로 각 노드를 key값으로 부모와 depth를 엮어서 모든 Leaf 노드들을 찾아서 모은 후, 각 Leaf 노드 의 pair를 겹치지 않게 서로 비교하여 distance에 맞는 pair 개수를 구하기로 했다

 

 

  • 시간복잡도
    • DFS에 [O(N)]
    • 각 리프노드끼리 비교하는데 [O(L^2)] 
    • pair의 distance를 구하는데 [O(H)]

[O(N + L^2 * H)]

 

  • 코드
더보기
/**
 * Definition for a binary tree node.
 * struct TreeNode {
 *     int val;
 *     TreeNode *left;
 *     TreeNode *right;
 *     TreeNode() : val(0), left(nullptr), right(nullptr) {}
 *     TreeNode(int x) : val(x), left(nullptr), right(nullptr) {}
 *     TreeNode(int x, TreeNode *left, TreeNode *right) : val(x), left(left), right(right) {}
 * };
 */
class Solution {
    using NodeDepthPair = pair<TreeNode*, int>;
    using ChildOfParentDepth = unordered_map<TreeNode*, NodeDepthPair>;
public:
    int countPairs(TreeNode* root, int distance) {
        ChildOfParentDepth  child_info;
        vector<TreeNode*>   leaf_nodes;

        dfs(root, child_info, leaf_nodes, 1);
        return (getPairNum(child_info, leaf_nodes, distance));
    }

private:
    int getPairNum(ChildOfParentDepth child_info, vector<TreeNode*> leaf_nodes, int distance) {
        int leaf_len = leaf_nodes.size();
        int result = 0;

        for (int i = 0; i < leaf_len; ++i) {
            for (int j = i + 1; j < leaf_len; ++j) {
                if (getDistance(child_info, leaf_nodes[i], leaf_nodes[j], distance) <= distance)    
                    result++;
            }
        }
        return (result);
    }

    int getDistance(ChildOfParentDepth child_info, TreeNode* node1, TreeNode* node2, int distance) {
        int         i, node_distance = 0;
        int         node1_depth = child_info[node1].second, node2_depth = child_info[node2].second;
        
        if (node1_depth > node2_depth) {
            i = node1_depth - node2_depth;
            node_distance += i;
            while (i--) 
                node1 = child_info[node1].first;
        }
        else if (node1_depth < node2_depth) {
            i = node2_depth - node1_depth;
            node_distance += i;
            while (i--) 
                node2 = child_info[node2].first;
        }
        while (node1 != node2) {
            node1 = child_info[node1].first;
            node2 = child_info[node2].first;
            node_distance += 2;
        }
        cout << node_distance << endl;
        return (node_distance);
    }

    void    dfs(TreeNode* node, ChildOfParentDepth &child_info, vector<TreeNode*> &leaf_nodes, int depth) {
        if (!node)
            return ;
        else if (!node->left && !node->right)
            leaf_nodes.emplace_back(node);
        if (node->left)
            child_info.insert(make_pair(node->left, make_pair(node, depth)));
        if (node->right)
            child_info.insert(make_pair(node->right, make_pair(node, depth)));
        dfs(node->left, child_info, leaf_nodes, depth + 1);
        dfs(node->right, child_info, leaf_nodes, depth + 1);
    }
};

 

풀이 결과:

 

성능문제로 실패...

테스트 케이스에 넣어서 돌려보면 답은 나오지만 시간 초과가 나오는 데다 700ms~ 가 나와서 성능문제가 맞는것 같다

심지어 내가 분석한 방식은 2ms인것을 보니 확실히 문제가 있다

 

 


배운 점 및 코드:

 

 

더보기
성능 문제가 온 이유는 unordered_map에서 문제가 있던 것 같다
비록 탐색과 추가에 O(1) 이 든다고 하지만 내부에서 Hash Table을 쓰는만큼 충돌(Collisions) 발생시 그 오버헤드가 굉장히 큰데다 각 노드의 value는 [1 ~ 100] 이지만 노드의 개수는 [1 ~ 2^10] 이므로 동일한 value가 중첩되는 상황을 피하고자 노드의 주소값을 이용했으나 애초에 2^10 개나 되는 노드들의 충돌을 피할 수는 없었던 것 같다

하지만 이뿐만이 문제는 아닌것 같지만 내 지식으로는 지금 당장 알아낸 바는 이뿐이다
좀더 알아보야겠다


그렇게 결국 답을 확인할 수 밖에 없었고...
굉장히 특이한 방법을 찾게 되었는데 이 방식을 해석하는데 꽤나 애먹었다
이부분은 인상깊은 방법이고 나중에 이진트리 말고도 다른 곳에서도 쓰일 수 있는 방식이라고 생각이 되어 분석했다 

더보기
private:
    vector<int> postOrder(TreeNode* currentNode, int distance) {
        if (!currentNode)
            return vector<int>(12);
        else if (!currentNode->left && !currentNode->right) {
            vector<int> current(12);
            current[0] = 1;
            return current;
        }

        vector<int> left = postOrder(currentNode->left, distance);
        vector<int> right = postOrder(currentNode->right, distance);

        vector<int> current(12);

        for (int i = 0; i < 10; i++) {
            current[i + 1] = left[i] + right[i];
        }

        current[11] += left[11] + right[11];

        for (int d1 = 0; d1 <= distance; d1++) {
            for (int d2 = 0; d2 <= distance; d2++) {
                if (2 + d1 + d2 <= distance) {
                    current[11] += left[d1] * right[d2];
                }
            }
        }

        return current;
    }

public:
    int countPairs(TreeNode* root, int distance) {
        return postOrder(root, distance)[11];
    }
};​

 

하나하나 뜯어서 분리해보자

if (!currentNode)
    return vector<int>(12);
else if (!currentNode->left && !currentNode->right) {
    vector<int> current(12);
    current[0] = 1;
    return current;
}

vector<int> left = postOrder(currentNode->left, distance);
vector<int> right = postOrder(currentNode->right, distance);

 

 DFS의 후위순회 방식으로 Leaf node의 경우 크기 12의 vector<int>를 선언해 0번 인덱스를 1로 초기화하고 반환한다

이렇게되면 우리가 처음으로 만나는 Leaf node의 Parent node가 받는 left, right의 종류는 두가지가 있다

  • Leaf node:             [1, 0, 0, ..., 0]
  • Non-Leaf node:    [0, 0, 0, ..., 0]

이는 0번째 Distance(Index)내에 Leaf node가 하나 있다는 뜻이다

이렇게 받아든 결과를 아래의 코드를 이용하면

vector<int> current(12);

for (int i = 0; i < 10; i++) {
    current[i + 1] = left[i] + right[i];
}

새로운 벡터를 생성하여 전의 자식노드와의 Distance가 하나 늘었으니 인덱스를 늘리게 되며 기존의 left, right를 추가한다

그렇게되면 현재 node(Parent node)가 가지는 Leaf node 목록은

  • [0, 1, 0, ..., 0]

 이 되게 되는데 이는 1 Distance(Index)내에 Leaf node가 하나 있다는 뜻이 된다

current[11] += left[11] + right[11];

for (int d1 = 0; d1 <= distance; d1++) {
    for (int d2 = 0; d2 <= distance; d2++) {
        if (2 + d1 + d2 <= distance) {
            current[11] += left[d1] * right[d2];
        }
    }
}

return current;

current[11] 에 더해주는 것은 곧 현재 node의 Sub Tree에서 계산된 Leaf node pair 개수를 누적하는 것이다

그 증거로는 바로 아래의 반복문을 들여다 보면 확인 될 것이다

그동안 d1, d2의 값을 0 ~ Distance(Index)로 주며 그 합이 Distance 보다 같거나 작은 경우에만 current[11]에 넣는데, 이는 Distance에 합당한 거리 내에 있는 pair의 개수를 current[11]에 넣어주는 것이다

 

이렇게 만들어진 vector를 반환하고 

public:
    int countPairs(TreeNode* root, int distance) {
        return postOrder(root, distance)[11];
    }
}

 

 반환된 vector의 11번 인덱스를 참조하여 최종 반환한다

 

  • 시간복잡도
한번의 DFS 탐색 [O(n)]
매 node 마다의 distance 수 만큼 pair의 개수를 구하기 위해 [O(d^2)] 만큼 반복문을 돌리기 때문에
최종적으로는 [O(n * d^2)]가 되겠다