3209. 子数组按位与值为 K 的数目

摘要
Title: 3209. 子数组按位与值为 K 的数目
Categories: dp、st、二分

Powered by:NEFU AB-IN

Link

3209. 子数组按位与值为 K 的数目

题意

给你一个整数数组 nums 和一个整数 k ,请你返回 nums 中有多少个
子数组
满足:子数组中所有元素按位 AND 的结果为 k 。

思路

  1. st+二分

    https://leetcode.cn/problems/number-of-subarrays-with-and-value-of-k/solutions/2833382/stbiao-er-fen-by-time-v5-4qtm/
    细节在于:

    1. 由于一直and操作是非递减的,所以取个负号,这样就是非递增了,能配合bisect_left函数
    2. bisect_left 函数,可以直接这么用 l = bisect_left(range(i, n), -k, key=lambda r: -st.query(i, r)),直接在 range(i, n) 上进行二分查找,通过 key 参数动态计算按位与结果,会快很多(不过这是 3.10 引进的特性),当然我也会自己实现
      1. range(i, n) 生成从 i 到 n-1 的索引序列。
      2. -k 是要插入的元素,它是 k 的相反数。
      3. key=lambda r: -st.query(i, r) 是自定义的比较函数,用于对 range(i, n) 中的每个元素 r 进行比较。它将 st.query(i, r) 的结果取反,实际上是在对 -st.query(i, r) 进行比较。
        相当于
    1
    2
    3
    4
    and_results = [-st.query(i, r) for r in range(i, n)]
    l = bisect_left(and_results, -k)
    r = bisect_right(and_results, -k)
    ans += r - l
  2. 滚动数组 + 哈希 + dp

    1. 我们使用一个二维数组 dp 来记录以每个位置结尾的所有可能的按位 AND 结果及其出现次数。定义 dp[i][j] 为以 nums[i] 结尾的子数组中,按位 AND 结果为 j 的子数组数量。
    2. 初始化,dp[0][nums[0]] = 1:表示第一个元素单独形成一个子数组,且按位 AND 结果为 nums[0]。
    3. 对于每个元素 nums[i],我们需要遍历之前所有的状态来更新当前状态:dp[i][nums[i]&key]+=dp[i1][key]dp[i][nums[i]\&key]+=dp[i−1][key]
    4. 由于每一层的状态仅依赖于前一层的状态,因此我们可以使用滚动数组来优化空间复杂度。

代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
'''
Author: NEFU AB-IN
Date: 2024-07-06 21:52:11
FilePath: \LeetCode\CP134_2\d\d.py
LastEditTime: 2024-07-08 16:31:44
'''
# 3.8.19 import
from collections import Counter, defaultdict, deque
from datetime import datetime, timedelta
from functools import lru_cache
from heapq import heapify, heappop, heappush, nlargest, nsmallest
from itertools import combinations, compress, permutations, starmap, tee
from math import ceil, fabs, floor, gcd, log, sqrt
from string import ascii_lowercase, ascii_uppercase
from sys import exit, setrecursionlimit, stdin
from typing import Any, Dict, List, Tuple, TypeVar, Union

# Constants
TYPE = TypeVar('TYPE')
N = int(2e5 + 10) # If using AR, modify accordingly
M = int(20) # If using AR, modify accordingly
INF = int(2e9)
OFFSET = int(100)

# Set recursion limit
setrecursionlimit(INF)

class Arr:
array = staticmethod(lambda x=0, size=N: [x] * size)
array2d = staticmethod(lambda x=0, rows=N, cols=M: [Arr.array(x, cols) for _ in range(rows)])
graph = staticmethod(lambda size=N: [[] for _ in range(size)])
@staticmethod
def to_1_indexed(data: Union[List, str, List[List]]):
"""Adds a zero prefix to the data and returns the modified data and its length."""
if isinstance(data, list):
if all(isinstance(item, list) for item in data): # Check if it's a 2D array
new_data = [[0] * (len(data[0]) + 1)] + [[0] + row for row in data]
return new_data, len(new_data) - 1, len(new_data[0]) - 1
else:
new_data = [0] + data
return new_data, len(new_data) - 1
elif isinstance(data, str):
new_data = '0' + data
return new_data, len(new_data) - 1
else:
raise TypeError("Input must be a list, a 2D list, or a string")

class Str:
letter_to_num = staticmethod(lambda x: ord(x.upper()) - 65) # A -> 0
num_to_letter = staticmethod(lambda x: ascii_uppercase[x]) # 0 -> A
removeprefix = staticmethod(lambda s, prefix: s[len(prefix):] if s.startswith(prefix) else s)
removesuffix = staticmethod(lambda s, suffix: s[:-len(suffix)] if s.endswith(suffix) else s)

class Math:
max = staticmethod(lambda a, b: a if a > b else b)
min = staticmethod(lambda a, b: a if a < b else b)

class IO:
input = staticmethod(lambda: stdin.readline().rstrip("\r\n"))
read = staticmethod(lambda: map(int, IO.input().split()))
read_list = staticmethod(lambda: list(IO.read()))


class Std:
@staticmethod
def find(container: Union[List[TYPE], str], value: TYPE):
"""Returns the index of value in container or -1 if value is not found."""
if isinstance(container, list):
try:
return container.index(value)
except ValueError:
return -1
elif isinstance(container, str):
return container.find(value)

@staticmethod
def pairwise(iterable):
"""Return successive overlapping pairs taken from the input iterable."""
a, b = tee(iterable)
next(b, None)
return zip(a, b)

@staticmethod
def bisect_left(a, x, key=lambda y: y):
"""The insertion point is the first position where the element is not less than x."""
left, right = 0, len(a)
while left < right:
mid = (left + right) // 2
if key(a[mid]) < x:
left = mid + 1
else:
right = mid
return left

@staticmethod
def bisect_right(a, x, key=lambda y: y):
"""The insertion point is the first position where the element is greater than x."""
left, right = 0, len(a)
while left < right:
mid = (left + right) // 2
if key(a[mid]) <= x:
left = mid + 1
else:
right = mid
return left

class SparseTable:
def __init__(self, data: list, func=lambda x, y: x | y):
"""Initialize the Sparse Table with the given data and function."""
self.func = func
self.st = [list(data)]
i, n = 1, len(self.st[0])
while 2 * i <= n:
pre = self.st[-1]
self.st.append([func(pre[j], pre[j + i]) for j in range(n - 2 * i + 1)])
i <<= 1

def query(self, begin: int, end: int):
"""Query the combined result over the interval [begin, end]."""
lg = (end - begin + 1).bit_length() - 1
return self.func(self.st[lg][begin], self.st[lg][end - (1 << lg) + 1])



# ————————————————————— Division line ——————————————————————
class Solution:
def countSubarrays(self, nums: List[int], k: int) -> int:
st = Std.SparseTable(nums, func=lambda x, y: x & y)

ans = 0
n = len(nums)
for i in range(n):
l = Std.bisect_left(range(i, n), -k, key=lambda r: -st.query(i, r))
r = Std.bisect_right(range(i, n), -k, key=lambda r: -st.query(i, r))
ans += r - l

return ans

class Solution:
def countSubarrays(self, nums: List[int], k: int) -> int:
nums, n = Arr.to_1_indexed(nums)
dp = Counter()
res = 0

for i in range(1, n + 1):
cur_dp = Counter()
cur_dp[nums[i]] += 1

for num, val in dp.items():
# 转移方程
cur_dp[nums[i] & num] += val

res += cur_dp[k]
dp = cur_dp
return res
使用搜索:谷歌必应百度