kth value in a subarray

UNDER CONSTRUCTION.

本文总结经典的区间第k小值数据结构题。 给定一个长为n的数组。有m个询问:求区间[l,r)中第k小的元素。

一些方法支持扩展问题:有m个操作,或者修改某个位置上的元素,或者询问区间[l,r)中第k小的元素。

归并树(merge sort tree)

用O(n*log(n))时间构建线段树,每个节点存储对应区间的有序数组。 对于一个询问,二分搜索答案ans转化为计数问题:区间[l,r)内小于ans的元素个数是否大于等于k。 对于这个计数问题,把区间[l,r)解构为不超过log(n)个线段树节点。对于每个节点,二分查找这个节点存储的有序数组里小于ans的元素数。

  • static: O(n*log(n)+m*log(n)^3), space complexity: O(n*log(n)), not recommended

若要支持修改元素,把每个节点存储的有序数组改成一棵binary search tree,这种嵌套树形解构俗称树套树。

描述值域的线段树

用O(n*log(n))时间构建一棵线段树,每个节点描述一个值域区间,存储出现的元素的位置序列。 对于静态问题,位置序列可以是一个有序数组。若要支持修改元素,位置序列得是线段树或binary search tree。

询问使用的区间查询支持区间减法,因此外层的线段树也可改成Fenwick tree,减少一半节点。

  • static: O(n*log(n)+m*log(n)^2), space complexity: O(n*log(n)), not recommended

划分树(range tree with functional cascading)

这是描述值域的线段树的一种优化。用O(n*log(n))时间构建一个描述值域的线段树,每个节点存储值域区间里按顺序出现的元素数组,和一个辅助数组表示分到左孩子的元素个数。 对于一个询问,可以O(1)知道[l,r)中落在左孩子值域的元素个数,判断要在左孩子或在右孩子找答案。

  • static: O((n+m)*log(n)), space complexity: O(n*log(n))

有多组修改和询问。每个询问会受到时间序之前的修改的影响,询问目标可以二分搜索。 这类算法将二分答案应用到多组修改和询问上。

在二分答案后,单点修改的影响为commutative monoid,区间询问的目标也是一个commutative monoid。

  • static: O((n+m)*log(n)^2), space complexity: O(n+m)
  • dynamic: O((n+m)*log(n+m)*log(n)), space complexity: O(n+m)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
#include <algorithm>
#include <cstdio>
#include <utility>
using namespace std;

#define FOR(i, a, b) for (int i = (a); i < (b); i++)
#define REP(i, n) for (int i = 0; i < (n); i++)

const int N = 200000, M = 200000;

namespace {
int ri() {
int m = 0, s = 0; unsigned c;
while ((c = getchar())-'0' >= 10u) m = c == '-';
for (; c-'0' < 10u; c = getchar()) s = s*10+c-'0';
return m ? -s : s;
}

pair<int, int> a[N];
int ans[M], fenwick[N], n;
struct Query { int id, l, r, k; } q[M], qq[M];

void add(int i, int d) {
for (; i < n; i |= i+1)
fenwick[i] += d;
}

int get_sum(int i) {
int sum = 0;
for (; i; i &= i-1)
sum += fenwick[i-1];
return sum;
}

void conquer(int ml, int mh, int l, int h) {
if (ml == mh-1) {
FOR(i, l, h)
ans[q[i].id] = a[ml].first;
return;
}
int mm = ml+mh >> 1, nl = 0, nh = h-l;
FOR(i, ml, mm)
add(a[i].second, 1);
FOR(i, l, h) {
int t = get_sum(q[i].r)-get_sum(q[i].l);
if (q[i].k <= t)
qq[nl++] = q[i];
else
qq[--nh] = q[i], qq[nh].k -= t;
}
FOR(i, ml, mm)
add(a[i].second, -1);
copy_n(qq, nl, q+l);
copy(qq+nh, qq+h-l, q+l+nl);
if (nl) conquer(ml, mm, l, l+nl);
if (l+nl < h) conquer(mm, mh, l+nl, h);
}
}

int main() {
n = ri();
int m = ri();
REP(i, n)
a[i] = {ri(), i};
REP(i, m)
q[i] = {i, ri()-1, ri(), ri()};
sort(a, a+n);
conquer(0, n, 0, m);
REP(i, m)
printf("%d\n", ans[i]);
}
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
// parallel binary search, dynamic
#include <algorithm>
#include <cstdio>
#include <utility>
using namespace std;

