1236. 递增三元组
摘要
Title: 1236. 递增三元组
Tag: 前缀和、二分、双指针
Memory Limit: 64 MB
Time Limit: 1000 ms
Powered by:NEFU AB-IN
1236. 递增三元组
-
题意
给定三个整数数组
A=[A1,A2,…AN],
B=[B1,B2,…BN],
C=[C1,C2,…CN],
请你统计有多少个三元组 (i,j,k) 满足:
1≤i,j,k≤N
Ai<Bj<Ck -
思路
核心思路: 枚举B的每个元素,求出A中小于这个元素的个数 乘 B中大于这个元素的个数
- 前缀和
- 以A[i]为下标,1为值(即桶排),放入前缀和数组,并做前缀和。那么即为答案
- 同理,
- 仅适合于元素的值比较小的情况,不然数组开不下
- 下标都+1, 因为会有0,0这个位置需要空出来,所以所有数整体偏移,不影响大小关系
- 二分
- 求最后一个小于B[i]的数 等价于 求第一个大于等于B[i]的数 - 1
- 求第一个大于B[i]的数 等价于 求最后一个小于等于B[i]的数 + 1
- 所以二分数组时,通常可以在数组两边加上边界值 -INF 和 INF
- 如果实在记不住,最推荐用库中自带的二分的轮子,会帮你自动填上上下界
- 比如 在
1 1 1
中找最后一个小于2的数,可以直接bisect_left, 找第一个大于等于2的数,没有,则返回3
- 双指针
- 原理相同,不再赘述
- 前缀和
-
代码
前缀和
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29N = int(1e5 + 10)
sa, sc, cnt = [0] * N, [0] * N, [0] * N
n = int(input())
INF = int(1e9)
a = list(map(int, input().split()))
b = list(map(int, input().split()))
c = list(map(int, input().split()))
a = [i + 1 for i in a] #下标都+1, 因为会有0,0这个位置需要空出来,所以所有数整体偏移
b = [i + 1 for i in b]
c = [i + 1 for i in c]
for i in a:
sa[i] += 1
for i in range(1, N):
sa[i] += sa[i - 1]
for i in range(n):
cnt[i] = sa[b[i] - 1]
for i in c:
sc[i] += 1
for i in range(1, N):
sc[i] += sc[i - 1]
for i in range(n):
cnt[i] *= (sc[N - 1] - sc[b[i]])
print(sum(cnt))
手写二分
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35n = int(input())
INF = int(1e9)
a = list(map(int, input().split()))
b = list(map(int, input().split()))
c = list(map(int, input().split()))
a.sort()
b.sort()
c.sort()
a = [-INF, *a, INF]
b = [-INF, *b, INF]
c = [-INF, *c, INF]
ans = 0
for i in range(1, n + 1): # [1个元素, n个元素, 1个元素] 故枚举还是从[1, n + 1]
l, r = 0, n + 1 #将边界都放上去
while l < r:
mid = l + r >> 1
if a[mid] >= b[i]:
r = mid
else:
l = mid + 1
tmp1 = r - 1
l, r = 0, n + 1
while l < r:
mid = l + r + 1 >> 1
if c[mid] <= b[i]:
l = mid
else:
r = mid - 1
tmp2 = n - (r + 1) + 1
ans += tmp1 * tmp2
print(ans)
轮子二分
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20'''
Author: NEFU AB-IN
Date: 2022-03-24 17:57:24
FilePath: \ACM\Acwing\1236.1.py
LastEditTime: 2022-03-24 18:54:55
'''
import bisect
n = int(input())
l1 = list(map(int, input().split()))
l2 = list(map(int, input().split()))
l3 = list(map(int, input().split()))
l1.sort()
l3.sort()
ret = 0
for i in range(n):
# print()
ret += (bisect.bisect_left(l1,
l2[i])) * (n - bisect.bisect_right(l3, l2[i]))
print(ret)双指针
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23'''
Author: NEFU AB-IN
Date: 2022-03-24 19:02:00
FilePath: \ACM\Acwing\1236.3.py
LastEditTime: 2022-03-24 19:20:24
'''
n = int(input())
a = list(map(int, input().split()))
b = list(map(int, input().split()))
c = list(map(int, input().split()))
a.sort()
b.sort()
c.sort()
ja, jc, ans = 0, 0, 0
for i in range(n):
while ja < n and a[ja] < b[i]:
ja += 1
while jc < n and c[jc] <= b[i]:
jc += 1
ans += ja * (n - jc)
print(ans)