The 2019 ICPC Asia Shanghai Regional Contest F.A Simple Problem On A Tree (线段树&树链剖分)

题目描述

We have met so many problems on the tree, so today we will have a simple problem on a tree.
You are given a tree (an acyclic undirected connected graph) with N\mathbf{N}N nodes. The tree nodes are numbered from 1\mathbf{1}1 to N\mathbf{N}N. Each node has a weight Wi\mathbf{W_i}Wi​. We will have four kinds of operations on it and you should solve them efficiently. Wish you have fun!

输入描述:

The first line of the input gives the number of test case, T\mathbf{T}T (1≤T≤101 \leq \mathbf{T} \leq 101≤T≤10). T\mathbf{T}T test cases follow.
For each case, the first line contains only one integer N\mathbf{N}N.(1≤N≤100,0001 \leq \mathbf{N} \leq 100,0001≤N≤100,000) The next N−1\mathbf{N-1}N−1 lines each contains two integers x{x}x, y{y}y which means there is an edge between them. It also means we will give you one tree initially. 
The next line will contains N\mathbf{N}N integers which means the initial weight Wi\mathbf{W_i}Wi​ of each node. (0≤Wi≤1,000,000,0000 \leq \mathbf{W_i} \leq 1,000,000,0000≤Wi​≤1,000,000,000)

The next line will contains an integer Q\mathbf{Q}Q. (1≤Q≤10,0001 \leq \mathbf{Q} \leq 10,0001≤Q≤10,000) The next Q\mathbf{Q}Q lines will start with an integer 1, 2, 3 or 4 means the kind of this operation.

1.Given three integers u{u}u, v{v}v, w{w}w, for the u{u}u, v{v}v and all nodes between the path from u{u}u to v{v}v inclusive, you should update their weight to w{w}w. (1≤u,v≤N1 \leq u, v \leq \mathbf{N}1≤u,v≤N, 0≤w≤1,000,000,0000 \leq w \leq 1,000,000,0000≤w≤1,000,000,000)
2.Given three integers u{u}u, v{v}v, w{w}w, for the u{u}u, v{v}v and all nodes between the path from u{u}u to v{v}v inclusive, you should increase their weight by w{w}w. (1≤u,v≤N1 \leq u, v \leq \mathbf{N}1≤u,v≤N, 0≤w≤1,000,000,0000 \leq w \leq 1,000,000,0000≤w≤1,000,000,000)
3. Given three integers u{u}u, v{v}v, w{w}w, for the u{u}u, v{v}v and all nodes between the path from u{u}u to v{v}v inclusive, you should multiply their weight by w{w}w. (1≤u,v≤N1 \leq u, v \leq \mathbf{N}1≤u,v≤N, 0≤w≤1,000,000,0000 \leq w \leq 1,000,000,0000≤w≤1,000,000,000)
4.Given two integers u{u}u, v{v}v, you should check the node weights on the path between u{u}u and v{v}v, and you should output cubic sum of them. It means, output ∑xWx3\sum_{x} \mathbf{W}_x^3∑x​Wx3​, x{x}x is node on the path from u{u}u to v{v}v (inclusive u{u}u and v{v}v). (1≤u,v≤N1 \leq u, v \leq \mathbf{N}1≤u,v≤N)

输出描述:

For each test case, output one line containing ``Case #x:'', where x is the test case number (starting from 1).
For operation 4, output a single integer in one line representing the result. The result could be huge, print it module 1,000,000,007(109+7)(10^9+7)(109+7).

示例1

输入

1
5
2 1
1 3
5 3
4 3
1 2 3 4 5
6
4 2 4
1 5 4 2
2 2 4 3
3 2 3 4
4 5 4
4 2 4

输出

Case #1:
100
8133
20221

题目大意 :

一棵N(1e5)的带点权树,有四种操作:U到V路径之间点权覆盖、增加W、乘W,输出点权和。

Tip:

这道题发出来就为了好玩QAQ 不打算做详细解释了,

因为它是HDU4578套上一个树剖,思路比较简单,就是代码长了点(长的过分)

Accepted code

#include<bits/stdc++.h>
#include<unordered_map>
using namespace std;

#define sc scanf
#define ls rt << 1
#define rs ls | 1
#define Min(x, y) x = min(x, y)
#define Max(x, y) x = max(x, y)
#define ALL(x) (x).begin(),(x).end()
#define SZ(x) ((int)(x).size())
#define pir pair <int, int>
#define MK(x, y) make_pair(x, y)
#define MEM(x, b) memset(x, b, sizeof(x))
#define MPY(x, b) memcpy(x, b, sizeof(x))
#define lowbit(x) ((x) & -(x))
#define P2(x) ((x) * (x))

typedef long long ll;
const int Mod = 1e9 + 7;
const int N = 1e5 + 100;
const int INF = 0x3f3f3f3f;
const ll LINF = 0x3f3f3f3f3f3f3f3f;
inline ll dpow(ll a, ll b){ ll r = 1, t = a; while (b){ if (b & 1)r = (r*t) % Mod; b >>= 1; t = (t*t) % Mod; }return r % Mod; }
inline ll fpow(ll a, ll b){ ll r = 1, t = a; while (b){ if (b & 1)r = (r*t); b >>= 1; t = (t*t); }return r; }

