【树形dp】acwing758

传送门

题意是给一棵树,每个节点有颜色,只有白色和黑色,切成若干部分,使得每部分只有一个白色点。
求切割方案。

状态设计

f[i][0]代表i节点的儿子都符合条件,且 i 点所在的连通块白色点。
f[i][1]代表i节点的儿子都符合条件,且 i 点所在的连通块没有白色点

  1. 当 i 点是白色点的时候。f[i][1] = 0因为他不可能存在一个连通块内并且不包含白色。f[i][0]由其所有的儿子的方案累乘。具体来说,对于一个儿子j,可以选择切割这条边,+f[j][0],也可以不切这条边 +f[j][1]

f[i][1] = 0

f[i][0]=\prod{(f[j][0] + f[j][1])}

  1. 当 i 点是黑色点的时候。f[i][1]也是同理由其儿子选择切割与不切割转移来,f[i][0]则是由贡献白点的儿子转移,由于只能有一个儿子贡献白点,其他儿子贡献的是全黑点的连通块,所以各个儿子贡献之间是加法关系

f[i][1] = \prod{(f[j][0] + f[j][1])}

f[i][0] = \sum f[j][1] * \prod\limits_{i != j}{f[j][0]}

这里的由于需要 i != j 的乘法积,可以直接先全部乘起来,再分别除f[i][0]。但是由于有取模操作,并且模数很大且是一个质数,所以用逆元来代替除法。

不过xc大佬用的前缀后缀积,更简单但是开始的确是没想到。

#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const int maxn = 1e5 + 233;
const ll mod = 1000000007; 
int e[maxn * 2], h[maxn], pre[maxn * 2], c[maxn], cnt;
void add(int u, int v)
{
    e[cnt] = v; pre[cnt] = h[u]; h[u] = cnt++;
}
ll ksm(ll n, ll k)
{
    ll res = 1;
    while(k)
    {
        if(k & 1) res = (res * n) % mod;
        n = (n * n) % mod;
        k >>= 1;
    }
    return res;
}
ll f[maxn][2], vis[maxn];
void dfs(int x, int fa)
{
    vis[x] = 1;
    if(c[x] == 0)
    {
        f[x][0] = 1;
        f[x][1] = 0;
        for(int i = h[x]; ~i; i = pre[i])
        {
            int y = e[i];
            if(vis[y]) continue;
            dfs(y, x);
            f[x][0] *= f[y][1] + f[y][0];
            f[x][0] %= mod;
        }
    }
    else 
    {
        f[x][0] = 0;
        f[x][1] = 1;
        int val = 1;
        for(int i = h[x]; ~i; i = pre[i])
        {
            int y = e[i];
            if(vis[y]) continue;
            dfs(y, x);
            f[x][1] *= f[y][0] + f[y][1];
            f[x][1] %= mod;
        }
        for(int i = h[x]; ~i; i = pre[i])
        {
            int y = e[i];
            ll k = f[y][0] + f[y][1];
            k = ksm(k, mod - 2);
            if(y != fa) f[x][0] += (((f[y][0] * f[x][1]) % mod) * k) % mod, f[x][0] %= mod;
        }
    }

}
int main()
{
    int n; cin >> n;
    memset(h, -1, sizeof h);
    for(int i = 2; i <= n; i++)
    {
        int k; scanf("%d", &k);
        k++;
        add(k, i); add(i, k);
    }
    for(int i = 1; i <= n; i++)
    {
        scanf("%d", &c[i]);
    }
    dfs(1, 1);
    cout << f[1][0] % mod;
}

发表评论

邮箱地址不会被公开。 必填项已用*标注