0%

[QOJ9406] Triangle

一个超级难写的大字符串题。

题目链接

题意

给定 个字符串,第 个字符串为

称三个字符串 能组成一个三角形当且仅当:

对于任意的 ,有 至少满足一个。

题解

表示有多少字符串 表示有多少字符串恰好等于

假设有三个字符串 满足 ,考虑他们什么时候能组成三角形。

分情况讨论,第一种情况是

容易发现对于这种情况,一定能形成一个三角形。

那么考虑 的情况,此时需要满足

那么可以发现,必然有一个是 的前缀

否则因为 ,则 ,同理 ,不满足条件。

假设 的前缀,设 ,则 非空。

此时需要有 ,即 时才可以满足条件。

考虑枚举 ,然后枚举 的长度,则合法的 数量为

这部分枚举显然是 复杂度的。

但是考虑,有一些情况会被算重。

考虑若 同时为 的前缀,且每一个字符串都大于另一个的后缀,则这个三元组会被算两次。

,则我们还需要统计 的情况数。

对于 的情况单独处理,这部分直接减去 即可。

对于 的情况,这相当于一个二维偏序问题,可以先拉出所有 的前缀和后缀排序然后跑计数。

最后有一个问题, 怎么求。

考虑 怎么求,直接所有串拉一起建一个后缀数组,然后扫一遍。

小的字符串后缀数组里一定在 的前面,但是注意到前面的字符串比一定都比 小。

考虑先算出每一个后缀的哈希值,然后用『 前面串的个数』减去『 前面与 相等的个数』,就是比 小的字符串个数。

的话,哈希值都求出来了,直接统计就可以了。

云一下大概能感觉出来,很难写。

写完了,确实很难写

