题意使用 openssl enc -aes256 -pbkdf2 加密。
U2FsdGVkX19EeCsRDVkgV517q9kufMjLb3BAGUUlIuVlCpx2alUOsTER1O0lkP76
N04/H6f+tFkm/arIBlq1bRLh5NDN/Qn1z4hEppfWhwE7Djhvo0Evypa8/Tw6iMKc
pc4k5ZHI99Cbl8EmLv3NhLA4r6rKTaKdvL0VCq+5EKzut+zzmhNcJ/o8uXN4JnyX
O/mL3XwK+njG4ON2sUPl5RclbnHGy/xpXmZkc/sQ70nGYnmpdDnCq6fe0kOtRvVJ
generator
import random
import sys
random.seed(sys.argv[-1])
def min_randrange(l, r, step):
return min((random.randrange(l, r) for _ in range(step)))
def max_randrange(l, r, step):
return max((random.randrange(l, r) for _ in range(step)))
n = 3000
print(n)
c = [random.randrange(0, n) for _ in range(n)]
for i in c:
print(i + 1, end = ' ')
print()
adj = [(random.randrange(0, i), i) for i in range(1, n)]
# adj = [(i - 1, i) for i in range(1, n)]
random.shuffle(adj)
for u, v in adj:
print(u + 1, v + 1)
print()
可以用 min_randrange 和 max_randrange 来生成有不均匀的随机数,以构造较长的链或叫扁平的树。
解析
很容易想到 dp。先枚举颜色 ,把不是这个颜色的点赋上 ,是这个颜色的点赋上 ,这样题意就变为了求有多少连通块的点的和大于 。定义 表示以 为根节点的子树,和为 的连通块个数。用类似树上背包的方式转移(假设可以用负数下标):
for (auto v : edges[u]) {
if (v == fa) continue;
self(self, v, u, f);
for (int i = 0; i < MAXC; i++) {
auto g = convolution(f[u][i], f[v][i], SHIFT);
for (int j = -SHIFT; j <= SHIFT; j++)
f[u][i][j] += g[j];
}
}
以下是几个可以优化的点:
- 上下界优化:假设 表示以 为根的子树出现相同颜色数量。 表示以 为根的子树大小。则对于 这个和的取值范围可以是 。
- 重儿子优化:注意到我们可以 合并第一个儿子,所以我们可以把重儿子当成第一个儿子。(这里定义重儿子是取值范围最大的)。
- 卷积优化:用到了卷积。可以考虑 NTT 优化。(然而我没写)
复杂度玄学。可以参考:ouuan 树上背包的上下界优化
using mint = static_modint<998244353>;
void dfs_size(int u,
int fa,
const std::vector<std::vector<int>> &adj,
std::vector<int> &res)
{
res[u] = 1;
for (auto v : adj[u]) {
if (v == fa) continue;
dfs_size(v, u, adj, res);
res[u] += res[v];
}
}
std::vector<int> get_size(const std::vector<std::vector<int>> &adj)
{
std::vector<int> res(adj.size());
dfs_size(0, 0, adj, res);
return res;
}
void dfs_cnt(int u,
int fa,
const std::vector<int> &c,
const std::vector<std::vector<int>> &adj,
std::vector<int> &res)
{
if (c[u] == 1) res[u] = 1;
for (auto v : adj[u]) {
if (v == fa) continue;
dfs_cnt(v, u, c, adj, res);
res[u] += res[v];
}
}
std::vector<int> get_cnt(const std::vector<int> &c,
const std::vector<std::vector<int>> &adj)
{
std::vector<int> res(adj.size());
dfs_cnt(0, 0, c, adj, res);
return res;
}
void dfs_solve(int u,
int fa,
const std::vector<int> &size,
const std::vector<int> &cnt,
const std::vector<int> &c,
const std::vector<std::vector<int>> &adj,
std::vector<range_vector<mint>> &f)
{
if (f[u].l <= c[u] && c[u] < f[u].r) f[u][c[u]] = 1;
int max_child = -1, max_child_size = 0;
for (auto v : adj[u]) {
if (v == fa) continue;
dfs_solve(v, u, size, cnt, c, adj, f);
if (f[v].size() > max_child_size) {
max_child_size = f[v].size();
max_child = v;
}
}
if (max_child != -1) {
for (int i = std::max(f[max_child].l, f[u].l - c[u]);
i < std::min(f[max_child].r, f[u].r - c[u]);
i++) {
f[u][i + c[u]] += f[max_child][i];
}
for (auto v : adj[u]) {
if (v == fa || v == max_child) continue;
range_vector<mint> g(f[u].l, f[u].r);
for (int i = f[u].l; i < f[u].r; i++) {
for (int j = std::max(f[v].l, f[u].l - i);
j < std::min(f[v].r, f[u].r - i);
j++) {
g[i + j] += f[u][i] * f[v][j];
}
}
for (int i = f[u].l; i < f[u].r; i++) {
f[u][i] += g[i];
}
}
}
}
mint solve(const std::vector<int> &size,
const std::vector<int> &c,
const std::vector<std::vector<int>> &adj)
{
const int n = adj.size();
const auto cnt = get_cnt(c, adj);
std::vector<range_vector<mint>> f;
for (int i = 0; i < n; i++) {
f.emplace_back(std::max(std::min(-cnt[i], -(size[i] - cnt[i])), -cnt[0]), cnt[i] + 1);
}
dfs_solve(0, 0, size, cnt, c, adj, f);
mint ans = 0;
for (int i = 0; i < n; i++) {
for (int j = 1; j < f[i].r; j++) {
ans += f[i][j];
}
}
return ans;
}
int main()
{
int n;
scanf("%d", &n);
CC<int> colors;
std::vector<int> c(n);
std::vector<std::vector<int>> adj(n);
for (auto &i : c) {
scanf("%d", &i);
colors.add(i);
}
for (auto &i : c) i = colors(i);
for (int i = 0; i < n - 1; i++) {
int u, v;
scanf("%d%d", &u, &v);
u--;
v--;
adj[u].emplace_back(v);
adj[v].emplace_back(u);
}
const auto size = get_size(adj);
mint ans = 0;
for (int i = 0; i < colors.size(); i++) {
std::vector<int> d(n);
for (int j = 0; j < n; j++) {
if (c[j] == i) d[j] = 1;
else d[j] = -1;
}
ans += solve(size, d, adj);
}
printf("%d\n", ans.val());
}
省略了:static_modint(自动取模),CC(离散化),range_vector(支持负数下标的 vector)。