题意使用 openssl enc -aes256 -pbkdf2 加密。

U2FsdGVkX19EeCsRDVkgV517q9kufMjLb3BAGUUlIuVlCpx2alUOsTER1O0lkP76
N04/H6f+tFkm/arIBlq1bRLh5NDN/Qn1z4hEppfWhwE7Djhvo0Evypa8/Tw6iMKc
pc4k5ZHI99Cbl8EmLv3NhLA4r6rKTaKdvL0VCq+5EKzut+zzmhNcJ/o8uXN4JnyX
O/mL3XwK+njG4ON2sUPl5RclbnHGy/xpXmZkc/sQ70nGYnmpdDnCq6fe0kOtRvVJ

generator

import random
import sys

random.seed(sys.argv[-1])

def min_randrange(l, r, step):
    return min((random.randrange(l, r) for _ in range(step)))
def max_randrange(l, r, step):
    return max((random.randrange(l, r) for _ in range(step)))

n = 3000
print(n)

c = [random.randrange(0, n) for _ in range(n)]
for i in c:
    print(i + 1, end = ' ')
print()

adj = [(random.randrange(0, i), i) for i in range(1, n)]
# adj = [(i - 1, i) for i in range(1, n)]
random.shuffle(adj)
for u, v in adj:
    print(u + 1, v + 1)
print()

可以用 min_randrangemax_randrange 来生成有不均匀的随机数,以构造较长的链或叫扁平的树。

解析

很容易想到 O(n4)O(n^4) dp。先枚举颜色 cc,把不是这个颜色的点赋上 1-1,是这个颜色的点赋上 +1+1,这样题意就变为了求有多少连通块的点的和大于 00。定义 f(i,j)f(i, j) 表示以 ii 为根节点的子树,和为 jj 的连通块个数。用类似树上背包的方式转移(假设可以用负数下标):

for (auto v : edges[u]) {
	if (v == fa) continue;
	self(self, v, u, f);
	for (int i = 0; i < MAXC; i++) {
		auto g = convolution(f[u][i], f[v][i], SHIFT);
		for (int j = -SHIFT; j <= SHIFT; j++)
			f[u][i][j] += g[j];
	}
}

以下是几个可以优化的点:

复杂度玄学。可以参考:ouuan 树上背包的上下界优化

using mint = static_modint<998244353>;

void dfs_size(int u,
	      int fa,
	      const std::vector<std::vector<int>> &adj,
	      std::vector<int> &res)
{
	res[u] = 1;
	for (auto v : adj[u]) {
		if (v == fa) continue;
		dfs_size(v, u, adj, res);
		res[u] += res[v];
	}
}

std::vector<int> get_size(const std::vector<std::vector<int>> &adj)
{ 
	std::vector<int> res(adj.size());
	dfs_size(0, 0, adj, res);
	return res;
}

void dfs_cnt(int u,
	     int fa,
	     const std::vector<int> &c,
	     const std::vector<std::vector<int>> &adj,
	     std::vector<int> &res)
{
	if (c[u] == 1) res[u] = 1;
	for (auto v : adj[u]) {
		if (v == fa) continue;
		dfs_cnt(v, u, c, adj, res);
		res[u] += res[v];
	}
}

std::vector<int> get_cnt(const std::vector<int> &c,
			 const std::vector<std::vector<int>> &adj)
{
	std::vector<int> res(adj.size());
	dfs_cnt(0, 0, c, adj, res);
	return res;
}

void dfs_solve(int u,
	       int fa,
	       const std::vector<int> &size,
	       const std::vector<int> &cnt,
	       const std::vector<int> &c,
	       const std::vector<std::vector<int>> &adj,
	       std::vector<range_vector<mint>> &f)
{
	if (f[u].l <= c[u] && c[u] < f[u].r) f[u][c[u]] = 1;
	int max_child = -1, max_child_size = 0;
	for (auto v : adj[u]) {
		if (v == fa) continue;
		dfs_solve(v, u, size, cnt, c, adj, f);
		if (f[v].size() > max_child_size) {
			max_child_size = f[v].size();
			max_child = v;
		}
	}
	if (max_child != -1) {
		for (int i = std::max(f[max_child].l, f[u].l - c[u]);
		     i < std::min(f[max_child].r, f[u].r - c[u]);
		     i++) {
			f[u][i + c[u]] += f[max_child][i];
		}
		for (auto v : adj[u]) {
			if (v == fa || v == max_child) continue;
			range_vector<mint> g(f[u].l, f[u].r);
			for (int i = f[u].l; i < f[u].r; i++) {
				for (int j = std::max(f[v].l, f[u].l - i);
				     j < std::min(f[v].r, f[u].r - i);
				     j++) {
					g[i + j] += f[u][i] * f[v][j];
				}
			}
			for (int i = f[u].l; i < f[u].r; i++) {
				f[u][i] += g[i];
			}
		}
	}
}

mint solve(const std::vector<int> &size,
	   const std::vector<int> &c,
	   const std::vector<std::vector<int>> &adj)
{
	const int n = adj.size();
	const auto cnt = get_cnt(c, adj);
	std::vector<range_vector<mint>> f;
	for (int i = 0; i < n; i++) {
		f.emplace_back(std::max(std::min(-cnt[i], -(size[i] - cnt[i])), -cnt[0]), cnt[i] + 1);
	}
	dfs_solve(0, 0, size, cnt, c, adj, f);
	mint ans = 0;
	for (int i = 0; i < n; i++) {
		for (int j = 1; j < f[i].r; j++) {
			ans += f[i][j];
		}
	}
	return ans;
}

int main()
{
	int n;
	scanf("%d", &n);
	CC<int> colors;
	std::vector<int> c(n);
	std::vector<std::vector<int>> adj(n);
	for (auto &i : c) {
		scanf("%d", &i);
		colors.add(i);
	}
	for (auto &i : c) i = colors(i);
	for (int i = 0; i < n - 1; i++) {
		int u, v;
		scanf("%d%d", &u, &v);
		u--;
		v--;
		adj[u].emplace_back(v);
		adj[v].emplace_back(u);
	}

	const auto size = get_size(adj);

	mint ans = 0;
	for (int i = 0; i < colors.size(); i++) {
		std::vector<int> d(n);
		for (int j = 0; j < n; j++) {
			if (c[j] == i) d[j] = 1;
			else d[j] = -1;
		}

		ans += solve(size, d, adj);
	}

	printf("%d\n", ans.val());
}

省略了:static_modint(自动取模),CC(离散化),range_vector(支持负数下标的 vector)。