0%

[QOJ4215] Easiest Sum

要讲的题喵。

似乎和 114514 年前的模拟赛题撞了但我并没有发现

题目链接

题意

对于一个序列,定义其权值为最大非空子段和。

给定一个长度为 的序列

现在可以进行 次操作,每一次可以给一个位置的数字减去

求操作之后整个序列权值的最小值,输出所有 的答案的和。

题解

的前缀和序列。

表示,最少要用几次操作,使得整个序列权值 。那么首先考虑 怎么求。

表示前 个位置一共操作了多少次,那么首先有 ,以及

其次,对于任何一个区间 都有

那么考虑直接贪心,我们令

看起来不太好求,那么我们令 ,得到

那么这样就可以在 时间内求出 了。

我们还可以看出一些性质。

首先,,这个比较显然。

考虑如果有一个权值 的序列,那给每一个位置都减掉 ,新序列显然权值

然后还可以发现 是凸的。(这个在后面的推导中可以直接看出来)

那么怎么求答案呢,再来考虑这么一个问题。

如果我们强制限定需要走 边(那么首先有 的贡献),此时 最大值是多少。

不难发现,这相当于需要找到序列 恰好 个不相交子区间,使得子区间的和最大。

这个问题比较容易通过数据结构模拟费用流的方法解决。

这个东西还是有原题的,具体做法是每一次用线段树找到最大子段和,然后直接给这一段全部取负。

但是可以发现这个只能找出来 段,并不是恰好 段。

但是我们只需要给这个过程稍微修改一下。

对于前面的若干次,我们找到的最大子段和一定 ,这个时候直接加上去就好了。

对于后面的次数,先把前面的一个一个拆开,再把负数从大到小加进去就好了。

对于一个固定的 ,设其恰好走 边的代价为

求出所有的 之后,我们就可以直接知道

我们直接就可以使用李超树支持查询,但其实没有必要。

我们求出这 条直线之后,直接求一个直线的下凸壳,然后直接在凸壳上算一算答案就好了。

时间复杂度是

代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
#include<bits/stdc++.h>
#define inf 0x3f3f3f3f3f3f3f3fll
#define debug(x) cerr<<#x<<"="<<x<<endl
using namespace std;
using ll=long long;
using ld=long double;
using pli=pair<ll,int>;
using pi=pair<int,int>;
template<typename A>
using vc=vector<A>;
template<typename A,const int N>
using aya=array<A,N>;
inline int read()
{
int s=0,w=1;char ch;
while((ch=getchar())>'9'||ch<'0') if(ch=='-') w=-1;
while(ch>='0'&&ch<='9') s=s*10+ch-'0',ch=getchar();
return s*w;
}
inline ll lread()
{
ll s=0,w=1;char ch;
while((ch=getchar())>'9'||ch<'0') if(ch=='-') w=-1;
while(ch>='0'&&ch<='9') s=s*10+ch-'0',ch=getchar();
return s*w;
}
const int mod=998244353;
const int inv2=(mod+1)/2;
struct node
{
int l,r;ll v;
node(int l=0,int r=0,ll v=0):l(l),r(r),v(v){}
node operator + (node b)
{
node ans;
ans.v=v+b.v;
ans.l=l,ans.r=b.r;
return ans;
}
bool operator < (node b) const { return v<b.v;}
};
struct Tree
{
node v0,v1,ans;
node v2,v3,wtf;
node all,rev;
inline void swap()
{
::swap(all,rev);
::swap(v0,v2);
::swap(v1,v3);
::swap(ans,wtf);
}
inline void init(int w,int v)
{
v0=v1=ans=all=node(w,w,v);
v2=v3=wtf=rev=node(w,w,-v);
}
}num[400005];
Tree operator + (Tree l,Tree r)
{
Tree ans;
ans.all=l.all+r.all;
ans.rev=l.rev+r.rev;
ans.v0=max(l.v0,l.all+r.v0);
ans.v1=max(r.v1,l.v1+r.all);
ans.v2=max(l.v2,l.rev+r.v2);
ans.v3=max(r.v3,l.v3+r.rev);
ans.ans=max(max(l.ans,r.ans),l.v1+r.v0);
ans.wtf=max(max(l.wtf,r.wtf),l.v3+r.v2);
return ans;
}
int sta[100005],top;
ll fr[100005];
ll t[100005];
bool tag[400005];
int a[100005];
ll ans,all;
int n;
inline void T(int p)
{
tag[p]^=1;
num[p].swap();
}
inline void push_down(int p)
{
if(tag[p]) T(p*2),T(p*2|1),tag[p]=0;
}
void build(int p,int pl,int pr)
{
if(pl==pr)
{
num[p].init(pl,a[pl]);
return ;
}
int mid=(pl+pr)>>1;
build(p*2,pl,mid);
build(p*2|1,mid+1,pr);
num[p]=num[p*2]+num[p*2|1];
}
void change(int p,int pl,int pr,int l,int r)
{
if(l<=pl&&pr<=r){ T(p);return ;}
int mid=(pl+pr)>>1;push_down(p);
if(l<=mid) change(p*2,pl,mid,l,r);
if(mid<r) change(p*2|1,mid+1,pr,l,r);
num[p]=num[p*2]+num[p*2|1];
}
inline ll get(int w,ll x)
{
return t[w]-w*x;
}
inline ll SUM(ll R,ll d,ll c)
{
ll L=R-(c-1)*d;
L%=mod,R%=mod;
return (L+R)*(c%mod)%mod*inv2%mod;
}
int main()
{
n=read();
priority_queue<int>que;
for(int i=1;i<=n;i++)
{
a[i]=read();all+=a[i];
if(a[i]<=0) que.push(-a[i]);
}

build(1,1,n);
int now=0;ll sum=0;
while(num[1].ans.v>0)
{
t[++now]=sum=sum+num[1].ans.v;
change(1,1,n,num[1].ans.l,num[1].ans.r);
}
while(now<n) t[++now]=sum;
while(que.size()) t[now--]=all,all+=que.top(),que.pop();


for(int i=n;i>=0;i--)
{
while(top&&get(sta[top],fr[top])<=get(i,fr[top])) top--;

if(!top) sta[++top]=i,fr[top]=-2e13;
else
{
ll L=fr[top],R=2e13;
while(L<R)
{
ll mid=(L+R)>>1;
if(get(sta[top],mid)<=get(i,mid)) R=mid;
else L=mid+1;
}
sta[++top]=i,fr[top]=L;
}
}

ll R=lread();fr[top+1]=2e13+1;
for(int i=1;i<top;i++)
{
ll L=get(sta[i+1],fr[i+1])+1;assert(L>=1);
if(L>R) continue;

ll S=get(sta[i],fr[i+1]-1);
if(R<S) (ans+=(R-L+1)%mod*(fr[i+1]%mod))%=mod;
else
{
(ans+=(S-L)%mod*(fr[i+1]%mod))%=mod;
ll len=R-S+1;
(ans+=(fr[i+1]-1-(len-1)/sta[i])%mod*(len%sta[i]))%=mod;
(ans+=SUM(fr[i+1]-1,1,len/sta[i])*sta[i])%=mod;
}
R=L-1;
}
ans=(ans%mod+mod)%mod;
printf("%lld\n",ans);
return 0;
}