1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155
| ''' Author: NEFU AB-IN Date: 2024-07-06 21:52:11 FilePath: \LeetCode\CP134_2\d\d.py LastEditTime: 2024-07-08 16:31:44 '''
from collections import Counter, defaultdict, deque from datetime import datetime, timedelta from functools import lru_cache from heapq import heapify, heappop, heappush, nlargest, nsmallest from itertools import combinations, compress, permutations, starmap, tee from math import ceil, fabs, floor, gcd, log, sqrt from string import ascii_lowercase, ascii_uppercase from sys import exit, setrecursionlimit, stdin from typing import Any, Dict, List, Tuple, TypeVar, Union
TYPE = TypeVar('TYPE') N = int(2e5 + 10) M = int(20) INF = int(2e9) OFFSET = int(100)
setrecursionlimit(INF)
class Arr: array = staticmethod(lambda x=0, size=N: [x] * size) array2d = staticmethod(lambda x=0, rows=N, cols=M: [Arr.array(x, cols) for _ in range(rows)]) graph = staticmethod(lambda size=N: [[] for _ in range(size)]) @staticmethod def to_1_indexed(data: Union[List, str, List[List]]): """Adds a zero prefix to the data and returns the modified data and its length.""" if isinstance(data, list): if all(isinstance(item, list) for item in data): new_data = [[0] * (len(data[0]) + 1)] + [[0] + row for row in data] return new_data, len(new_data) - 1, len(new_data[0]) - 1 else: new_data = [0] + data return new_data, len(new_data) - 1 elif isinstance(data, str): new_data = '0' + data return new_data, len(new_data) - 1 else: raise TypeError("Input must be a list, a 2D list, or a string")
class Str: letter_to_num = staticmethod(lambda x: ord(x.upper()) - 65) num_to_letter = staticmethod(lambda x: ascii_uppercase[x]) removeprefix = staticmethod(lambda s, prefix: s[len(prefix):] if s.startswith(prefix) else s) removesuffix = staticmethod(lambda s, suffix: s[:-len(suffix)] if s.endswith(suffix) else s)
class Math: max = staticmethod(lambda a, b: a if a > b else b) min = staticmethod(lambda a, b: a if a < b else b)
class IO: input = staticmethod(lambda: stdin.readline().rstrip("\r\n")) read = staticmethod(lambda: map(int, IO.input().split())) read_list = staticmethod(lambda: list(IO.read()))
class Std: @staticmethod def find(container: Union[List[TYPE], str], value: TYPE): """Returns the index of value in container or -1 if value is not found.""" if isinstance(container, list): try: return container.index(value) except ValueError: return -1 elif isinstance(container, str): return container.find(value) @staticmethod def pairwise(iterable): """Return successive overlapping pairs taken from the input iterable.""" a, b = tee(iterable) next(b, None) return zip(a, b) @staticmethod def bisect_left(a, x, key=lambda y: y): """The insertion point is the first position where the element is not less than x.""" left, right = 0, len(a) while left < right: mid = (left + right) // 2 if key(a[mid]) < x: left = mid + 1 else: right = mid return left
@staticmethod def bisect_right(a, x, key=lambda y: y): """The insertion point is the first position where the element is greater than x.""" left, right = 0, len(a) while left < right: mid = (left + right) // 2 if key(a[mid]) <= x: left = mid + 1 else: right = mid return left class SparseTable: def __init__(self, data: list, func=lambda x, y: x | y): """Initialize the Sparse Table with the given data and function.""" self.func = func self.st = [list(data)] i, n = 1, len(self.st[0]) while 2 * i <= n: pre = self.st[-1] self.st.append([func(pre[j], pre[j + i]) for j in range(n - 2 * i + 1)]) i <<= 1
def query(self, begin: int, end: int): """Query the combined result over the interval [begin, end].""" lg = (end - begin + 1).bit_length() - 1 return self.func(self.st[lg][begin], self.st[lg][end - (1 << lg) + 1])
class Solution: def countSubarrays(self, nums: List[int], k: int) -> int: st = Std.SparseTable(nums, func=lambda x, y: x & y) ans = 0 n = len(nums) for i in range(n): l = Std.bisect_left(range(i, n), -k, key=lambda r: -st.query(i, r)) r = Std.bisect_right(range(i, n), -k, key=lambda r: -st.query(i, r)) ans += r - l return ans
class Solution: def countSubarrays(self, nums: List[int], k: int) -> int: nums, n = Arr.to_1_indexed(nums) dp = Counter() res = 0 for i in range(1, n + 1): cur_dp = Counter() cur_dp[nums[i]] += 1 for num, val in dp.items(): cur_dp[nums[i] & num] += val res += cur_dp[k] dp = cur_dp return res
|