学习了一下换根DP,主要是根据一下两位大佬的博客学习的~
换根DP其实是树形DP的一种延伸技巧或者说是方法。
它的使用范围是,对树上的每个点跑树形DP。这样的话,不用换根DP一点一点跑的复杂度就是 ,必炸。那么换根DP应运而生。简单来讲,就是我们会通过推理发现,我们先以一个选定节点跑出来的最优解,通过另一个转移方程,就可以得出与他有关系的其他节点的答案。也就是说,我们相当于进行了两次DP,第一次的树形DP可以算作一种预处理,第二次的DP就是换根DP。其根本奥义就是用 的复杂度完成了 的问题。
题目练习:
传送门: P3478 [POI2008] STA-Station
AC代码:
#include <bits/stdc++.h>
using namespace std;
typedef long long LL;
const int N = 1e6 + 10;
LL f[N], cnt[N];
int h[N], e[N * 2], ne[N * 2], idx;
int n, ans;
LL mx;
void add(int a, int b){
e[idx] = b, ne[idx] = h[a], h[a] = idx++;
}
void dfs1(int u, int p){
cnt[u] = 1;
for(int i = h[u]; ~i; i = ne[i]){
int j = e[i];
if(j == p) continue;
dfs1(j, u);
cnt[u] += cnt[j];
f[u] += f[j] + cnt[j];
}
}
void dfs2(int u, int p){
if(f[u] > mx){
mx = f[u];
ans = u;
}
for(int i = h[u]; ~i; i = ne[i]){
int j = e[i];
if(j == p) continue;
int t1 = f[j] + cnt[j], t2 = cnt[j];
//开始换根
f[u] -= t1, cnt[u] -= t2;
f[j] += f[u] + cnt[u], cnt[j] += cnt[u];
dfs2(j, u);
f[u] += t1, cnt[u] += t2; //恢复
}
}
void solve(){
scanf("%d", &n);
memset(h, -1, sizeof h);
for(int i = 0; i < n - 1; i++){
int a, b;
scanf("%d%d", &a, &b);
add(a, b), add(b, a);
}
dfs1(1, 0);
dfs2(1, 0);
printf("%d\n", ans);
}
int main(){
int T = 1;
// scanf("%d", &T);
while(T -- ){
solve();
}
return 0;
}
传送门: POJ 3585 Accumulation Degree
AC代码:
//#include <bits/stdc++.h>
#include <iostream>
#include <cstring>
#include <algorithm>
using namespace std;
const int N = 200010;
int h[N], e[N * 2], w[N * 2], ne[N * 2], idx;
int du[N], f[N];
int n, ans;
void add(int a, int b, int c){
e[idx] = b, w[idx] = c, ne[idx] = h[a], h[a] = idx++;
}
void dfs1(int u, int p){
for(int i = h[u]; ~i; i = ne[i]){
int j = e[i];
if(j == p) continue;
dfs1(j, u);
if(du[j] == 1) f[u] += w[i]; //leaf,因为leaf没有子结点,f[leaf_id] = 0,不能用于取min,否则结果就是0了;
else f[u] += min(w[i], f[j]);
}
}
void dfs2(int u, int p){
ans = max(ans, f[u]);
for(int i = h[u]; ~i; i = ne[i]){
int j = e[i];
if(j == p) continue;
//换根,leaf需要特殊判断
//如果不特判,也用min(),那么由于f[leaf_id] = 0, 所以答案就被错误地更新成了0
if(du[u] == 1) {
f[j] += w[i]; //u是leaf,所以把根换成j时,w[i]必定能加上
dfs2(j, u);
}
else{
int t = min(w[i], f[j]);
f[u] -= t;
f[j] += min(w[i], f[u]);
dfs2(j, u);
f[u] += t;
}
}
}
void solve(){
scanf("%d", &n);
for(int i = 1; i <= n; i++){
h[i] = -1, du[i] = 0, f[i] = 0;
}
idx = 0, ans = 0;
for(int i = 0; i < n - 1; i++){
int a, b, c;
scanf("%d%d%d", &a, &b, &c);
add(a, b, c), add(b, a, c);
du[a]++, du[b]++;
}
dfs1(1, 0);
dfs2(1, 0);
printf("%d\n", ans);
}
int main(){
int T = 1;
scanf("%d", &T);
while(T -- ){
solve();
}
return 0;
}