#define FOR(i, a, b) for (int i = (a); i < (b); i++)
#define REP(i, n) for (int i = 0; i < (n); i++)

const int N = 100000, M = 100000;

namespace {
int ri() {
int m = 0, s = 0; unsigned c;
while ((c = getchar_unlocked())-'0' >= 10u) m = c == '-';
for (; c-'0' < 10u; c = getchar_unlocked()) s = s*10+c-'0';
return m ? -s : s;
}

int a[N+M], ans[M], fenwick[N], n;
struct Op { int id, l, r, k; } q[N+2*M], qq[N+2*M];

void add(int i, int d) {
for (; i < n; i |= i+1)
fenwick[i] += d;
}

int get_sum(int i) {
int sum = 0;
for (; i; i &= i-1)
sum += fenwick[i-1];
return sum;
}

void conquer(int vl, int vh, int l, int h) {
if (vl == vh-1) {
FOR(i, l, h)
if (q[i].id >= 0)
ans[q[i].id] = a[vl];
return;
}
int vm = vl+vh >> 1, nl = 0, nh = h-l;
FOR(i, l, h) {
auto x = q[i];
if (x.id < 0) {
if (x.k < vm)
qq[nl++] = x, add(x.l, x.r);
else
qq[--nh] = x;
} else {
int t = get_sum(x.r)-get_sum(x.l);
if (x.k <= t)
qq[nl++] = x;
else
x.k -= t, qq[--nh] = x;
}
}
REP(i, nl)
if (qq[i].id < 0)
add(qq[i].l, -qq[i].r);
copy_n(qq, nl, q+l);
reverse_copy(qq+nh, qq+h-l, q+l+nl);
if (nl) conquer(vl, vm, l, l+nl);
if (l+nl < h) conquer(vm, vh, l+nl, h);
}
}

int main() {
n = ri();
int m = ri(), nv = n, nop = n, nq = 0;
REP(i, n) {
a[i] = ans[i] = ri();
q[i] = {-1, i, 1, a[i]};
}
REP(i, m) {
char c;
while (c = getchar_unlocked(), c != 'C' && c != 'Q');
if (c == 'C') {
int j = ri()-1;
q[nop++] = {-1, j, -1, a[j]};
a[j] = a[nv++] = ri();
q[nop++] = {-1, j, 1, a[j]};
} else
q[nop++] = {nq++, ri()-1, ri(), ri()};
}
copy_n(ans, n, a);
sort(a, a+nv);
nv = unique(a, a+nv) - a;
REP(i, nop)
if (q[i].id < 0)
q[i].k = lower_bound(a, a+nv, q[i].k) - a;
conquer(0, nv, 0, nop);
REP(i, nq)
printf("%d\n", ans[i]);
}

可持久化线段树(persistent segment tree)