// 树剖
vector <int> G[N << 1];
int fz[N], son[N], sz[N];
int idx[N], dfn, n, q;
int dep[N], top[N];
ll a[N], ev[N];

void init() {
	for (int i = 1; i <= n; i++)
		G[i].clear(), son[i] = 0;
	dfn = 0;
}
void dfs1(int x, int fa, int dist) {
	fz[x] = fa;     // 父亲
	dep[x] = dist;  //深度
	sz[x] = 1;      // 子树大小

	int mx = 0;   // 重儿子个数
	for (auto v : G[x]) {
		if (v == fa)
			continue;
		dfs1(v, x, dist + 1);
		sz[x] += sz[v];
		if (sz[v] > mx)
			mx = sz[v], son[x] = v;
	}
}
void dfs2(int x, int topfz) {
	idx[x] = ++dfn;      // 重新编号
	top[x] = topfz;
	ev[dfn] = a[x];

	if (!son[x])
		return;
	dfs2(son[x], topfz);

	for (auto v : G[x]) {
		if (v != fz[x] && v != son[x])
			dfs2(v, v);
	}
}


// 线段树
ll lzy[N * 4][3]; // 加,乘,覆盖
ll wt[N * 4][3]; // p次幂

void Pushup(int rt) {
	wt[rt][0] = (wt[ls][0] + wt[rs][0]) % Mod;
	wt[rt][1] = (wt[ls][1] + wt[rs][1]) % Mod;
	wt[rt][2] = (wt[ls][2] + wt[rs][2]) % Mod;
}
void Build(int rt, int L, int R) {
	lzy[rt][0] = lzy[rt][2] = 0;
	lzy[rt][1] = 1;

	if (L == R) {
		for (ll i = 0; i < 3; i++)
			wt[rt][i] = dpow((ll)ev[L], i + 1);
		return;
	}
	int mid = (L + R) >> 1;
	Build(ls, L, mid);
	Build(rs, mid + 1, R);
	Pushup(rt);
}
void Pushdown(int rt, ll len) {

	ll Llen = len - (len >> 1);
	ll Rlen = len >> 1;
	if (lzy[rt][2]) {  // 覆盖
		wt[ls][2] = dpow(lzy[rt][2], 3) * Llen % Mod;
		wt[rs][2] = dpow(lzy[rt][2], 3) * Rlen % Mod;

		wt[ls][1] = dpow(lzy[rt][2], 2) * Llen % Mod;
		wt[rs][1] = dpow(lzy[rt][2], 2) * Rlen % Mod;

		wt[ls][0] = lzy[rt][2] * Llen % Mod;
		wt[rs][0] = lzy[rt][2] * Rlen % Mod;

		lzy[ls][0] = lzy[rs][0] = 0;
		lzy[ls][1] = lzy[rs][1] = 1;
		lzy[ls][2] = lzy[rs][2] = lzy[rt][2];

		lzy[rt][2] = 0;
	}
	if (lzy[rt][1] != 1) {  // 乘法
		wt[ls][2] = wt[ls][2] * dpow(lzy[rt][1], 3) % Mod;
		wt[rs][2] = wt[rs][2] * dpow(lzy[rt][1], 3) % Mod;

		wt[ls][1] = wt[ls][1] * dpow(lzy[rt][1], 2) % Mod;
		wt[rs][1] = wt[rs][1] * dpow(lzy[rt][1], 2) % Mod;

		wt[ls][0] = wt[ls][0] * lzy[rt][1] % Mod;
		wt[rs][0] = wt[rs][0] * lzy[rt][1] % Mod;

		lzy[ls][0] = lzy[ls][0] * lzy[rt][1] % Mod;
		lzy[rs][0] = lzy[rs][0] * lzy[rt][1] % Mod;

		lzy[ls][1] = lzy[ls][1] * lzy[rt][1] % Mod;
		lzy[rs][1] = lzy[rs][1] * lzy[rt][1] % Mod;

		lzy[rt][1] = 1;
	}
	if (lzy[rt][0]) {  // 加法
		wt[ls][2] = (wt[ls][2] + ((wt[ls][1] + wt[ls][0] * lzy[rt][0] % Mod) % Mod) * (3 * lzy[rt][0] % Mod) % Mod
			+ (Llen * dpow(lzy[rt][0], 3)) % Mod) % Mod;
		wt[rs][2] = (wt[rs][2] + ((wt[rs][1] + wt[rs][0] * lzy[rt][0] % Mod) % Mod) * (3 * lzy[rt][0] % Mod) % Mod
			+ (Rlen * dpow(lzy[rt][0], 3)) % Mod) % Mod;

		wt[ls][1] = (wt[ls][1] + ((wt[ls][0] * 2) % Mod * lzy[rt][0] % Mod + Llen * dpow(lzy[rt][0], 2) % Mod) % Mod) % Mod;
		wt[rs][1] = (wt[rs][1] + ((wt[rs][0] * 2) % Mod * lzy[rt][0] % Mod + Rlen * dpow(lzy[rt][0], 2) % Mod) % Mod) % Mod;

		wt[ls][0] = (wt[ls][0] + Llen * lzy[rt][0] % Mod) % Mod;
		wt[rs][0] = (wt[rs][0] + Rlen * lzy[rt][0] % Mod) % Mod;

		lzy[ls][0] = (lzy[ls][0] + lzy[rt][0]) % Mod;
		lzy[rs][0] = (lzy[rs][0] + lzy[rt][0]) % Mod;

		lzy[rt][0] = 0;
	}
}
void Update(int rt, int L, int R, int l, int r, ll w, int op) {
	if (L >= l && R <= r) {
		// 加法
		ll len = (ll)(R - L + 1);
		if (op == 1) {
			wt[rt][2] = (wt[rt][2] + ((wt[rt][1] + wt[rt][0] * w % Mod) % Mod) * (3 * w % Mod) % Mod + len * dpow(w, 3) % Mod) % Mod;
			wt[rt][1] = (wt[rt][1] + (wt[rt][0] * 2 % Mod) * w % Mod + len * dpow(w, 2) % Mod) % Mod;
			wt[rt][0] = (wt[rt][0] + len * w % Mod) % Mod;

			lzy[rt][0] = (lzy[rt][0] + w) % Mod;
		}
		else if (op == 2) {
			wt[rt][2] = wt[rt][2] * dpow(w, 3) % Mod;
			wt[rt][1] = wt[rt][1] * dpow(w, 2) % Mod;
			wt[rt][0] = wt[rt][0] * w % Mod;

			lzy[rt][1] = lzy[rt][1] * w % Mod;
			lzy[rt][0] = lzy[rt][0] * w % Mod;
		}
		else {
			wt[rt][2] = dpow(w, 3) * len % Mod;
			wt[rt][1] = dpow(w, 2) * len % Mod;
			wt[rt][0] = w * len % Mod;

			lzy[rt][0] = 0, lzy[rt][1] = 1;
			lzy[rt][2] = w;
		}
		return;
	}
	Pushdown(rt, R - L + 1);
	int mid = (L + R) >> 1;
	if (mid >= l)
		Update(ls, L, mid, l, r, w, op);
	if (mid < r)
		Update(rs, mid + 1, R, l, r, w, op);
	Pushup(rt);
}
ll Query(int rt, int L, int R, int l, int r) {
	if (L >= l && R <= r)
		return wt[rt][2];
	Pushdown(rt, R - L + 1);
	int mid = (L + R) >> 1;
	ll ans = 0;
	if (mid >= l)
		ans = Query(ls, L, mid, l, r);
	if (mid < r)
		ans += Query(rs, mid + 1, R, l, r);
	return ans % Mod;
}
void Range_Update(int x, int y, int op, ll w) {
	// 轻链
	while (top[x] != top[y]) {
		if (dep[top[x]] < dep[top[y]])
			swap(x, y);
		Update(1, 1, n, idx[top[x]], idx[x], w, op);
		x = fz[top[x]];
	}
	// 重链
	if (dep[x] > dep[y])
		swap(x, y);
	Update(1, 1, n, idx[x], idx[y], w, op);
}
ll Range_Ask(int x, int y) {
	// 轻链
	ll sum = 0;
	while (top[x] != top[y]) {
		if (dep[top[x]] < dep[top[y]])
			swap(x, y);
		sum += Query(1, 1, n, idx[top[x]], idx[x]);
		sum %= Mod;
		x = fz[top[x]];
	}
	// 重链
	if (dep[x] > dep[y])
		swap(x, y);
	sum += Query(1, 1, n, idx[x], idx[y]);
	return sum % Mod;
}

int main()
{
	int T;
	cin >> T;
	int Case = 0;
	while (T--) {
		sc("%d", &n);
		init();
		for (int i = 1; i < n; i++) {
			int u, v;
			sc("%d %d", &u, &v);
			G[u].push_back(v);
			G[v].push_back(u);
		}
		for (int i = 1; i <= n; i++)
			sc("%lld", &a[i]);
		dfs1(1, 0, 1), dfs2(1, 1);
		Build(1, 1, n);

		sc("%d", &q);
		printf("Case #%d:\n", ++Case);
		while (q--) {
			int op, u, v; ll w;
			sc("%d", &op);
			if (op <= 3) {
				sc("%d %d %lld", &u, &v, &w);
				if (op == 1)   // 板子原因,调下位置
					op = 3;
				else if (op == 2)
					op = 1;
				else
					op = 2;
				Range_Update(u, v, op, w);
			}
			else {
				sc("%d %d", &u, &v);
				printf("%lld\n", Range_Ask(u, v));
			}
		}
	}
	return 0;  // 改数组大小!!!用pair记得改宏定义!!!
}

猜你喜欢

转载自blog.csdn.net/weixin_43851525/article/details/107140112