SDOI2018 原题识别(主席树)

题目链接

题目大意

给定 n n 个节点的树,其中包含一条非随机生成的长度为 k k 的链,剩下的节点均随机父节点连边。每个节点有一个随机的颜色,维护:
1.给定 x , y x,y ,求 x , y x,y 之间不同颜色数。
2.给定 x , y x,y ,对于所有满足分别在 x , y x,y 到根的路径上的点 a , b a,b ,求其询问1的答案之和。
n 1 0 5 , m 2 × 1 0 5 n\le 10^5,m\le 2\times 10^5

题解

码量比较大qwq……
我们先从链上的情况入手考虑。

链的情况

对于第一问,这是经典二维数点题。考虑 p i p_i 表示 i i 之前第一个和它颜色相同的位置。我们以 ( i , p i ) (i,p_i) 为坐标建点,询问不同颜色数就相当于询问 x x 坐标位于 [ l , r ] [l,r] y y 坐标小于 l l 的点个数。直接主席树维护即可。
对于第二问,我们考虑点 i i 对答案的贡献。不妨设 x y x\le y ,我们分三种情况讨论:
1. x < i y x<i\le y ,此时贡献应该是 [ p i x ] ( x p i ) ( y i + 1 ) [p_i\le x](x-p_i)(y-i+1)
2. 1 i x 1\le i\le x a b a\le b ,此时贡献应该是 ( i p i ) ( y i + 1 ) (i-p_i)(y-i+1)
3. 1 i x 1\le i\le x a b a\ge b ,此时贡献应该是 ( i p i ) ( x i + 1 ) (i-p_i)(x-i+1)
如果直接把三种答案加起来的话会发现2,3两种情况中 a = b a=b 的部分算重了,减1即可。于是我们就需要维护上面的东西(2,3两个情况其实可以合起来):
第一种是
i = x + 1 , p i x y ( x p i ) ( y i + 1 ) = i = x + 1 , p i x y x ( y + 1 ) p i ( y + 1 ) x i + p i i \sum_{i=x+1,p_i\le x}^y (x-p_i)(y-i+1)\\ =\sum_{i=x+1,p_i\le x}^y x(y+1)-p_i(y+1)-xi+p_ii
这个东西可以通过主席树维护四个值来计算:个数, i i 的和, p i p_i 的和, p i i p_ii 的和。
我们再来看第二种。
i = 1 x ( i p i ) ( x + y + 2 2 i ) = i = 1 x ( x + y + 2 ) ( i p i ) 2 i ( i p i ) \sum_{i=1}^x(i-p_i)(x+y+2-2i)\\ =\sum_{i=1}^x(x+y+2)(i-p_i)-2i(i-p_i)
这个东西没有了对 p i p_i 的限制条件,因此直接前缀和维护即可。(当然如果你非要主席树的话我也不能拦着qwq)
到此为止,链的情况被我们在 O ( n l o g n ) O(nlogn) 的时间内做完了。

推广到树

注意到树除了那条链其它都是随机的,因此每个点到链距离的期望是 O ( l o g n ) O(logn) 的。每个颜色也是随机的,因此每个颜色出现次数的期望是 O ( 1 ) O(1) 的。
也就是说,对于两个点 x , y x,y 的LCA,记为 l l ,必有一个点到其距离为 O ( l o g n ) O(logn) 。不妨就设这个点为 x x ,考虑第一问怎么做。
我们先计算出 [ l , y ] [l,y] 中不同的颜色数(注意下面的区间都指的是一条链),这个可以直接主席树。接下来做的事就是暴力枚举 [ x , l ) [x,l) 中的每个颜色,看看它是否在 [ l , y ] [l,y] 中出现了,直接统计。判断方法就是暴力枚举所有颜色和它相同的点即可。
因此第一问的复杂度也是 O ( n l o g n ) O(nlogn) 的。
考虑第二问,我们可以划分成如下三个子问题:
1. a [ 1 , l ) , b [ 1 , y ] a\in [1,l),b\in [1,y] 。这实际上就是链的情况,主席树统计即可。
2. a [ l , x ] , b [ 1 , l ) a\in [l,x],b\in [1,l) 。这其实也是一条链,我们可以稍微转化一下,先求出 a [ 1 , x ] , b [ 1 , l ) a\in [1,x],b\in [1,l) 的答案,然后减去多算的。
多算的东西是 2 ( i p i ) ( l i ) 1 \sum 2(i-p_i)(l-i)-1 ,直接前缀和就能维护。
3. a [ l , x ] , b [ l , y ] a\in [l,x],b\in [l,y] 。这个情况很难算,我们也考虑分开计算贡献。考虑存在于 [ l , y ] [l,y] 中的点 i i 的贡献为 [ p i < l ] ( y i + 1 ) ( x l + 1 ) [p_i<l](y-i+1)(x-l+1) ,主席树维护即可。
再考虑存在于 [ l , x ] [l,x] 中点 i i 的贡献,首先它必须是所有与它颜色相同的点中第一个在 [ l , x ] [l,x] 中出现的,它不能在 [ l , y ] [l,y] 中包含和它颜色相同的点。不妨令 j j [ l , y ] [l,y] 中第一个和它颜色相同的点(如果不存在则为 y + 1 y+1 ),那么其贡献为 [ p i < l ] ( x i + 1 ) ( j l ) [p_i<l](x-i+1)(j-l)
暴力枚举点是 O ( l o g n ) O(logn) 的,找第一次出现时 O ( 1 ) O(1) 的,因此总复杂度还是 O ( n l o g n ) O(nlogn) 的,只是常数比较大。

