1236. 递增三元组

摘要
Title: 1236. 递增三元组
Tag: 前缀和、二分、双指针
Memory Limit: 64 MB
Time Limit: 1000 ms

Powered by:NEFU AB-IN

Link

1236. 递增三元组

  • 题意

    给定三个整数数组
    A=[A1,A2,…AN],
    B=[B1,B2,…BN],
    C=[C1,C2,…CN],
    请你统计有多少个三元组 (i,j,k) 满足:
    1≤i,j,k≤N
    Ai<Bj<Ck

  • 思路

    核心思路: 枚举B的每个元素,求出A中小于这个元素的个数 乘 B中大于这个元素的个数

    • 前缀和
      • 以A[i]为下标,1为值(即桶排),放入前缀和数组,并做前缀和。那么SA[B[j]1]SA[B[j] - 1]即为答案
      • 同理,SC[N1]SC[B[j]]SC[N - 1] - SC[B[j]]
      • 仅适合于元素的值比较小的情况,不然数组开不下
      • 下标都+1, 因为会有0,0这个位置需要空出来,所以所有数整体偏移,不影响大小关系
    • 二分
      • 求最后一个小于B[i]的数 等价于 求第一个大于等于B[i]的数 - 1
      • 求第一个大于B[i]的数 等价于 求最后一个小于等于B[i]的数 + 1
      • 所以二分数组时,通常可以在数组两边加上边界值 -INF 和 INF
      • 如果实在记不住,最推荐用库中自带的二分的轮子,会帮你自动填上上下界
      • 比如 在1 1 1中找最后一个小于2的数,可以直接bisect_left, 找第一个大于等于2的数,没有,则返回3
    • 双指针
      • 原理相同,不再赘述
  • 代码

    前缀和

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    N = int(1e5 + 10)
    sa, sc, cnt = [0] * N, [0] * N, [0] * N


    n = int(input())
    INF = int(1e9)
    a = list(map(int, input().split()))
    b = list(map(int, input().split()))
    c = list(map(int, input().split()))

    a = [i + 1 for i in a] #下标都+1, 因为会有0,0这个位置需要空出来,所以所有数整体偏移
    b = [i + 1 for i in b]
    c = [i + 1 for i in c]

    for i in a:
    sa[i] += 1
    for i in range(1, N):
    sa[i] += sa[i - 1]
    for i in range(n):
    cnt[i] = sa[b[i] - 1]

    for i in c:
    sc[i] += 1
    for i in range(1, N):
    sc[i] += sc[i - 1]
    for i in range(n):
    cnt[i] *= (sc[N - 1] - sc[b[i]])

    print(sum(cnt))

    手写二分

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    n = int(input())
    INF = int(1e9)
    a = list(map(int, input().split()))
    b = list(map(int, input().split()))
    c = list(map(int, input().split()))

    a.sort()
    b.sort()
    c.sort()
    a = [-INF, *a, INF]
    b = [-INF, *b, INF]
    c = [-INF, *c, INF]

    ans = 0
    for i in range(1, n + 1): # [1个元素, n个元素, 1个元素] 故枚举还是从[1, n + 1]
    l, r = 0, n + 1 #将边界都放上去
    while l < r:
    mid = l + r >> 1
    if a[mid] >= b[i]:
    r = mid
    else:
    l = mid + 1
    tmp1 = r - 1

    l, r = 0, n + 1
    while l < r:
    mid = l + r + 1 >> 1
    if c[mid] <= b[i]:
    l = mid
    else:
    r = mid - 1
    tmp2 = n - (r + 1) + 1
    ans += tmp1 * tmp2

    print(ans)

    轮子二分

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    '''
    Author: NEFU AB-IN
    Date: 2022-03-24 17:57:24
    FilePath: \ACM\Acwing\1236.1.py
    LastEditTime: 2022-03-24 18:54:55
    '''
    import bisect

    n = int(input())
    l1 = list(map(int, input().split()))
    l2 = list(map(int, input().split()))
    l3 = list(map(int, input().split()))
    l1.sort()
    l3.sort()
    ret = 0
    for i in range(n):
    # print()
    ret += (bisect.bisect_left(l1,
    l2[i])) * (n - bisect.bisect_right(l3, l2[i]))
    print(ret)

    双指针

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    '''
    Author: NEFU AB-IN
    Date: 2022-03-24 19:02:00
    FilePath: \ACM\Acwing\1236.3.py
    LastEditTime: 2022-03-24 19:20:24
    '''
    n = int(input())
    a = list(map(int, input().split()))
    b = list(map(int, input().split()))
    c = list(map(int, input().split()))

    a.sort()
    b.sort()
    c.sort()
    ja, jc, ans = 0, 0, 0
    for i in range(n):
    while ja < n and a[ja] < b[i]:
    ja += 1
    while jc < n and c[jc] <= b[i]:
    jc += 1
    ans += ja * (n - jc)

    print(ans)
使用搜索:谷歌必应百度