본문 바로가기

Dev/BOJ

[백준 1761] 정점들의 거리

 

 LCA(Lowest Common Ancestor, 최소 공통 조상) 문제다. 모든 노드에 대해서 조상 노드들까지의 거리를 전처리 해두고, LCA를 찾아 그 거리를 합하면 될 것이라고 쉽게 생각할 수 있다. 대신에 Biased Tree일 경우, 트리의 깊이가 N이 될 수 있고 그렇게 되면 LCA를 찾아내는 알고리즘이 O(트리 높이)인 경우, 전체 시간복잡도가 O(NM)이 되어 시간초과가 발생할 것이다. 풀고 나서 찾아보니까 테스트 케이스가 잘못되었는지 O(N)인 LCA 알고리즘으로도 풀린다고 하지만 문제의 의도와 맞지 않는 풀이라고 생각한다.

 

 LCA를 O(log N)만에 찾을 수 있는 알고리즘은 Sparse Table을 사용하는 것이다. n번째 노드라는 것은 이진수를 떠올리면 알 수 있듯이, 2^k 조합으로 찾아낼 수 있기 때문이다. 유명한 알고리즘인지라 간단하게 요약하면 각 노드들의 부모 노드만 저장해두는게 아니라, 모든 2^k번째 조상노드들을 함께 저장해 두고 탐색할 때 활용해 주면 된다.

 만약 13번째 조상노드에 접근한다면 8번째 조상의 4번째 조상의 1번째 조상노드로 접근할 수 있다(13을 이진법으로 표기하면 1101(2)인 것을 떠올리면 이해하기 쉽다).

 

 이 부분만 이해하고 나면, 나머지 부분은 O(N)인 LCA 알고리즘과 크게 다르지 않다. 쿼리로 주어지는 두 노드 중 깊이가 큰 노드를 작은 노드 쪽에 맞춰주고, 공통 조상을 찾으면 된다. 이 때, 깊이를 맞추는 과정과 공통 조상을 찾는 과정 모두 Sparse Table을 활용하면 O(log N)만에 가능하다. 

 필요한 값이 LCA 그 자체가 아니라, 결국 이를 이용하는 정점 간의 거리 합이므로, 나는 Saprse Table를 구축할 때 각 조상노드까지의 거리를 미리 저장해뒀다. 그래서 LCA를 찾아나가는 과정에서 거리를 계속 더해주고 그 합을 리턴하는 식으로 답을 찾아냈는데, 그냥 주어지는 두 노드에서 루트 노드까지의 거리를 각각 더하고 LCA에서 루트 노드까지의 거리*2를 빼주는 식으로 구해도 상관없다.

 

#include <iostream>
#include <vector>
#include <algorithm>

using namespace std;

int n;
vector<vector<pair<int, int> > > adjList; //{node, cost}
pair<int, long long> ancestors[40001][17]; //[node][k]{num, costSum}
int depthList[40001];

void dfs(int n, int depth) {

    depthList[n] = depth;
    
    if(n!=1) {
        int ind = 1, indDepth = 2;
        int anc = ancestors[n][0].first;
        long long costSum = ancestors[n][0].second;

        while(indDepth<=depth) {
            costSum += ancestors[anc][ind-1].second;
            ancestors[n][ind] = {ancestors[anc][ind-1].first, costSum};
            anc = ancestors[anc][ind-1].first;
            ind++;
            indDepth*=2;
        }
    }
    
    for(auto adj: adjList[n]) {
        if(adj.first != ancestors[n][0].first) { //부모노드 아니면
            ancestors[adj.first][0] = {n, adj.second};
            dfs(adj.first, depth+1);
        }
    }
}

long long lcaSum(int a, int b) {
    if(depthList[a]<depthList[b]) {
        swap(a, b);
    }
    int k = 0;
    while(1<<(k+1) <= depthList[a]) k++;

    long long costSum = 0;
    for(int i=k; i>=0; i--) {
        if(depthList[a] - (1<<i) >= depthList[b]) {
            costSum += ancestors[a][i].second;
            a = ancestors[a][i].first;
        }
    }

    if(a==b) return costSum;

    for(int i=k; i>=0; i--) {
        if(ancestors[a][i].first!=0 && ancestors[a][i].first!=ancestors[b][i].first) {
            costSum += ancestors[a][i].second;
            costSum += ancestors[b][i].second;

            a = ancestors[a][i].first;
            b = ancestors[b][i].first;
        }
    }

    costSum += ancestors[a][0].second;
    costSum += ancestors[b][0].second;

    return costSum;
}

int main() {
    ios_base::sync_with_stdio(false);
    cin.tie(nullptr);
    cout.tie(nullptr);

    cin >> n;
    adjList = vector<vector<pair<int, int> > >(n+1);

    int a, b, cost;
    for(int i=0; i<n-1; i++) {
        cin >> a >> b >> cost;
        adjList[a].push_back({b, cost});
        adjList[b].push_back({a, cost});
    }

    dfs(1, 0);

    int m;
    cin >> m;

    for(int i=0; i<m; i++) {
        cin >> a >> b;
        cout << lcaSum(a, b) << "\n";
    }
}

'Dev > BOJ' 카테고리의 다른 글

[백준 13334] 철로  (0) 2024.10.31
[백준 11438] LCA 2  (0) 2024.10.30
[백준 7569] 토마토  (1) 2024.10.25
[백준 2150] Strongly Connected Component  (0) 2024.10.24
[백준 11280] 2-SAT - 3  (0) 2024.10.23