题意
U2FsdGVkX18T278dBirMbWpY+y+yG0mooqC+jJM9ya7VIROpGnZxEEQgrWtQsjbA
DsQyQjhS3U0XUQih8b7ha79KxO8NW6YvusKeAg6XNR6eZVrk1pBAxeso8UWZFvXA
Fsqf4cTh4GXYqARxnS8SOrcskt8a8oI9py3W7AvFuLTm6zZlOi+nTnJF46QUzgit
wl3CFGCaa5H/v84J/UotFQaYmWyac8Dh594MZ3A9zXbqCQsWuN/jLNhczAUjkWm7
LWnxtUVMLa+hI1D6kmj9ac4Wxmu3PDn8qwZcn4oF8GShezmziu+GQUTRMsek2x7+
nDqtdTql95xEgdcNPEz/GVuQ10/SK/4WPgqnIGRGHyBSx17XklDUW8vcLezpjpuq
sx0EUHy76atPk3pufi/Ta8CFzRJzoC0/sYfEG7+ig4SauMm9rxgYpV85DWnBYf/x
解析
先考虑求解数量。令 表示以 为根节点的子树, 表示 处选了一个完整的四元组,或者没有选, 表示 处选了 个节点, 表示 处选了 个节点的链。
对于 ,容易得到 。对于 要求儿子中恰有一个选择是选 个节点,即:。而 则需要恰有一个选 且恰有一个选 ,或者全部不选,即: 。为了方便,可以用一个辅助的 dp 数组 表示当前是否选了 1 和是否选了 2 的方案数。
然后考虑求解和。一个简便的方法是,沿用上面的 dp 方法,将每个元素当成一个二元组 ,表示数量和和。考虑在二元组上定义 和 。由定义,容易得到:
然后根据上面的运算律计算即可。
实现
struct Solve
{
struct node_t
{
mint cnt, sum;
node_t() : cnt(0), sum(0) {}
node_t(mint cnt, mint sum) : cnt(cnt), sum(sum) {}
node_t operator+=(node_t a)
{
cnt += a.cnt;
sum += a.sum;
return *this;
};
node_t operator*=(node_t a)
{
sum = sum * a.cnt + a.sum * cnt;
cnt *= a.cnt;
return *this;
};
node_t operator*=(mint x)
{
cnt *= x;
sum *= x;
return *this;
}
node_t operator/=(mint x)
{
cnt /= x;
sum /= x;
return *this;
}
node_t operator+(node_t a) const { return node_t{*this} += a; }
node_t operator*(node_t a) const { return node_t{*this} *= a; }
node_t operator*(mint x) const { return node_t{*this} *= x; }
node_t operator/(mint x) const { return node_t{*this} /= x; }
};
int n;
std::vector<std::vector<int>> adj;
std::vector<int> fa;
std::vector<std::array<node_t, 3>> f;
Solve(int n) : n(n), adj(n), fa(n), f(n)
{
for (int i = 0; i < n - 1; i++) {
int u, v;
std::cin >> u >> v;
u--;
v--;
adj[u].emplace_back(v);
adj[v].emplace_back(u);
}
}
void dfs_init(int u)
{
for (auto v : adj[u]) {
if (v == fa[u]) continue;
fa[v] = u;
dfs_init(v);
}
}
void dfs_dp(int u)
{
f[u][0] = f[u][2] = node_t{0, 0};
f[u][1] = node_t{1, 0};
for (auto v : adj[u]) {
if (v == fa[u]) continue;
dfs_dp(v);
}
for (auto v : adj[u]) {
if (v == fa[u]) continue;
f[u][1] *= f[v][0];
}
std::array<std::array<node_t, 2>, 2> g;
g[0][0] = node_t{1, 0};
g[0][1] = g[1][0] = g[1][1] = node_t{0, 0};
for (auto v : adj[u]) {
if (v == fa[u]) continue;
std::array<std::array<node_t, 2>, 2> h;
h[0][0] += g[0][0] * f[v][0];
h[0][1] += g[0][0] * f[v][1] + g[0][1] * f[v][0];
h[1][0] += g[0][0] * f[v][2] + g[1][0] * f[v][0];
h[1][1] += g[0][1] * f[v][2] + g[1][0] * f[v][1] + g[1][1] * f[v][0];
g = std::move(h);
}
f[u][2] = g[0][1];
f[u][0] = f[u][1] + g[1][1] * node_t{1, 1};
}
void solve()
{
dfs_init(0);
dfs_dp(0);
}
};
int main()
{
set_io("tree");
int n, type;
std::cin >> n >> type;
Solve s(n);
s.solve();
auto res = s.f[0][0];
std::cout << res.cnt.val() << std::endl;
if (type) std::cout << res.sum.val() << std::endl;
}