题意

U2FsdGVkX18U1KtTi/DDTe12ohjQ7HCysDqhYK/SjCsWkMxn7BDk+2a5h7EMYyOQ
iHif1yN2BSOC/W1NNmCNucSDrcakL+CWjWiK9qWdaGAyWU6u/p6huelX/Y+h2CI/
dVFJfYMj3JXwZb5vhMSBvNsRwg1c/HNhrcE14M0KW6XMW44YlHaRuZRmOP8rvrxt
LKIDjHMMj4P+Q2H0zB3pIx66SAvygUih6CyvRHGVaJYKhPDCukNd9P3t6X8LaIkt

解析

考虑点分治。假设当前分治中心为 i,解决所有包含 i 的连通块的答案。

定义 f(i, j) 表示考虑 dfn 序前 i 个点,乘积为 j 的连通块数量。考虑转移,如果将 dfn 序为 i 的点加入连通块,则 f(i,j)f(i+1,j×ai)f(i, j) \to f(i + 1, j \times a_{i}),如果不加入连通块,则 f(i,j)f(i+size(i),j)f(i, j) \to f(i + size(i), j)(i 的儿子都无法再加入联通块,直接跳过 i 所在的子树)。这样时间复杂度是 O(nmlogn)O(nm\log n) 的。

考虑优化。对于两个乘积 j 和 j’,如果满足 mj=mj\left\lfloor\frac{m}{j} \right\rfloor = \left\lfloor\frac{m}{j'}\right\rfloor,即他们最大能乘的数相同,那么我们可以把 j 和 j’ 合并。所有可以合并的数,乘上一个相同数之后,仍然可以合并,因为有 a,b,cZ,abc=abc\forall a,b,c\in\mathbb{Z},\left\lfloor\frac{a}{bc} \right\rfloor = \left\lfloor\frac{\left\lfloor\frac{a}{b}\right\rfloor}{c} \right\rfloor。我们记下 g(i, j) 表示考虑 dfn 序前 i 个点,乘积为 k,而 j=mkj = \left\lfloor\frac{m} {k}\right\rfloor 的方案数。时间复杂度 O(nmlogn)O(n\sqrt m \log n)

实现

struct tree
{
	int n, m;
	std::vector<int> a;
	std::vector<std::vector<int>> adj;
	std::vector<int> id, map;
	std::vector<int> size;
	std::vector<bool> removed;

	tree(const std::vector<int> &a,
	     const std::vector<std::pair<int, int>> &edges,
	     int m) 
		: n(a.size()), m(m), a(a), adj(n), id(m + 1), size(n), removed(n)
	{
		for (int i = 0; i < (int)edges.size(); i++) {
			int u = edges[i].first, v = edges[i].second;
			adj[u].emplace_back(v);
			adj[v].emplace_back(u);
		}

		id[1] = 0;
		map.emplace_back(1);
		for (int i = 2; i <= m; i++) {
			if (m / i == m / (i - 1)) {
				id[i] = id[i - 1];
			} else {
				id[i] = map.size();
				map.emplace_back(i);
			}
		}
	}

	int find_centroid(int u, int fa, int tot)
	{
		size[u] = 1;
		int max_size = 0;
		for (auto v : adj[u]) {
			if (v == fa || removed[v]) continue;
			int res = find_centroid(v, u, tot);
			if (res != -1) return res;
			size[u] += size[v];
			max_size = std::max(max_size, size[v]);
		}
		max_size = std::max(max_size, tot - size[u]);
		if (max_size * 2 <= tot) {
			if (fa != -1) size[fa] = tot - size[u];
			return u;
		} else {
			return -1;
		}
	}

	int dfs_subtree(int u, int fa, std::vector<std::pair<int, int>> &child)
	{
		int id = child.size();
		child.emplace_back(u, 1);
		for (auto v : adj[u]) {
			if (v == fa || removed[v]) continue;
			int vid = dfs_subtree(v, u, child);
			child[id].second += child[vid].second;
		}
		return id;
	}

	mint calc(int u)
	{
		removed[u] = true;
		std::vector<std::pair<int, int>> subtree;
		dfs_subtree(u, u, subtree);
		int n = subtree[0].second;
		std::vector<std::vector<mint>> f(n + 1, std::vector<mint>(id[m] + 1));
		f[1][id[a[u]]] = 1;
		for (int i = 1; i < n; i++) {
			auto [v, s] = subtree[i];
			for (int j = 0; j <= id[m]; j++) {
				f[i + s][j] += f[i][j];
				if ((long long)map[j] * a[v] <= m) {
					f[i + 1][id[map[j] * a[v]]] += f[i][j];
				}
			}
		}

		mint res = 0;
		for (int j = 0; j <= id[m]; j++) {
			res += f.back()[j];
		}

		for (auto v : adj[u]) {
			if (removed[v]) continue;
			int centroid = find_centroid(v, -1, size[v]);
			res += calc(centroid);
		}
		
		return res;
	}

	mint solve()
	{
		int centroid = find_centroid(0, -1, n);
		return calc(centroid);
	}
};