P4689 [Ynoi2016] 这是我自己的发明 与 P5268 [SNOI2017] 一个简单的询问0

cnblogs 2024-08-18 16:39:00 阅读 50

P4689 [Ynoi2016] 这是我自己的发明 与 P5268 [SNOI2017] 一个简单的询问0

讲解 P4689 [Ynoi2016] 这是我自己的发明 与 P5268 [SNOI2017] 一个简单的询问。

先将树转化为 dfn 序,然后根据特殊性质得到区间,然后使用莫队算法,

思路:

首先可以先考虑没有换根的情况。

先将树拍到 dfn 序上,那么一个子树 \(u\) 的所有点的 dfn 序区间为 \([dfn_u,dfn_u+siz_u-1]\)

那么询问变为:

    <li>

    每次给定两个区间 \([l_1,r_1],[l_2,r_2]\),对于在第一个区间内的点 \(x\) 和在第二个区间的点 \(y\),若 \((x,y)\) 有贡献,当且仅当 \(w_x=w_y\)

  • 询问有贡献的点对数量。

即 P5268 [SNOI2017] 一个简单的询问。

\(F(l_1,r_1,l_2,r_2)\) 表示 \([l_1,r_1]\)\([l_2,r_2]\) 的贡献,那么:

\[F(l_1,r_1,l_2,r_2) = F(1,r_1,1,r_2) - F(1,l_1-1,1,r_2) - F(1,r_1,1,l_2-1) - F(1,l_1-1,1,l_2-1)

\]

那么一个询问就都转化为了四个 \(F(1,x,1,y)\) 的形式,考虑如何求 \(F(1,x,1,y)\),先钦定 \(x \le y\),那么考虑莫队:

  • 设当前 \(p_{1,x},p_{2,x}\) 分别表示两个区间 \(x\) 的出现次数。

  • \(x \gets x+1\) 时,贡献会增加 \(p_{2,a_{x+1}}\)

  • \(x \gets x-1\) 时,贡献会减少 \(p_{2,a_x}\)

  • \(y \gets y+1\) 时,贡献会增加 \(p_{1,a_{y+1}}\)

  • \(y \gets y-1\) 时,贡献会减少 \(p_{1,a_y}\)

现在再考虑换根操作,若当前以 \(rt\) 为根:

  • \(rt\) 不在初始以 \(1\) 为根时 \(x\) 的子树内,则不好造成影响。

  • 否则 \(x\) 子树内的点即为除了\((x \to rt)\) 路径上最接近 \(x\) 的点 \(y\) 子树内的点的全部点。

因为 \(x\) 在原始树上始终是 \(rt\) 的父亲,则 \(y\)\(rt\)\(dep_{rt}-dep_{x}-1\) 级祖先,直接倍增即可。

时间复杂度为 \(O(N\sqrt{M}+M \log N+M)\)

完整代码:

<code>#include<bits/stdc++.h>

#define Add(x,y) (x+y>=mod)?(x+y-mod):(x+y)

#define lowbit(x) x&(-x)

#define pi pair<ll,ll>

#define pii pair<ll,pair<ll,ll>>

#define iip pair<pair<ll,ll>,ll>

#define ppii pair<pair<ll,ll>,pair<ll,ll>>

#define fi first

#define se second

#define full(l,r,x) for(auto it=l;it!=r;it++) (*it)=x

#define Full(a) memset(a,0,sizeof(a))

#define open(s1,s2) freopen(s1,"r",stdin),freopen(s2,"w",stdout);

#define For(i,l,r) for(int i=l;i<=r;i++)

#define _For(i,l,r) for(int i=r;i>=l;i--)

using namespace std;

typedef double db;

typedef unsigned long long ull;

typedef long long ll;

bool Begin;

const ll N=1e5+10,M=4e6+10,K=17;

inline ll read(){

ll x=0,f=1;

char c=getchar();

while(c<'0'||c>'9'){

if(c=='-')

f=-1;

c=getchar();

}

while(c>='0'&&c<='9'){

x=(x<<1)+(x<<3)+(c^48);

c=getchar();

}

return x*f;

}

inline void write(ll x){

if(x<0){

putchar('-');

x=-x;

}

if(x>9)

write(x/10);

putchar(x%10+'0');

}

