트리 DP 문제다. 처음 본다면 로직을 떠올리기 조금 힘들 수도 있다. 트리 DP 문제를 풀어본 경험이 있다면 쉽게 접근할 수 있는 문제.
각각의 마을이 일반 마을일 수도 있고, 우수 마을일 수도 있는데 문제에서는 조건을 통해 이를 제약하고 있다.
1. 우수 마을로 선정된 마을 주민 수의 총합을 최대로 해야 한다.
2. 우수 마을끼리는 인접할 수 없다.
3. 일반 마을은 최소한 하나의 우수 마을과 인접해야 한다.
따라서 dp 값을 계산할 때 어떤 마을이 일반 마을인 경우와 우수 마을인 경우로 나눈 다음에 트리 탐색 과정에서 이를 조건에 맞게 처리해 주어야 한다.
나처럼 dp[마을 번호][우수 마을 선정 여부(0: 선정x, 1: 선정)] = 해당 마을 번호를 루트로 하는 서브 트리의 최대 우수 마을 주민 수로 설정한다면, n*2 크기의 배열이나 리스트를 사용하게 될 것이다.(실제 코드에선 간편한 인덱싱을 위해 (n+1)*2 크기로 사용했다)
이제 각각의 dp 값을 어떻게 채워 넣을 수 있을지 고민해봐야 하는데, 이렇게 생각해 볼 수 있다.
1. 먼저 어떤 마을 n을 탐색하는 중이라고 할 때, dp[n][0]은 0으로 초기화하고 dp[n][1]은 마을 n의 주민 수로 초기화한다. 우수 마을로 선정됐다면, 해당 마을의 주민 수가 구하고자 하는 결과 값에 더해져야 하기 때문이다. 그 이후 자식 노드들을 고려하기 시작한다.
2. dp[n][1], 즉 어떤 마을이 우수 마을이라면 그 자식 노드에 해당하는 마을들은 일반 마을이어야 한다. 따라서 dp[n][1]은 n의 모든 자식 m에 대해 dp[m][0]를 더한 값이 된다.
3. dp[n][0]이 묘하게 느껴질 수 있다. 일단 모든 자식 m에 대해 dp[m][1]은 당연히 고려할 수 있고, dp[m][0] 또한 고려할 수 있다. 손자 노드가 우수 마을일 수도 있고, 다른 자식 노드가 우수 마을인 경우가 채택될 수 있으니까. 그래서 dp[n][0]은 모든 자식 m에 대해 max(dp[m][0], dp[m][1])이라고 할 수 있는데, 이때 하나의 의문이 든다.
만약 모든 자식 m에 대해 dp[m][0]이 채택되어 버린다면? 이는 문제 조건 3번에 위배된다고 생각할 수 있다. 하지만 모든 자식 노드가 일반 마을인 경우를 채택하게 된다면 dp[n][1] 값이 무조건 dp[n][0]보다 크다. 즉, 해당 노드가 우수 마을인 경우가 부모 노드에 의해 채택된다는 소리이며, 따라서 그런 예외사항을 고려할 필요가 없다.
이제 위를 구현하기 위해서, DFS를 활용해 리프 노드부터 루트 노드까지 dp 값을 차례로 구해오며, 최종적으로 탐색을 시작한 임의의 루트 노드 r의 dp값인 dp[r][0], dp[r][1] 중 더 큰 값을 출력하면 끝이다. 오랜만에 푼 트리 DP 문제라서 재밌게 푼 것 같다.
#include <iostream>
#include <vector>
#include <algorithm>
using namespace std;
int n;
vector<vector<int> > dp; //0: 일반 마을, 1: 우수 마을
vector<int> residents;
vector<vector<int> > adjList;
vector<bool> isVisited;
void dfs(int num) {
if(isVisited[num]) return;
isVisited[num] = true;
dp[num][0] = 0;
dp[num][1] = residents[num];
for(auto next: adjList[num]) {
if(isVisited[next]) continue;
dfs(next);
dp[num][0] += max(dp[next][0], dp[next][1]);
dp[num][1] += dp[next][0];
}
}
int main() {
ios_base::sync_with_stdio(false);
cin.tie(nullptr);
cin >> n;
dp = vector<vector<int> >(n+1, vector<int>(2, 0));
residents = vector<int>(n+1);
isVisited = vector<bool>(n+1, false);
adjList = vector<vector<int> >(n+1);
for(int i=1; i<=n; i++) {
cin >> residents[i];
}
int a, b;
for(int i=0; i<n-1; i++) {
cin >> a >> b;
adjList[a].push_back(b);
adjList[b].push_back(a);
}
dfs(1);
cout << max(dp[1][0], dp[1][1]);
}
'Dev > BOJ' 카테고리의 다른 글
[백준 4195] 친구 네트워크 (1) | 2025.01.09 |
---|---|
[백준 1507] 궁금한 민호 (1) | 2024.12.10 |
[백준 1365] 꼬인 전깃줄 (0) | 2024.12.09 |
[백준 1922] 네트워크 연결 (0) | 2024.12.07 |
[백준 14428] 수열과 쿼리 16 (1) | 2024.12.06 |