代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
#include<bits/stdc++.h>
#define inf 0x3f3f3f3f3f3f3f3fll
#define debug(x) cerr<<#x<<"="<<x<<endl
using namespace std;
using ll=long long;
using ld=long double;
using pli=pair<ll,int>;
using pi=pair<int,int>;
template<typename A>
using vc=vector<A>;
using pl=pair<ll,ll>;
inline int read()
{
int s=0,w=1;char ch;
while((ch=getchar())>'9'||ch<'0') if(ch=='-') w=-1;
while(ch>='0'&&ch<='9') s=s*10+ch-'0',ch=getchar();
return s*w;
}
inline ll lread()
{
ll s=0,w=1;char ch;
while((ch=getchar())>'9'||ch<'0') if(ch=='-') w=-1;
while(ch>='0'&&ch<='9') s=s*10+ch-'0',ch=getchar();
return s*w;
}
const int mod1=998244853,mod2=1000000009;
const int base=1145141;
struct node{ int a,b,c;}v[300001],vv[300001];
int son[300005][26];
int num[300005];//这个节点对应多少字符串
int sa[600005],rk[1200005];
int num1[600005],num2[600005];
int tmp[600005],cnt[600005];
ll p1[600005],p2[600005];
ll h1[600005],h2[600005];
ll ff[600005],cc[600005];
int ed[300005],id[300005];
int hi[600005],st[21][600005];
char t[600005];
char s[300005];
int c,tot,n,rt;
ll ans;
inline void clear()
{
c=tot=rt=0;
}
inline int ins(int &p,char *s)
{
if(!p) p=++tot,num[p]=0,memset(son[p],0,sizeof(son[p]));
if(!s[0]){ num[p]++;return p;}
t[++c]=s[0];return ins(son[p][s[0]-'a'],s+1);
}
inline void SA()
{
int lim=max(27,c);
for(int i=1;i<=c;i++) rk[i]=t[i]=='$'?1:t[i]-'a'+2;
for(int i=c+1;i<=2*c;i++) rk[i]=0;
for(int P=0;(1<<P)<=c;P++)
{
memset(cnt,0,sizeof(int)*(lim+2));
for(int i=1;i<=c;i++) num1[i]=rk[i],num2[i]=rk[i+(1<<P)],cnt[num2[i]+1]++;
for(int i=1;i<=lim;i++) cnt[i]+=cnt[i-1];
for(int i=1;i<=c;i++) tmp[++cnt[num2[i]]]=i;

memset(cnt,0,sizeof(int)*(lim+2));
for(int i=1;i<=c;i++) cnt[num1[i]+1]++;
for(int i=1;i<=lim;i++) cnt[i]+=cnt[i-1];
for(int i=1;i<=c;i++) sa[++cnt[num1[tmp[i]]]]=tmp[i];

for(int i=1;i<=c;i++)
{
int p=sa[i-1],q=sa[i];
if(num1[p]==num1[q]&&num2[p]==num2[q]) rk[q]=rk[p];
else rk[q]=rk[p]+1;
}
}
// printf("%s c=%d\n",t+1,c);
// for(int i=1;i<=c;i++) printf("%2d%c",sa[i]," \n"[i==c]);
// for(int i=1;i<=c;i++) printf("%2d%c",rk[i]," \n"[i==c]);

for(int i=1,j=0;i<=c;i++)
{
if(j) j--;
int p=sa[rk[i]-1];
while(t[i+j]==t[p+j]) j++;
hi[rk[i]]=j;
// printf("%d and %d : %d\n",i,p,j);
}
// for(int i=1;i<=c;i++) printf("%2d%c",hi[i]," \n"[i==c]);

for(int i=1;i<=c;i++) st[0][i]=hi[i];
for(int j=1;(1<<j)<=c;j++) for(int i=1;i+(1<<j)-1<=c;i++) st[j][i]=min(st[j-1][i],st[j-1][i+(1<<(j-1))]);
}
inline int get(int w1,int w2)
{
if(w1==w2) return c-w1+1;
w1=rk[w1],w2=rk[w2];if(w1>w2) swap(w1,w2);
int num=31-__builtin_clz(w2-w1);
return min(st[num][w1+1],st[num][w2-(1<<num)+1]);
}
inline bool check(int l1,int r1,int l2,int r2)
{
//是否有 t[l1,r1] > t[l2,r2]
if(l1==l2) return r1>r2;
int lcp=min(get(l1,l2),min(r1-l1+1,r2-l2+1));
// printf("check %d %d %d %d lcp=%d\n",l1,r1,l2,r2,lcp);
if(l1+lcp<=r1&&l2+lcp<=r2) return t[l1+lcp]>t[l2+lcp];
return l1+lcp<=r1;
}
int tt[600001];
inline int lowbit(int i){ return i&(-i);}
inline void add(int x,int y){ while(x<=2*c) tt[x]+=y,x+=lowbit(x);}
inline int get(int x){ int ans=0;while(x) ans+=tt[x],x-=lowbit(x);;return ans;}
inline void solve()
{
n=read();
for(int i=1;i<=n;i++)
{
scanf("%s",s+1);t[++c]='$';
id[i]=ins(rt,s+1),ed[i]=c;
}
t[c+1]=0;SA();

p1[0]=p2[0]=1;
for(int i=1;i<=c;i++) p1[i]=p1[i-1]*base%mod1,p2[i]=p2[i-1]*base%mod2;

int len=-1;ll now1=0,now2=0;
for(int i=c;i;i--)
{
if(t[i]=='$') len=-1,now1=now2=h1[i]=h2[i]=0;
else
{
len++;
h1[i]=now1=(now1+p1[len]*(t[i]-'a'+1))%mod1;
h2[i]=now2=(now2+p2[len]*(t[i]-'a'+1))%mod2;
}
}

map<pl,int>vis;int P=0;
for(int i=1;i<=c;i++)
{
int p=sa[i];ff[p]=cc[p]=0;
if(t[p]=='$') continue;
if(t[p-1]=='$') P++,vis[pl(h1[p],h2[p])]++;
ff[p]=P-vis[pl(h1[p],h2[p])];
}
for(int i=1;i<=c;i++)
{
int p=sa[i];
if(t[p]!='$') cc[p]=vis[pl(h1[p],h2[p])];
}

// for(int i=1;i<=c;i++) printf("%lld%c",ff[i]," \n"[i==c]);
// for(int i=1;i<=c;i++) printf("%lld%c",cc[i]," \n"[i==c]);

ans=0;ll v1=0,v2=0;
for(int i=1;i<=n;i++)//枚举x
{
int now=rt,st=ed[i-1]+2,ed=::ed[i];
vc<pi>V;
for(int j=st;j<ed;j++)
{
now=son[now][t[j]-'a'];
ans+=(ll)num[now]*max(0ll,ff[st]-ff[j+1]-cc[j+1]);
// printf("st=%d ed=%d j=%d : %d %lld\n",st,ed,j,num[now],max(0ll,ff[st]-ff[j+1]-cc[j+1]));
if(num[now]&&check(st,j,j+1,ed))
{
// printf("%d ~ %d : %d\n",st,j,num[now]);
ans-=(ll)num[now]*(num[now]+1)/2;
}

V.push_back(pi(st,j));
V.push_back(pi(j+1,ed));
v[j-st+1].c=num[now];
}
int f=ff[st],c=cc[st],all=ed-st;
v1+=(ll)(c-1)*(c-2)/2;
v2+=(ll)(c-1)*f;

sort(V.begin(),V.end(),[](pi a,pi b)
{
return check(b.first,b.second,a.first,a.second);
});
int rk=0;
for(unsigned i=0;i<V.size();i++)
{
if(!i||check(V[i].first,V[i].second,V[i-1].first,V[i-1].second)) rk++;
if(V[i].first==st) v[V[i].second-st+1].a=rk;
else v[V[i].first-st].b=rk;
}
memcpy(vv,v,sizeof(node)*(all+1));
sort(v+1,v+all+1,[](node a,node b){ return a.a<b.a;});
sort(vv+1,vv+all+1,[](node a,node b){ return a.b<b.b;});

// for(int i=1;i<=all;i++) if(v[i].c) printf("%d : %d %d %d\n",i,v[i].a,v[i].b,v[i].c);
// putchar('\n');
// for(int i=1;i<=all;i++) if(vv[i].c) printf("%d : %d %d %d\n",i,vv[i].a,vv[i].b,vv[i].c);

ll val=0,sum=0;now=1;
for(int j=1;j<=all;j++)
{
while(now<=all&&vv[now].b<v[j].a) add(vv[now].a,vv[now].c),sum+=vv[now].c,now++;
val+=v[j].c*(sum-get(v[j].b));
}
// printf("val=%lld\n",val);
while(now>1) now--,add(vv[now].a,-vv[now].c);
for(int i=1;i<=all;i++) if(v[i].a>v[i].b) val-=(ll)v[i].c*v[i].c;
ans-=val/2;// printf("val=%lld ans=%lld\n",val,ans);
}
ans+=v1/3+v2/2;
printf("%lld\n",ans);
}
int main()
{
int T=read();
while(T--) clear(),solve();
return 0;
}
/*
1
3
aaa
aaa
aaaa
ans=1
*/