洛谷传送门
BZOJ传送门
题目描述
风见幽香非常喜欢玩一个叫做 osu!的游戏,其中她最喜欢玩的模式就是接水果。由于她已经DT FC 了The big black, 她觉得这个游戏太简单了,于是发明了一个更加难的版本。
首先有一个地图,是一棵由 个顶点、 条边组成的树(例如图 1给出的树包含 个顶点、 条边)。
这颗树上有 个盘子,每个盘子实际上是一条路径(例如图 1 中顶点 6 到顶点 8 的路径),并且每个盘子还有一个权值。第 个盘子就是顶点 到顶点 的路径(由于是树,所以从 到 的路径是唯一的),权值为 。
接下来依次会有 个水果掉下来,每个水果本质上也是一条路径,第 个水果是从顶点 到顶点 的路径。
幽香每次需要选择一个盘子去接当前的水果:一个盘子能接住一个水果,当且仅当盘子的路径是水果的路径的子路径(例如图 中从 到 的路径是从 到 的路径的子路径)。这里规定:从 到 的路径与从 到 的路径是同一条路径。
当然为了提高难度,对于第 个水果,你需要选择能接住它的所有盘子中,权值第 小的那个盘子,每个盘子可重复使用(没有使用次数的上限:一个盘子接完一个水果后,后面还可继续接其他水果,只要它是水果路径的子路径)。幽香认为这个游戏很难,你能轻松解决给她看吗?
输入输出格式
输入格式:
第一行三个数 和 和 ,表示树的大小和盘子的个数和水果的个数。 接下来 行,每行两个数 、 ,表示树上的 和 之间有一条边。树中顶点按 到 标号。 接下来 行,每行三个数 、 、 ,表示路径为 到 、权值为 的盘子,其中 , 不等于 。 接下来 行,每行三个数 、 、 ,表示路径为 到 的水果,其中 不等于 ,你需要选择第 小的盘子,第 小一定存在。
输出格式:
对于每个果子,输出一行表示选择的盘子的权值。
输入输出样例
输入样例#1:
10 10 10
1 2
2 3
3 4
4 5
5 6
6 7
7 8
8 9
9 10
3 2 217394434
10 7 13022269
6 7 283254485
6 8 333042360
4 6 442139372
8 3 225045590
10 4 922205209
10 8 808296330
9 2 486331361
4 9 551176338
1 8 5
3 8 3
3 8 4
1 8 3
4 8 1
2 3 1
2 3 1
2 3 1
2 4 1
1 4 1
输出样例#1:
442139372
333042360
442139372
283254485
283254485
217394434
217394434
217394434
217394434
217394434
说明
。
解题分析
又是一道类似区间 大的问题, 还可以离线, 考虑整体二分。
那么关键的问题在如何判断一个水果是否能被盘子接住, 这样我们就可以在二分的时候计算出有几个盘子接到了同一个水果, 进而得出答案。
那么对于一个盘子, 我们分情况考虑:
-
两个端点为父子关系:(例如图中1、 3)
那么显然水果的一端应该在 的子树内, 一端在 的子树外, 那么就可以通过DFS的入栈出栈序分为两段: 。
- 两个端点不为父子关系: (例如图中2、 4)
那么显然水果的一端应该在 的子树内, 另一端在 的子树内, 也可以分别对应到一段区间上。
这样我们把每个盘子对应的水果的 序两端的可行区间可以转化为一个或两个矩形, 然后把水果视为一个点, 每次查询覆盖点的矩形有多少个。 这个可以用扫描线很方便地完成。
为了统一统计方式, 可以把 序较小的那段区间设为 轴上的一段, 序较大的一段设为 坐标区间。
代码如下:
#include <cstdio>
#include <cstring>
#include <cctype>
#include <algorithm>
#include <cstdlib>
#include <cmath>
#define R register
#define IN inline
#define W while
#define lbt(i) ((i) & (-(i)))
#define MX 40500
IN char gc()
{
static const int buflen = 1e6;
static char buf[buflen], *p1 = buf, *p2 = buf;
return p1 == p2 && (p2 = (p1 = buf) + fread(buf, 1, buflen, stdin), p1 == p2) ? EOF : *p1++;
}
template <class T>
IN void in(T &x)
{
x = 0; R char c = gc();
for (; !isdigit(c); c = gc());
for (; isdigit(c); c = gc())
x = (x << 1) + (x << 3) + c - 48;
}
int n, p, f, cnt, ct, dif, tot;
template <class T> IN T max(T a, T b) {return a > b ? a : b;}
template <class T> IN T min(T a, T b) {return a < b ? a : b;}
int lb[MX], rb[MX], head[MX], fat[MX][17], dep[MX], buf[MX], id[MX], ans[MX], tree[MX << 2];
struct Edge {int to, nex;} edge[MX << 1];
struct Plate {int a, b, val;} plt[MX];
struct Fruit {int a, b, kth;} frt[MX];
struct INFO {int typ, x, y[2], id, val;} line[MX << 3], buf1[MX << 3], buf2[MX << 3]; // 0 for modify, 1 for query
IN bool operator < (const INFO &x, const INFO &y) {return x.x == y.x ? x.typ < y.typ : x.x < y.x;}
IN void add(R int from, R int to) {edge[++cnt] = {to, head[from]}, head[from] = cnt;}
void DFS(R int now)
{
lb[now] = ++ct;
for (R int i = 1; i <= 16; ++i)
{
fat[now][i] = fat[fat[now][i - 1]][i - 1];
if (!fat[now][i]) break;
}
for (R int i = head[now]; i; i = edge[i].nex)
{
if (edge[i].to == fat[now][0]) continue;
fat[edge[i].to][0] = now;
dep[edge[i].to] = dep[now] + 1;
DFS(edge[i].to);
}
rb[now] = ct;
}
IN int LCA(R int x, R int y)
{
if (dep[x] < dep[y]) std::swap(x, y);
R int del = dep[x] - dep[y];
for (R int i = 0; i <= 16; ++i)
if (del & (1 << i)) x = fat[x][i];
if (x == y) return y;
for (R int i = 16; ~i; --i)
if (fat[x][i] ^ fat[y][i]) x = fat[x][i], y = fat[y][i];
return fat[x][0];
}
IN int Get(R int deep, R int shallow)
{
R int del = dep[deep] - dep[shallow] - 1;
for (R int i = 0; i <= 16; ++i)
if (del & (1 << i)) deep = fat[deep][i];
return deep;
}
namespace BIT
{
IN void add(R int pos, R int del) {for (; pos <= n; pos += lbt(pos)) tree[pos] += del;}
IN int query(R int pos)
{
int ret = 0;
for (; pos; pos -= lbt(pos)) ret += tree[pos];
return ret;
}
IN void modify(R int lef, R int rig, R int val) {add(lef, val), add(rig + 1, -val);}
}
void solve(R int lef, R int rig, R int lb, R int rb)
{
int cnt1 = 0, cnt2 = 0, res;
if (lef > rig || lb > rb) return;
if (lb == rb)
{
for (R int i = lef; i <= rig; ++i)
if (line[i].typ) ans[line[i].id] = buf[lb];
return;
}
int mid = lb + rb >> 1;
for (R int i = lef; i <= rig; ++i)
{
if (!line[i].typ)//Modify
{
if (line[i].val <= mid)
{
BIT::modify(line[i].y[0], line[i].y[1], line[i].id);
buf1[++cnt1] = line[i];
}
else buf2[++cnt2] = line[i];
}
else
{
res = BIT::query(line[i].y[0]);
if (res >= line[i].val) buf1[++cnt1] = line[i];
else buf2[++cnt2] = line[i], buf2[cnt2].val -= res;
}
}
for (R int i = lef; i <= rig; ++i)
{
if (!line[i].typ)
{
if (line[i].val <= mid)
BIT::modify(line[i].y[0], line[i].y[1], -line[i].id);
}
}
for (R int i = 1; i <= cnt1; ++i) line[lef + i - 1] = buf1[i];
for (R int i = 1; i <= cnt2; ++i) line[lef + cnt1 + i - 1] = buf2[i];
solve(lef, lef + cnt1 - 1, lb, mid), solve(lef + cnt1, rig, mid + 1, rb);
}
int main(void)
{
int a, b;
in(n), in(p), in(f);
for (R int i = 1; i < n; ++i) in(a), in(b), add(a, b), add(b, a);
for (R int i = 1; i <= p; ++i) in(plt[i].a), in(plt[i].b), in(plt[i].val), buf[i] = plt[i].val;
for (R int i = 1; i <= f; ++i) in(frt[i].a), in(frt[i].b), in(frt[i].kth);
std::sort(buf + 1, buf + 1 + p); DFS(1); ++n;
dif = std::unique(buf + 1, buf + 1 + p) - buf - 1;
for (R int i = 1; i <= p; ++i)
{
id[i] = std::lower_bound(buf + 1, buf + 1 + dif, plt[i].val) - buf;
if (dep[plt[i].a] > dep[plt[i].b]) std::swap(plt[i].a, plt[i].b);
//shallow : a; deep : b
if (LCA(plt[i].a, plt[i].b) == plt[i].a)
{
a = Get(plt[i].b, plt[i].a);
line[++tot] = {0, 1, {lb[plt[i].b], rb[plt[i].b]}, 1, id[i]};
line[++tot] = {0, lb[a], {lb[plt[i].b], rb[plt[i].b]}, -1, id[i]};
line[++tot] = {0, lb[plt[i].b], {rb[a] + 1, n}, 1, id[i]};
line[++tot] = {0, rb[plt[i].b] + 1, {rb[a] + 1, n}, -1, id[i]};//加1的原因是我们把修改排在的查询的前面, 而在边界上的点也要算进去, 这样我们的右边界会先被删掉。
}
else
{
if (lb[plt[i].a] > lb[plt[i].b]) std::swap(plt[i].a, plt[i].b);
line[++tot] = {0, lb[plt[i].a], {lb[plt[i].b], rb[plt[i].b]}, 1, id[i]};
line[++tot] = {0, rb[plt[i].a] + 1, {lb[plt[i].b], rb[plt[i].b]}, -1, id[i]};
}
}
for (R int i = 1; i <= f; ++i)
{
if (lb[frt[i].a] > lb[frt[i].b]) std::swap(frt[i].a, frt[i].b);
line[++tot] = {1, lb[frt[i].a], {lb[frt[i].b], 0}, i, frt[i].kth};
}
std::sort(line + 1, line + 1 + tot);
solve(1, tot, 1, dif);
for (R int i = 1; i <= f; ++i) printf("%d\n", ans[i]);
}