#include <bits/stdc++.h>
namespace IOStream {
	const int MAXR = 1 << 23;
	char _READ_[MAXR], _PRINT_[MAXR];
	int _READ_POS_, _PRINT_POS_, _READ_LEN_;
	inline char readc() {
	#ifndef ONLINE_JUDGE
		return getchar();
	#endif
		if (!_READ_POS_) _READ_LEN_ = fread(_READ_, 1, MAXR, stdin);
		char c = _READ_[_READ_POS_++];
		if (_READ_POS_ == MAXR) _READ_POS_ = 0;
		if (_READ_POS_ > _READ_LEN_) return 0;
		return c;
	}
	template<typename T> inline void read(T &x) {
		x = 0; register int flag = 1, c;
		while (((c = readc()) < '0' || c > '9') && c != '-');
		if (c == '-') flag = -1; else x = c - '0';
		while ((c = readc()) >= '0' && c <= '9') x = x * 10 + c - '0';
		x *= flag;
	}
	template<typename T1, typename ...T2> inline void read(T1 &a, T2 &...x) {
		read(a), read(x...);
	}
	inline int reads(char *s) {
		register int len = 0, c;
		while (isspace(c = readc()) || !c);
		s[len++] = c;
		while (!isspace(c = readc()) && c) s[len++] = c;
		s[len] = 0;
		return len;
	}
	inline void ioflush() {
		fwrite(_PRINT_, 1, _PRINT_POS_, stdout), _PRINT_POS_ = 0;
		fflush(stdout);
	}
	inline void printc(char c) {
		_PRINT_[_PRINT_POS_++] = c;
		if (_PRINT_POS_ == MAXR) ioflush();
	}
	inline void prints(char *s) {
		for (int i = 0; s[i]; i++) printc(s[i]);
	}
	template<typename T> inline void print(T x, char c = '\n') {
		if (x < 0) printc('-'), x = -x;
		if (x) {
			static char sta[20];
			register int tp = 0;
			for (; x; x /= 10) sta[tp++] = x % 10 + '0';
			while (tp > 0) printc(sta[--tp]);
		} else printc('0');
		printc(c);
	}
	template<typename T1, typename ...T2> inline void print(T1 x, T2... y) {
		print(x, ' '), print(y...);
	}
}
using namespace IOStream;
using namespace std;
typedef long long ll;
#define cls(a) memset(a, 0, sizeof(a))

const int MAXN = 200005, MAXT = 2000005;
struct Edge { int to, next; } edge[MAXN];
int dfn[MAXN], st[20][MAXN], head[MAXN], lg[MAXN], tot, n, m, K, T;
void addedge(int u, int v) {
	edge[++tot] = (Edge) { v, head[u] };
	head[u] = tot;
}
int lst[MAXN], col[MAXN], dep[MAXN], rt[MAXN], app[MAXN], ed[MAXN];
struct Value {
	ll sum1, sum2, sum3, sum4;
	Value() { sum1 = sum2 = sum3 = sum4 = 0; }
	Value& operator+=(const Value &v) {
		sum1 += v.sum1, sum2 += v.sum2, sum3 += v.sum3, sum4 += v.sum4;
		return *this;
	}
	Value& operator-=(const Value &v) {
		sum1 -= v.sum1, sum2 -= v.sum2, sum3 -= v.sum3, sum4 -= v.sum4;
		return *this;
	}
} nd[MAXT];
//sum1=1,sum2=p[i],sum3=i,sum4=p[i]*i
ll pre2[MAXN], pre3[MAXN]; int ptot;
//pre1=1,pre2=i-p[i],pre3=i(i-p[i])
int ls[MAXT], rs[MAXT], par[MAXN], vis[MAXN];

//presistence segment tree
int update(int p, int x, int y, int l = 0, int r = n) {
	int q = ++ptot; nd[q] = nd[p];
	++nd[q].sum1, nd[q].sum2 += y, nd[q].sum3 += x, nd[q].sum4 += (ll)x * y;
	if (l == r) return q;
	int mid = (l + r) >> 1;
	if (y <= mid) ls[q] = update(ls[p], x, y, l, mid), rs[q] = rs[p];
	else rs[q] = update(rs[p], x, y, mid + 1, r), ls[q] = ls[p];
	return q;
}
void query(Value &v, int p, int q, int a, int b, int l = 0, int r = n) {//x in (p,q],y in [a,b]
	if (a > r || b < l || p == q) return;
	if (a <= l && b >= r) { v += nd[q], v -= nd[p]; return; }
	int mid = (l + r) >> 1;
	query(v, ls[p], ls[q], a, b, l, mid);
	query(v, rs[p], rs[q], a, b, mid + 1, r);
}

