题意
U2FsdGVkX1/FvStGG50vT3KlzSCj3Uj8LLAehvYJ+o1VCZSfyApuI+SzRwyNrrr8
X+gve1m4SmqCkZXrc7y1MIG2FBVdXX4Ru+LZElv2hyT/XBo8bpAQRw09HogV4cBF
BhkAhDgo4ezM0LiW49a0Qqw59fkEcCo+G2+ByX6S0XW3ovfStQW1DY5SZ7wWzFsO
GO7TUTGgTcx35kNw31WmDkyfzf7kqxGOpD3kdt4yKncdhu0Lub6DTG9uWWyAnQRC
gGGqDujvfzKCLMgJysdEYA==
解析
首先考虑解决只有一个点的情况,求与 距离不超过 的点的数量,这个可以用点分治离线解决。把距离 不超过 的点的数量和距离 不超过 的点的数量加起来,再减去同时距离 和 都不超过 的点的数量,就是答案。而求这个距离 和 都不超过 的点的数量可以转化为距离 和 中点距离不超过某个值的点数。为了保证中点是个点,需要先把每个边中间插入一个点。
实现
constexpr int MAX_FA = 20;
struct graph_t
{
int N;
int n;
std::vector<std::vector<int>> adj;
std::vector<std::array<int, MAX_FA>> fa;
std::vector<int> dep;
std::vector<int> size;
std::vector<bool> removed;
graph_t(int N, int n) : N(N), n(n), adj(n), fa(n), dep(n), size(n), removed(n), q(n) {}
void add_edge(int u, int v)/*{{{*/
{
adj[u].emplace_back(v);
adj[v].emplace_back(u);
}/*}}}*/
void dfs_info(int u)/*{{{*/
{
for (auto v : adj[u]) {
if (v == fa[u][0]) continue;
dep[v] = dep[u] + 1;
fa[v][0] = u;
for (int i = 1; i < MAX_FA; i++) fa[v][i] = fa[fa[v][i - 1]][i - 1];
dfs_info(v);
}
}/*}}}*/
int lca(int u, int v)/*{{{*/
{
if (dep[u] < dep[v]) std::swap(u, v);
for (int i = MAX_FA - 1; i >= 0; i--) {
if (dep[u] - dep[v] >= (1 << i)) {
u = fa[u][i];
}
}
if (u == v) return u;
for (int i = MAX_FA - 1; i >= 0; i--) {
if (fa[u][i] != fa[v][i]) {
u = fa[u][i];
v = fa[v][i];
}
}
return fa[u][0];
}/*}}}*/
int jump(int u, int d)
{
for (int i = MAX_FA - 1; i >= 0; i--) {
if (d >= (1 << i)) {
u = fa[u][i];
d -= (1 << i);
}
}
return u;
}
std::pair<int, int> midpoint(int u, int v)
{
int uv = lca(u, v);
int dis = (dep[u] + dep[v] - 2 * dep[uv]) + 1;
int half_dis = dis / 2;
if (dep[u] - dep[uv] == dep[v] - dep[uv]) {
return {uv, half_dis};
} else {
if (dep[u] - dep[uv] > dep[v] - dep[uv]) {
return {jump(u, half_dis), half_dis};
} else {
return {jump(v, half_dis), half_dis};
}
}
}
struct query_t
{
int d, id;
query_t(int d, int id) : d(d), id(id) {}
};
std::vector<std::vector<query_t>> q;
std::vector<int> ans;
int add_query(int u, int d)
{
int id = ans.size();
q[u].emplace_back(d, id);
ans.emplace_back(0);
return id;
}
void solve_query()
{
int centroid = dfs_centroid(0, 0, n);
calc(centroid);
}
int dfs_centroid(int u, int fa, int tot)
{
int max_size = 0;
size[u] = 1;
for (auto v : adj[u]) {
if (v == fa || removed[v]) continue;
int res = dfs_centroid(v, u, tot);
if (res != -1) return res;
size[u] += size[v];
max_size = std::max(max_size, size[v]);
}
max_size = std::max(max_size, tot - size[u]);
if (max_size <= tot / 2) {
if (fa != -1) size[fa] = tot - size[u];
return u;
}
return -1;
}
void calc(int u)
{
removed[u] = true;
std::vector<int> cc;
for (auto v : adj[u]) {
cc.emplace_back(v);
}
std::sort(cc.begin(), cc.end());
cc.erase(std::unique(cc.begin(), cc.end()), cc.end());
std::vector<int> all_path;
std::vector<std::vector<std::pair<int, int>>> node(cc.size());
std::vector<std::vector<int>> path(cc.size());
if (u < N) all_path.emplace_back(0);
for (auto v : adj[u]) {
if (removed[v]) continue;
int id = std::lower_bound(cc.begin(), cc.end(), v) - cc.begin();
dfs_path(v, u, 1, node[id], path[id]);
std::sort(path[id].begin(), path[id].end());
all_path.insert(all_path.end(), path[id].begin(), path[id].end());
}
std::sort(all_path.begin(), all_path.end());
for (auto [d, id] : q[u]) {
ans[id] += std::upper_bound(all_path.begin(), all_path.end(), d) - all_path.begin();
}
for (auto v : adj[u]) {
if (removed[v]) continue;
int vid = std::lower_bound(cc.begin(), cc.end(), v) - cc.begin();
for (auto [dis, x] : node[vid]) {
for (auto [d, qid] : q[x]) {
if (d - dis < 0) continue;
ans[qid] += std::upper_bound(all_path.begin(), all_path.end(), d - dis) - all_path.begin();
ans[qid] -= std::upper_bound(path[vid].begin(), path[vid].end(), d - dis) - path[vid].begin();
}
}
}
for (auto v : adj[u]) {
if (removed[v]) continue;
int centroid = dfs_centroid(v, u, size[v]);
calc(centroid);
}
}
void dfs_path(int u, int fa, int dep, std::vector<std::pair<int, int>> &node, std::vector<int> &path)
{
node.emplace_back(dep, u);
if (u < N) path.emplace_back(dep);
for (auto v : adj[u]) {
if (removed[v] || v == fa) continue;
dfs_path(v, u, dep + 1, node, path);
}
}
int operator[](int x) const
{
if (x == -1) return 0;
else return ans[x];
}
};
struct query_t
{
int u, v, mid, ua, va, mida;
query_t() = default;
query_t(int u, int v, int mid, int ua, int va, int mida) : u(u), v(v), mid(mid), ua(ua), va(va), mida(mida) {}
};
int main()
{
set_io("atree");
int n, Q;
std::cin >> n >> Q;
graph_t g(n, n * 2);
for (int i = 0; i < n - 1; i++) {
int u, v;
std::cin >> u >> v;
u--;
v--;
g.add_edge(u, n + i);
g.add_edge(v, n + i);
}
g.dfs_info(0);
std::vector<query_t> q(Q);
for (auto &[u, v, mid, ua, va, mida] : q) {
int d;
std::cin >> u >> v >> d;
u--;
v--;
d *= 2;
auto [uv, dis] = g.midpoint(u, v);
mid = uv;
ua = g.add_query(u, d);
va = g.add_query(v, d);
mida = -1;
if (d - dis >= 0) mida = g.add_query(mid, d - dis);
}
g.solve_query();
for (auto [_, __, ___, ua, va, mida] : q) {
int ans = g[ua] + g[va] - g[mida];
std::cout << ans << std::endl;
}
}