题意
U2FsdGVkX18U1KtTi/DDTe12ohjQ7HCysDqhYK/SjCsWkMxn7BDk+2a5h7EMYyOQ
iHif1yN2BSOC/W1NNmCNucSDrcakL+CWjWiK9qWdaGAyWU6u/p6huelX/Y+h2CI/
dVFJfYMj3JXwZb5vhMSBvNsRwg1c/HNhrcE14M0KW6XMW44YlHaRuZRmOP8rvrxt
LKIDjHMMj4P+Q2H0zB3pIx66SAvygUih6CyvRHGVaJYKhPDCukNd9P3t6X8LaIkt
解析
考虑点分治。假设当前分治中心为 i,解决所有包含 i 的连通块的答案。
定义 f(i, j) 表示考虑 dfn 序前 i 个点,乘积为 j 的连通块数量。考虑转移,如果将 dfn 序为 i 的点加入连通块,则 ,如果不加入连通块,则 (i 的儿子都无法再加入联通块,直接跳过 i 所在的子树)。这样时间复杂度是 的。
考虑优化。对于两个乘积 j 和 j’,如果满足 ,即他们最大能乘的数相同,那么我们可以把 j 和 j’ 合并。所有可以合并的数,乘上一个相同数之后,仍然可以合并,因为有 。我们记下 g(i, j) 表示考虑 dfn 序前 i 个点,乘积为 k,而 的方案数。时间复杂度 。
实现
struct tree
{
int n, m;
std::vector<int> a;
std::vector<std::vector<int>> adj;
std::vector<int> id, map;
std::vector<int> size;
std::vector<bool> removed;
tree(const std::vector<int> &a,
const std::vector<std::pair<int, int>> &edges,
int m)
: n(a.size()), m(m), a(a), adj(n), id(m + 1), size(n), removed(n)
{
for (int i = 0; i < (int)edges.size(); i++) {
int u = edges[i].first, v = edges[i].second;
adj[u].emplace_back(v);
adj[v].emplace_back(u);
}
id[1] = 0;
map.emplace_back(1);
for (int i = 2; i <= m; i++) {
if (m / i == m / (i - 1)) {
id[i] = id[i - 1];
} else {
id[i] = map.size();
map.emplace_back(i);
}
}
}
int find_centroid(int u, int fa, int tot)
{
size[u] = 1;
int max_size = 0;
for (auto v : adj[u]) {
if (v == fa || removed[v]) continue;
int res = find_centroid(v, u, tot);
if (res != -1) return res;
size[u] += size[v];
max_size = std::max(max_size, size[v]);
}
max_size = std::max(max_size, tot - size[u]);
if (max_size * 2 <= tot) {
if (fa != -1) size[fa] = tot - size[u];
return u;
} else {
return -1;
}
}
int dfs_subtree(int u, int fa, std::vector<std::pair<int, int>> &child)
{
int id = child.size();
child.emplace_back(u, 1);
for (auto v : adj[u]) {
if (v == fa || removed[v]) continue;
int vid = dfs_subtree(v, u, child);
child[id].second += child[vid].second;
}
return id;
}
mint calc(int u)
{
removed[u] = true;
std::vector<std::pair<int, int>> subtree;
dfs_subtree(u, u, subtree);
int n = subtree[0].second;
std::vector<std::vector<mint>> f(n + 1, std::vector<mint>(id[m] + 1));
f[1][id[a[u]]] = 1;
for (int i = 1; i < n; i++) {
auto [v, s] = subtree[i];
for (int j = 0; j <= id[m]; j++) {
f[i + s][j] += f[i][j];
if ((long long)map[j] * a[v] <= m) {
f[i + 1][id[map[j] * a[v]]] += f[i][j];
}
}
}
mint res = 0;
for (int j = 0; j <= id[m]; j++) {
res += f.back()[j];
}
for (auto v : adj[u]) {
if (removed[v]) continue;
int centroid = find_centroid(v, -1, size[v]);
res += calc(centroid);
}
return res;
}
mint solve()
{
int centroid = find_centroid(0, -1, n);
return calc(centroid);
}
};