题意

U2FsdGVkX18T278dBirMbWpY+y+yG0mooqC+jJM9ya7VIROpGnZxEEQgrWtQsjbA
DsQyQjhS3U0XUQih8b7ha79KxO8NW6YvusKeAg6XNR6eZVrk1pBAxeso8UWZFvXA
Fsqf4cTh4GXYqARxnS8SOrcskt8a8oI9py3W7AvFuLTm6zZlOi+nTnJF46QUzgit
wl3CFGCaa5H/v84J/UotFQaYmWyac8Dh594MZ3A9zXbqCQsWuN/jLNhczAUjkWm7
LWnxtUVMLa+hI1D6kmj9ac4Wxmu3PDn8qwZcn4oF8GShezmziu+GQUTRMsek2x7+
nDqtdTql95xEgdcNPEz/GVuQ10/SK/4WPgqnIGRGHyBSx17XklDUW8vcLezpjpuq
sx0EUHy76atPk3pufi/Ta8CFzRJzoC0/sYfEG7+ig4SauMm9rxgYpV85DWnBYf/x

解析

先考虑求解数量。令 f(i,0/1/2)f(i, 0/1/2) 表示以 ii 为根节点的子树,00 表示 ii 处选了一个完整的四元组,或者没有选,11 表示 ii 处选了 11 个节点,22 表示 ii 处选了 22 个节点的链。

对于 f(i,1)f(i, 1),容易得到 f(i,1)=vf(v,0)f(i, 1) = \prod_{v} f(v, 0)。对于 f(i,2)f(i, 2) 要求儿子中恰有一个选择是选 11 个节点,即:f(i,2)=vf(v,1)vvf(v,0)f(i, 2) = \sum_{v} f(v, 1) \prod_{v' \ne v} f(v, 0)。而 f(i,0)f(i, 0) 则需要恰有一个选 11 且恰有一个选 22,或者全部不选,即: f(i,0)=vf(v,0)+vvf(v,1)f(v,2)vvvvf(v,0)f(i, 0) = \prod_{v} f(v, 0) + \sum_{v \ne v'} f(v, 1)f(v', 2) \prod_{v'' \ne v \land v'' \ne v'} f(v'', 0)。为了方便,可以用一个辅助的 dp 数组 g(0/1,0/1)g(0/1, 0/1) 表示当前是否选了 1 和是否选了 2 的方案数。

然后考虑求解和。一个简便的方法是,沿用上面的 dp 方法,将每个元素当成一个二元组 (c,s)(c, s),表示数量和和。考虑在二元组上定义 ++×\times。由定义,容易得到:

(c1,s1)+(c2,s2)=(c1+c2,s1+s2)(c1,s1)×(c2,s2)=(c1c2,s1c2+s2c1)\begin{aligned} (c_{1}, s_{1}) + (c_{2}, s_{2}) &= (c_{1} + c_{2}, s_{1} + s_{2}) \\ (c_{1}, s_{1}) \times (c_{2}, s_{2}) &= (c_{1}c_{2}, s_{1}c_{2} + s_{2}c_{1}) \\ \end{aligned}

然后根据上面的运算律计算即可。

实现

struct Solve
{
	struct node_t
	{
		mint cnt, sum;
		node_t() : cnt(0), sum(0) {}
		node_t(mint cnt, mint sum) : cnt(cnt), sum(sum) {}
		node_t operator+=(node_t a) 
		{
			cnt += a.cnt;
			sum += a.sum;
			return *this;
		};
		node_t operator*=(node_t a) 
		{
			sum = sum * a.cnt + a.sum * cnt;
			cnt *= a.cnt;
			return *this;
		};
		node_t operator*=(mint x)
		{
			cnt *= x;
			sum *= x;
			return *this;
		}
		node_t operator/=(mint x)
		{
			cnt /= x;
			sum /= x;
			return *this;
		}
		node_t operator+(node_t a) const { return node_t{*this} += a; }
		node_t operator*(node_t a) const { return node_t{*this} *= a; }
		node_t operator*(mint x) const { return node_t{*this} *= x; }
		node_t operator/(mint x) const { return node_t{*this} /= x; }
	};

	int n;
	std::vector<std::vector<int>> adj;
	std::vector<int> fa;
	std::vector<std::array<node_t, 3>> f;

	Solve(int n) : n(n), adj(n), fa(n), f(n)
	{
		for (int i = 0; i < n - 1; i++) {
			int u, v;
			std::cin >> u >> v;
			u--;
			v--;
			adj[u].emplace_back(v);
			adj[v].emplace_back(u);
		}
	}

	void dfs_init(int u)
	{
		for (auto v : adj[u]) {
			if (v == fa[u]) continue;
			fa[v] = u;
			dfs_init(v);
		}
	}

	void dfs_dp(int u)
	{
		f[u][0] = f[u][2] = node_t{0, 0};
		f[u][1] = node_t{1, 0};
		for (auto v : adj[u]) {
			if (v == fa[u]) continue;
			dfs_dp(v);
		}
		for (auto v : adj[u]) {
			if (v == fa[u]) continue;
			f[u][1] *= f[v][0];
		}
		std::array<std::array<node_t, 2>, 2> g;
		g[0][0] = node_t{1, 0};
		g[0][1] = g[1][0] = g[1][1] = node_t{0, 0};
		for (auto v : adj[u]) {
			if (v == fa[u]) continue;
			std::array<std::array<node_t, 2>, 2> h;
			h[0][0] += g[0][0] * f[v][0];
			h[0][1] += g[0][0] * f[v][1] + g[0][1] * f[v][0];
			h[1][0] += g[0][0] * f[v][2] + g[1][0] * f[v][0];
			h[1][1] += g[0][1] * f[v][2] + g[1][0] * f[v][1] + g[1][1] * f[v][0];
			g = std::move(h);
		}
		f[u][2] = g[0][1];
		f[u][0] = f[u][1] + g[1][1] * node_t{1, 1};
	}

	void solve()
	{
		dfs_init(0);
		dfs_dp(0);
	}
};

int main()
{
	set_io("tree");

	int n, type;
	std::cin >> n >> type;
	Solve s(n);
	s.solve();
	auto res = s.f[0][0];
	std::cout << res.cnt.val() << std::endl;
	if (type) std::cout << res.sum.val() << std::endl;
}