树上分组背包问题

yuanheci 2023年03月28日 681次浏览

  第一次碰到这种树上的背包问题,记录一下,后续会慢慢补充。

题目链接 《P1273 有线电视网》

关键:
  把每个结点看成一个背包,体积就是该结点为根的子树中的用户总数量。并且以该结点为根的子树中的用户选择数量作为不同组别(每一种数量看成是组内的一种物品),进行分组背包求解。

完整的状态定义为三维:

  f[u][i][j]:f[u][i][j]: 表示以 uu 为根的子树,且只选其中的 ii 个用户,所选用户数量不超过 jj 的方案集合的最大价值。

状态转移方程:

  f[u][i][j]=max(f[u][i1][jk]+f[son][size(son)][k]cost[u][son])f[u][i][j] = max(f[u][i - 1][j - k] + f[son][size(son)][k] - cost[u][son])

  • son表示选择的某个子结点.
  • k表示分给son的体积。

  三维会 MLEMLE,因此可以根据转移方程优化掉第一维,由于树形DPDP采用 dfsdfs 形式,是自底向上的,所以在计算以 uu 为根结点时,以 sonson 为根的情况已经算完了,所以 f[son][k]f[son][k] 这一项不受影响。而前一项需要类比 0101 背包,对体积从大到小枚举。

  因此二维状态表示 f[u][j]f[u][j] 对应的转移方程如下:

  f[u][j]=max(f[u][jk]+f[son][k]cost[u][son])f[u][j] = max(f[u][j - k] + f[son][k] - cost[u][son])

答案: 统计一遍,f[1][j]0f[1][j] \ge 0 时最大的 jj 就是答案。

注意: ff 的初始化

AC代码:

  1. #include <bits/stdc++.h>
  2. using namespace std;

  3. const int N = 3010;
  4. int h[N], e[N], w[N], ne[N], idx;
  5. int val[N], f[N][N]; //f[i][j]: 以i为根结点,选择用户数量不超过j的方案的最大价值,答案就是f[1][j] > 0中的Max(j)
  6. int n, m;
  7. int ans;

  8. void add(int a, int b, int c){
  9. e[idx] = b, w[idx] = c, ne[idx] = h[a], h[a] = idx++;
  10. }

  11. int dfs(int u, int p){
  12. if(u > n - m){ //已经到达用户终端
  13. f[u][1] = val[u];
  14. return 1;
  15. }
  16. int s = 0;
  17. for(int i = h[u]; ~i; i = ne[i]){
  18. int son = e[i];
  19. if(son == p) continue;
  20. int t = dfs(son, u);
  21. s += t;
  22. for(int j = m; j >= 0; j--){
  23. //分组选取
  24. for(int k = 0; k <= t; k++){
  25. if(j - k >= 0) f[u][j] = max(f[u][j], f[u][j - k] + f[son][k] - w[i]);
  26. }
  27. }
  28. }
  29. return s;
  30. }

  31. void solve(){
  32. scanf("%d%d", &n, &m);
  33. memset(h, -1, sizeof h);
  34. memset(f, -0x3f, sizeof f);
  35. for(int i = 1; i <= n - m; i++){
  36. int k;
  37. scanf("%d", &k);
  38. for(int j = 0; j < k; j++){
  39. int x, c;
  40. scanf("%d%d", &x, &c);
  41. add(i, x, c);
  42. }
  43. }
  44. for(int i = n - m + 1; i <= n; i++) scanf("%d", &val[i]);
  45. for(int i = 1; i <= n; i++) f[i][0] = 0;
  46. dfs(1, -1);

  47. for(int j = m; j >= 0; j--){
  48. if(f[1][j] >= 0){
  49. ans = j;
  50. break;
  51. }
  52. }
  53. printf("%d\n", ans);
  54. }

  55. int main(){
  56. int T = 1;
  57. // scanf("%d", &T);
  58. while(T -- ){
  59. solve();
  60. }
  61. return 0;
  62. }