1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129
| #include<bits/stdc++.h> #define inf 0x3f3f3f3f3f3f3f3fll #define debug(x) cerr<<#x<<"="<<x<<endl using namespace std; using ll=long long; using ld=long double; using pli=pair<ll,int>; using pi=pair<int,int>; template<typename A> using vc=vector<A>; inline int read() { int s=0,w=1;char ch; while((ch=getchar())>'9'||ch<'0') if(ch=='-') w=-1; while(ch>='0'&&ch<='9') s=s*10+ch-'0',ch=getchar(); return s*w; } inline ll lread() { ll s=0,w=1;char ch; while((ch=getchar())>'9'||ch<'0') if(ch=='-') w=-1; while(ch>='0'&&ch<='9') s=s*10+ch-'0',ch=getchar(); return s*w; } const int mod=998244353; const int G=3; const int NG=(mod+1)/G; int wh[530001]; ll f[530001]; ll g[530001]; ll jc[400005]; ll inv[400005]; ll dp[200005]; char s[200005]; int n;int L; inline void Add(ll &a,ll b) { a+=b; if(a>=mod) a-=mod; } inline ll qow(ll a,ll b) { ll ans=1; while(b) { if(b&1) ans=ans*a%mod; a=a*a%mod; b>>=1; } return ans; } inline ll C(int a,int b) { if(a<b||b<0) return 0; return jc[a]*inv[b]%mod*inv[a-b]%mod; } inline void init(int L) { jc[0]=inv[0]=1; for(int i=1;i<=L;i++) { jc[i]=jc[i-1]*i%mod; inv[i]=qow(jc[i],mod-2); } } inline void NTT(ll *f,bool dft=true) { for(int i=0;i<(1<<L);i++) if(i<wh[i]) swap(f[i],f[wh[i]]); for(int len=1;len<=L;len++) { ll step=qow(dft?G:NG,(mod+1)/(1<<len));int P=1<<len; for(int l=0,mid=P>>1,r=P;l<(1<<L);l=r,r+=P,mid+=P) { ll now=1; for(int i=l,j=mid;j<r;i++,j++) { ll num=f[j]*now%mod; f[j]=(f[i]-num+mod)%mod; f[i]=(f[i]+num)%mod; now=now*step%mod; }
} } if(!dft) { ll num=qow(1<<L,mod-2); for(int i=0;i<(1<<L);i++) f[i]=f[i]*num%mod; } } inline void NTT() { NTT(f),NTT(g); for(int i=0;i<(1<<L);i++) f[i]=f[i]*g[i]%mod; NTT(f,0); } int main() { scanf("%s",s+1),n=strlen(s+1);int k=1; while(k<n&&s[k]<=s[k+1]) k++;
dp[k]=1;int r=k; while(k>1&&s[k]==s[k-1]) k--,dp[k]=1; init(2*n); while((1<<L)<=2*n) L++; for(int i=0;i<(1<<L);i++) wh[i]=(wh[i>>1]>>1)|((i&1)?(1<<(L-1)):0); while(k!=1) { k--,dp[k]=1; for(int i=k;i<n;i++) if(s[i+1]>s[k]) Add(dp[i+1],dp[i]);
int mem=k; while(k>1&&s[k]==s[k-1]) k--,dp[k]=1;
while(r<n&&s[r+1]>s[k]) r++; if(mem!=k) { for(int i=mem;i<=r;i++) f[i]=dp[i]; for(int i=0;i<=n;i++) g[i]=C(i+mem-k-1,i); NTT(); for(int i=mem;i<=r;i++) dp[i]=f[i]; } memset(f,0,sizeof(f)); memset(g,0,sizeof(g)); } printf("%lld\n",dp[n]); return 0; }
|