PROBLEM
给定一个n个节点的树,给定一个排列,求所有连续子段的节点的LCA的深度和。
SOLUTION
这题有很多种方法。
分治
考虑跨过区间中点的答案,从中线往两边扫,扫过左半边和右半边的LCA一定是在两条链上,那么合并这两条链上任意点对的答案,扫一遍就可以了。另外考虑O(1)求LCA,用欧拉序与RMQ可以做到总复杂度O(Nlog n),只不过常数巨大。
一个性质
对于一个排列,任意相邻位置求LCA,所有LCA中深度最小的就是整个排列的LCA。因此我们可以把这个问题转化为区间最值问题。先将相邻的LCA求出来,如果一个位置为选定区间的LCA,那么这个区间里就不能有深度比它小的,找到左边第一个小于它的,右边第一个小于等于它的,避免算重,用一个单调栈来维护一下就好了。
3)线段树合并?
单调栈
#include<cstdio>
#include<cstring>
#include<cmath>
#include<algorithm>
#define maxn 600010
#define maxm 1200020
#define maxp 20
#define LL long long
using namespace std;
int n,m,i,j,k,x,y,t,tot,lca,p[maxn],le[maxn],ri[maxn],b[maxn];
int em,e[maxm],nx[maxm],ls[maxn],dep[maxn],d[maxn],fa[maxn][maxp+1];
LL ans;
void read(int &x){
x=0; char ch=getchar();
for(;ch<'0'||ch>'9';ch=getchar());
for(;ch>='0'&&ch<='9';ch=getchar()) x=x*10+ch-'0';
}
void insert(int x,int y){
em++; e[em]=y; nx[em]=ls[x]; ls[x]=em;
em++; e[em]=x; nx[em]=ls[y]; ls[y]=em;
}
void dfs(int x,int p){
fa[x][0]=p,dep[x]=dep[p]+1;
for(int i=1;i<=maxp;i++) fa[x][i]=fa[fa[x][i-1]][i-1];
for(int i=ls[x];i;i=nx[i]) if (e[i]!=p)
dfs(e[i],x);
}
int getlca(int x,int y){
if (dep[x]<dep[y]) swap(x,y);
for(int i=maxp;i>=0;i--) if (dep[fa[x][i]]>=dep[y]) x=fa[x][i];
if (x==y) return x;
for(int i=maxp;i>=0;i--) if (fa[x][i]!=fa[y][i]) x=fa[x][i],y=fa[y][i];
return fa[x][0];
}
int main(){
freopen("easy.in","r",stdin);
freopen("easy.out","w",stdout);
read(n);
for(i=1;i<n;i++){
read(x),read(y);
insert(x,y);
}
dfs(1,0);
for(i=1;i<=n;i++) ans+=dep[i];
for(i=1;i<=n;i++) read(p[i]);
for(i=1;i<=n-1;i++) b[i]=getlca(p[i],p[i+1]);
d[t=1]=0;
for(i=1;i<=n-1;i++){
while (t&&dep[b[d[t]]]>dep[b[i]]) t--;
le[i]=d[t]+1;
d[++t]=i;
}
d[t=1]=0;
for(i=n-1;i>=1;i--){
while (t&&dep[b[d[t]]]>=dep[b[i]]) t--;
if (d[t]==0) ri[i]=n-1;
else ri[i]=d[t]-1;
d[++t]=i;
}
for(i=1;i<=n-1;i++) ans+=(LL)(ri[i]-i+1)*(i-le[i]+1)*dep[b[i]];
printf("%lld",ans);
}
T飞了的RMQ
#include<cstdio>
#include<cstring>
#include<cmath>
#include<algorithm>
#define maxn 600010
#define maxm 1200020
#define maxp 20
#define LL long long
using namespace std;
int n,m,i,j,k,x,y,t,tot,lca,p[maxn],le[maxn],ri[maxn];
int em,e[maxm],nx[maxm],ls[maxn],dep[maxn],d[maxn];
int fir[maxn*2],a[maxn*2],f[maxn*2][maxp+1],g[maxn*2][maxp+1],b[maxn];
LL ans;
void read(int &x){
x=0; char ch=getchar();
for(;ch<'0'||ch>'9';ch=getchar());
for(;ch>='0'&&ch<='9';ch=getchar()) x=x*10+ch-'0';
}
void insert(int x,int y){
em++; e[em]=y; nx[em]=ls[x]; ls[x]=em;
em++; e[em]=x; nx[em]=ls[y]; ls[y]=em;
}
void dfs(int x,int p){
a[++a[0]]=x; fir[x]=a[0]; dep[x]=dep[p]+1;
for(int i=ls[x];i;i=nx[i]) if (e[i]!=p)
dfs(e[i],x),a[++a[0]]=x;
}
void RMQ(){
for(int i=1;i<=a[0];i++) f[i][0]=dep[a[i]],g[i][0]=a[i];
for(int j=1;j<=maxp&&(1<<j)<=a[0]-1;j++)
for(int i=1;i<=a[0]-(1<<j);i++){
int k=i+(1<<(j-1));
if (f[i][j-1]<f[k][j-1])
f[i][j]=f[i][j-1],g[i][j]=g[i][j-1];
else f[i][j]=f[k][j-1],g[i][j]=g[k][j-1];
}
}
int que(int x,int y){
if (x==y) return x;
if (fir[x]>fir[y]) swap(x,y);
k=log2(fir[y]-fir[x]);
if (f[fir[x]][k]<f[fir[y]-(1<<k)][k]) return g[fir[x]][k];
else return g[fir[y]-(1<<k)][k];
}
int main(){
freopen("easy.in","r",stdin);
freopen("easy.out","w",stdout);
read(n);
for(i=1;i<n;i++){
read(x),read(y);
insert(x,y);
}
dfs(1,0);
RMQ();
for(i=1;i<=n;i++) ans+=dep[i];
for(i=1;i<=n;i++) read(p[i]);
for(i=1;i<=n-1;i++) b[i]=que(p[i],p[i+1]);
d[t=1]=0;
for(i=1;i<=n-1;i++){
while (t&&dep[b[d[t]]]>dep[b[i]]) t--;
le[i]=d[t]+1;
d[++t]=i;
}
d[t=1]=0;
for(i=n-1;i>=1;i--){
while (t&&dep[b[d[t]]]>=dep[b[i]]) t--;
if (d[t]==0) ri[i]=n-1;
else ri[i]=d[t]-1;
d[++t]=i;
}
for(i=1;i<=n-1;i++) ans+=(LL)(ri[i]-i+1)*(i-le[i]+1)*dep[b[i]];
printf("%lld",ans);
}