vector<int> pla[MAXN];
void dfs(int u, int fa) {
	st[0][dfn[u] = ++tot] = u, dep[u] = dep[fa] + 1;
	pla[col[u]].push_back(u), lst[u] = app[col[u]];
	int t = app[col[u]]; app[col[u]] = u;
	pre2[u] = pre2[fa] + dep[u] - dep[lst[u]];
	pre3[u] = pre3[fa] + (ll)(dep[u] - dep[lst[u]]) * dep[u];
	rt[u] = update(rt[fa], dep[u], dep[lst[u]]);
	for (int i = head[u]; i; i = edge[i].next) {
		dfs(edge[i].to, u);
		st[0][++tot] = u;
	}
	app[col[u]] = t, ed[u] = tot;
}
int get_min(int x, int y) { return dep[x] < dep[y] ? x : y; }
int get_lca(int x, int y) {
	x = dfn[x], y = dfn[y];
	if (x > y) swap(x, y);
	int l = lg[y - x + 1];
	return get_min(st[l][x], st[l][y - (1 << l) + 1]);
}
int on_link(int x, int y, int p) {//x is ancestor of y
	return dfn[x] <= dfn[p] && ed[x] >= dfn[p] &&
		dfn[p] <= dfn[y] && ed[p] >= dfn[y];
}

int solve1(int x, int y) {
	++tot;
	int la = get_lca(x, K), lb = get_lca(y, K);
	if (la > lb) swap(la, lb), swap(x, y);
	int l = get_lca(x, y);
	Value v; query(v, rt[par[l]], rt[y], 0, dep[l] - 1);
	int res = v.sum1;
	for (int i = x; i != l; i = par[i]) if (vis[col[i]] != tot) {
		vis[col[i]] = tot;
		int flag = 1;
		for (int j : pla[col[i]])
			if (on_link(l, y, j)) { flag = 0; break; }
		res += flag;
	}
	return res;
}
ll calc_link(int x, int y, const Value &v) {//x is ancestor of y
	int a = dep[x], b = dep[y];
	ll res = (v.sum1 * a - v.sum2) * (b + 1) - v.sum3 * a + v.sum4;
	return res + (a + b + 2) * pre2[x] - 2 * pre3[x] - a;
}
ll solve2(int x, int y) {
	++tot;
	int la = get_lca(x, K), lb = get_lca(y, K);
	if (la > lb) swap(la, lb), swap(x, y);
	int l = get_lca(x, y), pl = par[l], dl = dep[l];
	Value v1, v2;
	query(v1, rt[pl], rt[y], 0, dl - 1);
	query(v2, rt[pl], rt[x], 0, dl - 1);
	ll res = calc_link(pl, y, v1) + calc_link(pl, x, v2) -
		2 * (dl * pre2[pl] - pre3[pl]) + dl - 1;
	res += ((dep[y] + 1) * v1.sum1 - v1.sum3) * (dep[x] - dl + 1);
	int tp = 0;
	for (int i = x; i != l; i = par[i]) app[++tp] = i;
	app[++tp] = l;
	while (tp > 0) {
		int i = app[tp--];
		if (vis[col[i]] != tot) {
			vis[col[i]] = tot;
			int mn = dep[y] + 1;
			for (int j : pla[col[i]])
				if (on_link(l, y, j) && mn > dep[j]) mn = dep[j];
			res += (ll)(mn - dl) * (dep[x] - dep[i] + 1);
		}
	}
	return res;
}

unsigned int SA, SB, SC;
unsigned int rng61(){
    SA ^= SA << 16;
    SA ^= SA >> 5;
    SA ^= SA << 1;
    unsigned int t = SA;
    SA = SB;
    SB = SC;
    SC ^= t ^ SA;
    return SC;
}
void gen(){
    read(n, K, SA, SB, SC);
    for(int i = 2; i <= K; i++) addedge(par[i] = i - 1, i);
    for(int i = K + 1; i <= n; i++)
        addedge(par[i] = rng61() % (i - 1) + 1, i);
    for(int i = 1; i <= n; i++) col[i] = rng61() % n + 1;
}
int main() {
	for (read(T); T--;) {
		tot = 0, cls(head), cls(vis), cls(app);
		gen();
		for (int i = 1; i <= n; i++) pla[i].clear();
		dfs(1, ptot = tot = 0);
		for (int i = 2; i <= tot; i++) lg[i] = lg[i >> 1] + 1;
		for (int i = 1; i < 20; i++)
		for (int j = 1; j + (1 << i) - 1 <= tot; j++)
			st[i][j] = get_min(st[i - 1][j], st[i - 1][j + (1 << i >> 1)]);
		tot = 0;
		for (read(m); m--;) {
			int a, b, c; read(a, b, c);
			if (a == 1) print(solve1(b, c));
			else print(solve2(b, c));
		}
	}
	ioflush();
	return 0;
}

猜你喜欢

转载自blog.csdn.net/WAautomaton/article/details/87288466