dp + 树状数组/线段树 优化
由于每次查询的是从1开始前缀的max,不涉及其他区间,因此可以用树状数组来维护区间最值。
#include <bits/stdc++.h>
using namespace std;
typedef long long LL;
const int N = 200010;
const LL INF = 1e18;
LL tr[N], a[N], s[N], f[N];
vector<LL> num;
int n;
int lowbit(int x){
return x & -x;
}
void modify(int x, LL c){
//这里需要是nums.size(),不能是n,因为加入-INF和0后可能导致nums.size() > n
for(int i = x; i <= num.size(); i += lowbit(i)){
tr[i] = max(tr[i], c);
}
}
LL query(int x){
LL ans = -INF;
for(int i = x; i; i -= lowbit(i)){
ans = max(ans, tr[i]);
}
return ans;
}
int find(LL x){
int l = 1, r = num.size() - 1;
while(l < r){
int mid = l + r >> 1;
if(num[mid] >= x) r = mid;
else l = mid + 1;
}
return l;
}
void solve(){
scanf("%d", &n);
num.push_back(-INF); //将下标调整到从1开始
num.push_back(0); //f[0]对应的初值情况
for(int i = 1; i <= n; i++) {
scanf("%lld", &a[i]);
s[i] = s[i - 1] + a[i];
num.push_back(s[i]);
}
sort(num.begin(), num.end());
num.erase(unique(num.begin(), num.end()), num.end());
f[0] = 0;
for(int i = 0; i < N; i++) tr[i] = -INF; //注意需要赋初值为-INF,因为modify的c可能为负数
modify(find(0), f[0] - 0);
for(int i = 1; i <= n; i++){
f[i] = max(f[i - 1], 1LL* query(find(s[i])) + i);
modify(find(s[i]), f[i] - i);
}
printf("%lld\n", f[n]);
}
int main(){
int T = 1;
// scanf("%d", &T);
while(T -- ){
solve();
}
return 0;
}