#!/usr/bin/python3import sysimport randomif len(sys.argv) != 4: print(f'Usage: {sys.argv[0]} n q seed') sys.exit(-1)MAXA = 1000MAXC = 10**9n = int(sys.argv[1])q = int(sys.argv[2])random.seed(sys.argv[3])print(n, q)a = (random.randint(1, MAXA) for i in range(n - 1))c = (random.randint(1, MAXC) for i in range(n))print(*a)print(*c)for i in range(q): x, y = (random.randint(1, n) for i in range(2)) print(x, y)
解析
我们认为求解乘法逆元是常数复杂度的。
定义 s(i)=0≤j<i∑ai,即 a 的前缀和
初始的 n 三方做法
定义 l(x,y) 表示 x 到 y 的期望距离。首先,当 x=y 时,l(x,y)=0,因为两点重合。由于 l(x,y) 显然等于 l(y,x),所以我们可以假设 x<y。因为我们是按照从 0 到 n−1 的顺序加的点,所以 y 不可能是 x 的祖先,也就是说 lca(x,y) 一定不等于 y。枚举 y 的所有父亲 i,则 l(x,i) 再加上 i 到 y 的边权就是 l(x,y)。
原来的 l(x,y),就状态定义就是 O(n2) 种,比较难优化,考虑拆分一下。我们知道树上 x 到 y 的距离等于 x 到根的距离加 y 到根的距离减去 2 倍
lca(x,y) 到根的距离。放到期望也是一样的,我们定义 d(x) 表示
x 到根的期望距离,f(x,y) 表示 lca(x,y) 到根的距离,则
l(x,y)=d(x)+d(y)−2f(x,y)。这样拆分降低了耦合度。
求解 d(x)
和求 l(x,y) 的方法类似,枚举 x 的父亲 i,然后把 d(i) 加上 i 到 x
的边权。容易得到:
d(x)=⎩⎨⎧0s(i)0≤i<x∑ai×[d(i)+ci]+cxx=0其他
可以维护 0≤i<x∑ai×[d(i)+ci] 做到 O(n) 求解。
求解 f(x, y)
仍然和求 l(x,y) 的方法类似,仍然是假设 x<y,枚举 y 的父亲 i,由于
y 不是 x 的祖先,所以如果 y 的父亲为 i,则 lca 到根距离的期望就是 f(x,i)。容易得到:
using mint = static_modint<1'000'000'007>;auto get_d(const auto &a, const auto &sa, const auto &c){ const int n = a.size(); std::vector<mint> d(n); mint sd = 0; d[0] = 0; sd += mint{c[0]} * a[0]; for (int i = 1; i < n; i++) { d[i] = sd / sa[i] + c[i]; sd += (d[i] + c[i]) * a[i]; } return d;}// g(i) = f(i, i + 1)auto get_g(const auto &a, const auto &sa, const auto &d){ const int n = a.size(); std::vector<mint> g(n); mint s = 0; for (int i = 0; i + 1 < n; i++) { g[i] = (d[i] * a[i] + s) / sa[i + 1]; s += g[i] * a[i]; } return g;};int main(){ int n, q; scanf("%d%d", &n, &q); std::vector<int> a(n), c(n); for (int i = 0; i < n - 1; i++) scanf("%d", &a[i]); for (auto &i : c) scanf("%d", &i); std::vector<mint> sa(n); for (int i = 0; i < n - 1; i++) sa[i + 1] = sa[i] + a[i]; const auto d = get_d(a, sa, c); const auto g = get_g(a, sa, d); for (int i = 0; i < q; i++) { int x, y; scanf("%d%d", &x, &y); x--; y--; if (x == y) { puts("0"); } else { int t = std::min(x, y); printf("%d\n", (d[x] + d[y] - g[t] * 2).val()); } }}