0%

[2023 集训队互测 Round 2] 相等树链

模拟赛题喵。

题目链接

题意

给定两棵点集为 的树

求有多少 的非空子集 满足 上都是一条链。

题解

定义 为两个集合的异或操作即

显然所有 的集合都合法,考虑计算 的总方案数。

表示 这条链上点构成的点集。

进行点分治,假设当前的分治中心是点 ,考虑计算 的方案数。

那么设 ,令

那么显然 ,因为 一定在 的不同子树里面。

此外还有

由于点集 上是 这条链,考虑 在这条链上的位置。

为这条链上,最左侧的, 里的点, 为这条链上最右侧的 里的点。

那么显然

同理对于 有一个

考虑对于 进行一些分讨:

  1. 第一种情况是 的同一棵子树内,这里以 这棵子树为例。

    ,可以发现有

    显然有

  2. 第二种是 不在同一棵子树内。

    ,不妨设

    还是有

考虑对两种情况分别计算方案数。

容易发现最后化成的形式,都是等号左边只和 有关,等号右边只和 有关。

对于第一种情况,我们枚举 ,应该直接能在 上找到对应的 。直接算即可。

对于第二种情况也是同样的,但是同一个 对应了 两个位置,两个都要算进去。

这个时候有一个神秘的情况就是第二种情况里面,会多算上 上位于 同一棵子树的情况,需要去重。

时间复杂度