ll op,n,m,t,q,u,v,rt,sum,l1,r1,l2,r2,l,r,cnt;

ll A[N],a[N],b[N],w[N],d[N],siz[N],dfn[N],p1[N],p2[N],ans[M];

ll F[N][K];

vector<pi> X,Y;

vector<ll> E[N];

struct Ques{

ll x,y;

ll id;

ll v;

inline bool operator<(const Ques &rhs)const{

if(A[x]^A[rhs.x])

return A[x]<A[rhs.x];

return y>rhs.y;

}

}Q[M];

inline void add(ll u,ll v){

E[u].push_back(v);

E[v].push_back(u);

}

inline void dfs(ll u,ll fa){

For(i,1,K-1)

F[u][i]=F[F[u][i-1]][i-1];

dfn[u]=++cnt;

w[cnt]=a[u];

siz[u]=1;

for(auto v:E[u]){

if(v==fa)

continue;

F[v][0]=u;

d[v]=d[u]+1;

dfs(v,u);

siz[u]+=siz[v];

}

}

inline ll get_fa(ll u,ll k){

_For(i,0,K-1){

if((k>>i)&1ll){

k-=(1ll<<i);

u=F[u][i];

}

}

return u;

}

inline vector<pi> get(ll x){

vector<pi> ans;

if(x==rt)

ans.push_back({1,n});

else if(dfn[x]<=dfn[rt]&&dfn[rt]<=dfn[x]+siz[x]-1){

ll y=get_fa(rt,d[rt]-d[x]-1);

if(dfn[y]!=1)

ans.push_back({1,dfn[y]-1});

if(dfn[y]+siz[y]<=n)

ans.push_back({dfn[y]+siz[y],n});

}

else

ans.push_back({dfn[x],dfn[x]+siz[x]-1});

return ans;

}

inline void get(ll l1,ll r1,ll l2,ll r2){

Q[++q]={r1,r2,cnt,1};

if(l1-1)

Q[++q]={l1-1,r2,cnt,-1};

if(l2-1)

Q[++q]={r1,l2-1,cnt,-1};

if(l1-1&&l2-1)

Q[++q]={l1-1,l2-1,cnt,1};

}

inline void insert1(ll x){

sum+=p2[w[x]];

p1[w[x]]++;

}

inline void insert2(ll x){

sum+=p1[w[x]];

p2[w[x]]++;

}

inline void del1(ll x){

sum-=p2[w[x]];

p1[w[x]]--;

}

inline void del2(ll x){

sum-=p1[w[x]];

p2[w[x]]--;

}

bool End;

int main(){

n=read(),m=read();

For(i,1,n){

a[i]=read();

b[++cnt]=a[i];

}

sort(b+1,b+cnt+1);

cnt=unique(b+1,b+cnt+1)-(b+1);

For(i,1,n)

a[i]=lower_bound(b+1,b+cnt+1,a[i])-b;

cnt=0;

For(i,1,n-1){

u=read(),v=read();

add(u,v);

}

dfs(1,1);

cnt=0;

For(i,1,m){

op=read(),u=read();

if(op==1){

rt=u;

continue;

}

++cnt;

v=read();

X=get(u);

Y=get(v);

for(auto x:X)

for(auto y:Y)

get(x.fi,x.se,y.fi,y.se);

}

t=max(n/max((ll)sqrt(m),1ll),1ll);

For(i,1,n)

A[i]=(i-1)/t+1;

For(i,1,q)

if(Q[i].x>Q[i].y)

swap(Q[i].x,Q[i].y);

sort(Q+1,Q+q+1);

For(i,1,q){

while(l<Q[i].x)

insert1(++l);

while(l>Q[i].x)

del1(l--);

while(r<Q[i].y)

insert2(++r);

while(r>Q[i].y)

del2(r--);

ans[Q[i].id]+=sum*Q[i].v;

}

For(i,1,cnt){

write(ans[i]);

putchar('\n');

}

//cerr<<'\n'<<abs(&Begin-&End)/1048576<<"MB";

return 0;

}



声明

本文内容仅代表作者观点,或转载于其他网站,本站不以此文作为商业用途
如有涉及侵权,请联系本站进行删除
转载本站原创文章,请注明来源及作者。