0%

[ZJOI2013]K大数查询

有N个位置,M个操作。
1 a b c形式,表示在第a个位置到第b个位置,每个位置加入一个数c。
2 a b c形式,表示询问从第a个位置到第b个位置,第C大的数是多少。

区间的第k大值有一种二分的做法。
二分答案mid,计算出区间内>mid的值有多少个。
若数量小于C,则ans<mid,否则ans>=mid。

考虑离线做法,将所有询问一起二分。
每次二分扫描所有操作,维护一个线段树记录在一段区间内有多少个>mid的数。
对于一个询问,直接查询在区间内的>a的数量即可知道ans<mid或>=mid。
但是操作数量也很多,如果每次二分扫描所有操作也会超时。

操作也可以分治,根据c与mid的大小关系进行分治。将操作序列整体二分。
把c<=mid的修改放到左边,反之放在右边。
将ans<mid的的询问放到左边,反之放到右边。
实现时,若用树状数组,清空是把加上去的减回来;若用线段树,打标记清空。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
#include<iostream>
#include<cstdio>
using namespace std;
typedef long long LL;
struct node {
int op,L,R,idx;
LL val;
} qu[50005],q1[50005],q2[50005];
int ans[50005],cnt;
LL c1[50005],c2[50005];
int n,m;
inline LL read() {
LL sum=0;char ch=0;int f=1;
while ((ch>'9' || ch<'0')&&(ch!='-')) ch=getchar();
if (ch=='-') f=-1,ch=getchar();
while (ch>='0' && ch<='9') sum=sum*10+ch-'0',ch=getchar();
return sum*f;
}
inline int lowbit(int x) {return x&(-x);}
inline void update(int x,LL val) {
LL tmp=x*val;
for (;x<=n;x+=lowbit(x)) c1[x]+=val,c2[x]+=tmp;
}
inline LL getsum(int x) {
LL a=0,b=0,tmp=x+1;
for (;x>0;x-=lowbit(x)) a+=c1[x],b+=c2[x];
return a*tmp-b;
}
inline void Update(int l,int r,LL val) {
update(l,val);
update(r+1,-val);
}
inline LL query(int l,int r) {
return getsum(r)-getsum(l-1);
}
void solve(int l,int r,int s,int t) {
if (s>t) return;
if (l==r) {
for (int i=s;i<=t;i++)
if (qu[i].op==2) ans[qu[i].idx]=l;
return;
}
int mid=(l+r)>>1;
int s1=0,s2=0;
for (int i=s;i<=t;i++) {
if (qu[i].op==1) {
if (qu[i].val>mid) {
Update(qu[i].L,qu[i].R,1);
q2[++s2]=qu[i];
}
else q1[++s1]=qu[i];
}
else {
LL tmp=query(qu[i].L,qu[i].R);
if (tmp<qu[i].val) {
qu[i].val-=tmp;
q1[++s1]=qu[i];
}
else q2[++s2]=qu[i];
}
}
for (int i=s;i<=t;i++) {
if (qu[i].op==1 && qu[i].val>mid) Update(qu[i].L,qu[i].R,-1);
}
for (int i=1;i<=s1;i++) qu[s+i-1]=q1[i];
for (int i=1;i<=s2;i++) qu[t-s2+i]=q2[i];
solve(l,mid,s,s+s1-1);
solve(mid+1,r,s+s1,t);
}
int main() {
n=read();m=read();int mx=0,mn=1e9;
for (int i=1;i<=m;i++) {
qu[i].op=read();qu[i].L=read();
qu[i].R=read();qu[i].val=read();
if (qu[i].op==2) qu[i].idx=++cnt;
else mx=max(mx,(int)qu[i].val),mn=min(mn,(int)qu[i].val);
}
solve(mn,mx,1,m);
for (int i=1;i<=cnt;i++) printf("%d\n",ans[i]);
return 0;
}