代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
#include<bits/stdc++.h>
#define inf 0x3f3f3f3f3f3f3f3fll
#define debug(x) cerr<<#x<<"="<<x<<endl
using namespace std;
using ll=long long;
using ld=long double;
using pli=pair<ll,int>;
using pi=pair<int,int>;
template<typename A>
using vc=vector<A>;
template<typename A,const int N>
using aya=array<A,N>;
using ull=unsigned long long;
inline int read()
{
int s=0,w=1;char ch;
while((ch=getchar())>'9'||ch<'0') if(ch=='-') w=-1;
while(ch>='0'&&ch<='9') s=s*10+ch-'0',ch=getchar();
return s*w;
}
inline ll lread()
{
ll s=0,w=1;char ch;
while((ch=getchar())>'9'||ch<'0') if(ch=='-') w=-1;
while(ch>='0'&&ch<='9') s=s*10+ch-'0',ch=getchar();
return s*w;
}
template<const int N,const int M>
struct graph
{
int head[N+5];
int ww[M+5];
int t[M+5];
int x[M+5];
int cntm;
graph(){ cntm=1;}
inline void clear(int n=N)
{
cntm=1;
for(int i=1;i<=n;i++) head[i]=0;
}
inline void ad(int u,int v,int w=0)
{
cntm++;
t[cntm]=v;
x[cntm]=head[u];
ww[cntm]=w;
head[u]=cntm;
}
inline void add(int u,int v,int w=0)
{
ad(u,v,w);
ad(v,u,w);
}
inline int st(int num){ return head[num];}
inline int to(int num){ return t[num];}
inline int nx(int num){ return x[num];}
inline int w(int num){ return ww[num];}
};
graph<200000,400000>g1,g2;
// mt19937_64 _rand(time(0)^clock());
mt19937_64 _rand(0);
bool vis[200005];
int siz[200005];
int tag[200005];
int dep[200005];
int bel[200005];
int ttg[200005];
int a1[200005];
int a2[200005];
ull v1[200005];
ull v2[200005];
ull H[200005];
int f[200005];
int n,S,tim,ttm;
ll ans;int C;
int getroot(int num,int fa)
{
C++;
int ans=0;siz[num]=1,f[num]=0;
for(int i=g1.st(num);i;i=g1.nx(i))
{
int p=g1.to(i);
if(p==fa||vis[p]) continue;
int val=getroot(p,num);
siz[num]+=siz[p];
f[num]=max(f[num],siz[p]);
if(!ans||f[val]<f[ans]) ans=val;
}
f[num]=max(f[num],S-siz[num]);
if(!ans||f[num]<f[ans]) ans=num;
return ans;
}
void dfs1(int num,int fa)
{
tag[num]=tim,v1[num]=v1[fa]^H[num],siz[num]=1;
for(int i=g1.st(num);i;i=g1.nx(i))
{
int p=g1.to(i);
if(p==fa||vis[p]) continue;
dfs1(p,num),siz[num]+=siz[p];
}
}
void dfs2(int num,int fa)
{
v2[num]=v2[fa]^H[num];ttg[num]=ttm;
for(int i=g2.st(num);i;i=g2.nx(i))
{
int p=g2.to(i);
if(p==fa||tag[p]!=tim) continue;
bel[p]=bel[num],dep[p]=dep[num]+1;
dfs2(p,num);
}
}
vc<int>nod;
map<ull,int>m3;
map<ull,int>m4[200005];
using pu=pair<ull,int>;
vc<pu>m1,m2;
int rt;
void dfs3(int num,int fa,int anc)
{
if(ttg[num]!=ttm) return ;
a1[num]=a1[fa],a2[num]=a2[fa];
if(bel[a1[num]]==bel[num]) a1[num]=dep[num]>dep[a1[num]]?num:a1[num];
else if(bel[a2[num]]==bel[num]) a2[num]=dep[num]>dep[a2[num]]?num:a2[num];
else if(a1[num]==rt) a1[num]=num;
else if(a2[num]==rt) a2[num]=num;
else return ;

// printf("dfs3 %d %d : %d %d\n",num,fa,a1[num],a2[num]);

m1.push_back(pu(v1[num],anc));
m2.push_back(pu(v1[num]^v2[a1[num]]^v2[a2[num]],anc));
ans+=m3[v1[num]^v2[a1[num]]]-m4[bel[a1[num]]][v1[num]^v2[a1[num]]];
if(a2[num]!=rt) ans+=m3[v1[num]^v2[a2[num]]]-m4[bel[a2[num]]][v1[num]^v2[a2[num]]];
nod.push_back(num);

for(int i=g1.st(num);i;i=g1.nx(i))
{
int p=g1.to(i);
if(p==fa||vis[p]) continue;
dfs3(p,num,anc);
}
}
void solve(int u)
{
// printf("solve %d\n",u);
vis[u]=1;rt=u;
tim++,v1[u]=0,dfs1(u,u),ttm++;
dep[u]=0,bel[u]=u,v2[u]=H[u];
for(int i=g2.st(u);i;i=g2.nx(i))
{
int p=g2.to(i);
if(tag[p]!=tim) continue;
bel[p]=p,dep[p]=1,dfs2(p,u);
m4[p].clear();
}
m1.clear(),m2.clear(),m3.clear();
m1.push_back(pu(H[u],0));
m2.push_back(pu(H[u],0));
a1[u]=a2[u]=u;
for(int i=g1.st(u);i;i=g1.nx(i))
{
int p=g1.to(i);if(vis[p]) continue;
dfs3(p,u,p);
for(int q:nod)
{
m3[v1[q]^v2[a1[q]]]++,m4[bel[a1[q]]][v1[q]^v2[a1[q]]]++;
if(a2[q]!=u) m3[v1[q]^v2[a2[q]]]++,m4[bel[a2[q]]][v1[q]^v2[a2[q]]]++;
}
nod.clear();
}
sort(m1.begin(),m1.end());
sort(m2.begin(),m2.end());
//case1
unsigned x=0,y=0;
while(x<m1.size()&&y<m2.size())
{
ull mem=min(m1[x].first,m2[y].first);int cx=0,cy=0;
while(x<m1.size()&&m1[x].first==mem) cx++,x++;
while(y<m2.size()&&m2[y].first==mem) cy++,y++;
ans+=(ll)cx*cy;
}
//case2
x=y=0;
while(x<m1.size()&&y<m2.size())
{
pu mem=min(m1[x],m2[y]);int cx=0,cy=0;
while(x<m1.size()&&m1[x]==mem) cx++,x++;
while(y<m2.size()&&m2[y]==mem) cy++,y++;
ans-=(ll)cx*cy;
}
// printf("finsh %d : %lld\n",u,ans);
for(int i=g1.st(u);i;i=g1.nx(i))
{
int p=g1.to(i);
if(vis[p]) continue;
S=siz[p],solve(getroot(p,u));
}
}
int main()
{
S=n=read();
for(int i=2;i<=n;i++) g1.add(read(),i);
for(int i=2;i<=n;i++) g2.add(read(),i);
for(int i=1;i<=n;i++) H[i]=_rand();
solve(getroot(1,1));
printf("%lld\n",ans+n);
cerr<<C<<endl;
return 0;
}