hdu3974题解——线段树,dfs构建结构
题意:一个公司有N个人,编号1-n,除公司老总外,每个人都有一个上司,每个成员都有若干个或0个下属(一个成员的下属的下属还是他的下属),形成树结构,为公司的人分配任务i,若分配给成员x(可以是老总),则x和他的下属的任务都变为i;m次询问,两种情况,情况一询问某个成员的任务,情况二将成员x和他的下属的任务更改为y;
解法:将树结构构建为线性结构,用dfs求出成员x的下属区间(dfs可以用来求该节点管理的区间范围),然后就直接用线段树解。
dfs建立线性结构的时候可以先找到根结点,从根结点开始dfs,这样每个结点就只访问一次,可以减少时间和内存,但这题数据比较水,相差不大。
#include <bits/stdc++.h>
using namespace std;
const int MAXN = 5e5 + 10;
int head[MAXN], to[MAXN], nex[MAXN], cent;
int vis[MAXN], st[MAXN], ed[MAXN], cnt;
void init()
{
memset(head, -1, sizeof(head));
cent = 0;
}
void add(int x, int y)
{
to[cent] = y;
nex[cent] = head[x];
head[x] = cent++;
}
void dfs(int x)
{
cnt++;
st[x] = cnt;
vis[x] = 1;
for (int i = head[x]; ~i; i = nex[i])
dfs(to[i]);
ed[x] = cnt;
}
struct node{
int l, r, val;
int lazy;
}t[MAXN * 4];
void pushdown(int n)
{
if(t[n].l == t[n].r) return;
if(t[n].lazy)
{
t[n * 2].lazy = t[n * 2 + 1].lazy = t[n].lazy;
t[n * 2].val = t[n * 2 + 1].val = t[n].lazy;
t[n].lazy = 0;
}
}
void build(int n, int l, int r)
{
t[n].l = l;
t[n].r = r;
t[n].lazy = 0;
t[n].val = -1;
if(l == r) return;
int mid = (l + r) >> 1;
build(n * 2, l, mid);
build(n * 2 + 1, mid + 1, r);
}
void update(int n, int L, int R, int val)
{
int l = t[n].l, r = t[n].r;
if(L <= l && R >= r)
{
t[n].val = val;
t[n].lazy = val;
return;
}
pushdown(n);
int mid = (l + r) >> 1;
if(L <= mid) update(n * 2, L, R, val);
if(R > mid) update(n * 2 + 1, L, R, val);
}
int query(int n, int pos)
{
int l = t[n].l, r = t[n].r;
if(l == r) return t[n].val;
pushdown(n);
int mid = (l + r) >> 1;
if(pos <= mid) return query(n * 2, pos);
else return query(n * 2 + 1, pos);
}
int main()
{
int tt, n, m, x, y, c = 1;
scanf("%d", &tt);
while (tt--)
{
scanf("%d", &n);
cnt = 0;
init();
memset(vis, 0, sizeof(vis));
memset(st, 0, sizeof(st));
memset(ed, 0, sizeof(ed));
for (int i = 1; i < n; i++)
{
scanf("%d %d", &x, &y);
add(y, x);
}
for (int i = 1; i <= n; i++)
{
if(!vis[i])
{
dfs(i);
}
}
for (int i = 1; i <= n; i++)
printf("%d %d\n", st[i], ed[i]);
build(1, 1, cnt);
printf("%d\n", cnt);
scanf("%d", &m);
char s[3];
printf("Case #%d:\n", c++);
while (m--)
{
scanf("%s", s);
if(s[0] == 'C')
{
scanf("%d", &x);
printf("%d\n", query(1, st[x]));
}
else
{
scanf("%d %d", &x, &y);
update(1, st[x], ed[x], y);
}
}
}
}
优化版本
#include <bits/stdc++.h>
using namespace std;
const int MAXN = 5e5 + 10;
int head[MAXN], to[MAXN], nex[MAXN], cent;
int st[MAXN], ed[MAXN], cnt, f[MAXN];
void init()
{
memset(head, -1, sizeof(head));
cent = 0;
}
void add(int x, int y)
{
to[cent] = y;
nex[cent] = head[x];
head[x] = cent++;
}
int find(int x)
{
return x == f[x] ? x : f[x] = find(f[x]);
}
void dfs(int x)
{
cnt++;
st[x] = cnt;
for (int i = head[x]; ~i; i = nex[i])
dfs(to[i]);
ed[x] = cnt;
}
struct node{
int l, r, val;
int lazy;
}t[MAXN * 4];
void pushdown(int n)
{
if(t[n].l == t[n].r) return;
if(t[n].lazy)
{
t[n * 2].lazy = t[n * 2 + 1].lazy = t[n].lazy;
t[n * 2].val = t[n * 2 + 1].val = t[n].lazy;
t[n].lazy = 0;
}
}
void build(int n, int l, int r)
{
t[n].l = l;
t[n].r = r;
t[n].lazy = 0;
t[n].val = -1;
if(l == r) return;
int mid = (l + r) >> 1;
build(n * 2, l, mid);
build(n * 2 + 1, mid + 1, r);
}
void update(int n, int L, int R, int val)
{
int l = t[n].l, r = t[n].r;
if(L <= l && R >= r)
{
t[n].val = val;
t[n].lazy = val;
return;
}
pushdown(n);
int mid = (l + r) >> 1;
if(L <= mid) update(n * 2, L, R, val);
if(R > mid) update(n * 2 + 1, L, R, val);
}
int query(int n, int pos)
{
int l = t[n].l, r = t[n].r;
if(l == r) return t[n].val;
pushdown(n);
int mid = (l + r) >> 1;
if(pos <= mid) return query(n * 2, pos);
else return query(n * 2 + 1, pos);
}
int main()
{
int tt, n, m, x, y, c = 1;
scanf("%d", &tt);
while (tt--)
{
scanf("%d", &n);
cnt = 0;
init();
memset(st, 0, sizeof(st));
memset(ed, 0, sizeof(ed));
for (int i = 1; i <= n; i++) f[i] = i;
for (int i = 1; i < n; i++)
{
scanf("%d %d", &x, &y);
add(y, x);
int xx = find(x), yy = find(y);
if(xx != yy) {
f[xx] = yy;
}
}
set<int> sp;
for (int i = 1; i <= n; i++) {
x = find(i);
sp.insert(x);
}
for (set<int>::iterator vip = sp.begin(); vip != sp.end(); vip++)
{
dfs(*vip);
}
build(1, 1, cnt);
scanf("%d", &m);
char s[3];
printf("Case #%d:\n", c++);
while (m--)
{
scanf("%s", s);
if(s[0] == 'C')
{
scanf("%d", &x);
printf("%d\n", query(1, st[x]));
}
else
{
scanf("%d %d", &x, &y);
update(1, st[x], ed[x], y);
}
}
}
}