用O(n*log(n))时间构建n+1棵描述值域的线段树。每棵线段树表示一个原数组的一个前缀(共n+1个)。在每棵线段树中,每个节点存储一个值域区间里的元素数。 相邻两棵线段树描述的区间只相差一个元素,它们可以共用大部分节点,只有ceil(log(n))个节点有差异。

  • static: O((n+m)*log(n)), space complexity: O(n*log(n))
  • dynamic: O((n+m)*log(n)^2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
// persistent segment tree
#include <algorithm>
#include <cstdio>
using namespace std;

#define REP(i, n) for (int i = 0; i < (n); i++)

int ri() {
int m = 0, s = 0; unsigned c;
while ((c = getchar())-'0' >= 10u) m = c == '-';
for (; c-'0' < 10u; c = getchar()) s = s*10+c-'0';
return m ? -s : s;
}

const int N = 200000, M = 200000, LOG2N = 32-__builtin_clz(N-1);
int a[N], b[N], roots[N+1], allo;
struct Segment { int ch[2], cnt; } seg[N*2+M*LOG2N];

void build(int &t, int l, int r) {
t = ++allo;
if (l < r-1) {
int m = l+r >> 1;
build(seg[t].ch[0], l, m);
build(seg[t].ch[1], m, r);
}
}

void add(int *t, int u, int l, int r, int v) {
while (l < r-1) {
*t = ++allo;
seg[*t].cnt = seg[u].cnt+1;
int m = l+r >> 1, d = v >= m;
if (d) l = m;
else r = m;
seg[*t].ch[d^1] = seg[u].ch[d^1];
t = &seg[*t].ch[d];
u = seg[u].ch[d];
}
*t = ++allo;
seg[*t].cnt = seg[u].cnt+1;
}

int kth(int t, int u, int l, int r, int k) {
while (l < r-1) {
int m = l+r >> 1, lcnt = seg[seg[t].ch[0]].cnt-seg[seg[u].ch[0]].cnt, d = k >= lcnt;
if (d) l = m, k -= lcnt;
else r = m;
t = seg[t].ch[d];
u = seg[u].ch[d];
}
return l;
}

int main() {
int n = ri(), m = ri();
REP(i, n) {
a[i] = ri();
b[i] = a[i];
}
sort(b, b+n);
int nn = unique(b, b+n) - b;
build(roots[0], 0, nn);
REP(i, n) {
int v = lower_bound(b, b+nn, a[i]) - b;
add(&roots[i+1], roots[i], 0, nn, v);
}
while (m--) {
int l = ri(), r = ri(), k = ri();
printf("%d\n", b[kth(roots[r], roots[l-1], 0, nn, k-1)]);
}
}

莫涛算法(Mo's algorithm)

  • static: O(n*log(n)+m*sqrt(n)+m*log(m))
  • dynamic (binary search on the value, 二分答案): O(n*log(n)+m*sqrt(n)*log(n)*log(n+m))
  • dynamic (区间 [l,r] 内所有的 x 变成 y, P4119):

静态情形:维护两个频度数组,一个表示元素x的频度,另一个表示元素区间(如[i,i+block_size))的频度。区间长度加减一时,O(1)修改频度。

1
2
c1[a[i]] += d;
c2[block[a[i]]] += d;

询问时O(sqrt(n))扫描频度数组得到答案。

1
2
3
4
5
6
7
int x = 0, k = qs[i].k;
while (c2[x] < k) k -= c2[x++];
for (int j = x*block_size; ; j++)
if ((k -= c1[j]) <= 0) {
qs[i].ans = j;
break;
}

要点在于不要用有序数据结构维护区间内的元素,会不必要增大修改的时间复杂度。

要支持修改元素,可在每个分块里里维护一个有序数组。 修改时重建有序数组。 询问时二分答案ans。在包含的分块里二分搜索小于ans的元素数。在分块外线性遍历至多2*block_size个元素

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
// Mo's algorithm
#include <algorithm>
#include <cmath>
#include <cstdio>
using namespace std;

#define REP(i, n) for (int i = 0; i < (n); i++)

int ri() {
int m = 0, s = 0; unsigned c;
while ((c = getchar())-'0' >= 10u) m = c == '-';
for (; c-'0' < 10u; c = getchar()) s = s*10+c-'0';
return m ? -s : s;
}

const int N = 200000, M = 200000;
int a[N], b[N], block[N], c1[N], c2[N], block_size;
struct Query {
int l, r, k, id, ans;
bool operator<(const Query &o) const {
int i = block[l], j = block[o.l];
if (i != j) return i < j;
return i & 1 ? r < o.r : r > o.r;
}
} qs[M];

static void add(int i, int d) {
c1[a[i]] += d;
c2[block[a[i]]] += d;
}

int main() {
int n = ri(), m = ri();
REP(i, n) {
a[i] = ri();
b[i] = a[i];
}
sort(b, b+n);
int nn = unique(b, b+n) - b;
REP(i, n)
a[i] = lower_bound(b, b+nn, a[i]) - b;
REP(i, m) {
qs[i].l = ri()-1;
qs[i].r = ri();
qs[i].k = ri();
qs[i].id = i;
}
block_size = sqrt(n);
REP(i, n)
block[i] = i/block_size;
sort(qs, qs+m);

int l = 0, r = 0;
REP(i, m) {
while (qs[i].l < l) add(--l, 1);
while (r < qs[i].r) add(r++, 1);
while (l < qs[i].l) add(l++, -1);
while (qs[i].r < r) add(--r, -1);
int x = 0, k = qs[i].k;
while (c2[x] < k) k -= c2[x++];
for (int j = x*block_size; ; j++)
if ((k -= c1[j]) <= 0) {
qs[i].ans = j;
break;
}
}
REP(i, m)
a[qs[i].id] = b[qs[i].ans];
REP(i, m)
printf("%d\n", a[i]);
}