2741. 特别的排列

摘要
Title: 2741. 特别的排列
Tag: 状态压缩dp
Memory Limit: 64 MB
Time Limit: 1000 ms

Powered by:NEFU AB-IN

Link

2741. 特别的排列

  • 题意

给你一个下标从 0 开始的整数数组 nums ,它包含 n 个 互不相同 的正整数。如果 nums 的一个排列满足以下条件,我们称它是一个特别的排列:

对于 0 <= i < n - 1 的下标 i ,要么 nums[i] % nums[i+1] == 0 ,要么 nums[i+1] % nums[i] == 0 。
请你返回特别排列的总数目,由于答案可能很大,请将它对 10^9 + 7 取余 后返回。

2 <= nums.length <= 14
1 <= nums[i] <= 10^9

  • 思路

  1. 状态压缩+dfs+记忆化搜索

    看到nums的长度并不长,考虑将状态压缩

    1. 类似全排列的思路(拿一个空数组,从左往右开始填数),一个包含 n 个不同整数的数组有 n! 种排列,如果直接dfs并判断是否是特别的排列,可能会超时。遂考虑将状态压缩为01串,0表示这个数并未选过,1表示这个数已经选过
    2. 举例子
      1. 例如数组 [2, 3, 6],如果状态为101,说明2和6被选过
      2. 考虑对dfs进行优化
        1. 在从左往右填数的过程中,维护在原数组的坐标,只需要考虑下一个数是否和前一个数构成因数关系即可
        2. 根据第一条的结论,我们就可以从一层一层的状态中过滤很多,保证下一层继承的上一层是正确的
        3. 考虑记忆化搜索,Python可以用@cache优化
      3. 状态转移:
        1. 维护两个值
          1. 一个是mask的01串,代表哪个数被选了,初始为0
          2. 一个是prev_index,代表前一个选的坐标是什么,初始为-1,代表是dfs时的第一个数,必选
        2. 类似全排列,查找 mask 中哪个没别选,如果这个数满足要求,那么下一个dp状态就是 dp(mask | (1 << i), i),即让这个mask的这一位置1,并且我们维护的前一个数的下标更改为i
      4. 最后当mask全为1时,则代表全选完了,而且是正确结果
  2. 状压dp

    • dp[i][j] 的含义为,当最后选择的下标j的数字后,状态为i
    • 我们只需要考虑选了i的情况,和选i之前的情况,可以设为j
    • 所以枚举状态的同时,枚举最后一次和倒数第二次的i和j,得到 f[state][i] = (f[state][i] + f[state ^ (1 << i)][j]) % MOD
      • state ^ (1 << i) 就是将i处的0和1交换,相当于1变为0,转变为到i没选,最后是j被选的状态
    • 最后状态全为1,且最后一个数为i,为结果,求和即可
  • 代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128

'''
Author: NEFU AB-IN
Date: 2024-06-26 15:20:32
FilePath: \LeetCode\2741\2741.py
LastEditTime: 2024-06-26 20:34:10
'''
# import
from functools import cache
from sys import setrecursionlimit, stdin, stdout, exit
from collections import Counter, deque, defaultdict
from heapq import heapify, heappop, heappush, nlargest, nsmallest
from bisect import bisect_left, bisect_right
from datetime import datetime, timedelta
from string import ascii_lowercase, ascii_uppercase
from math import log, gcd, sqrt, fabs, ceil, floor
from types import GeneratorType
from typing import TypeVar, List, Dict, Any, Callable


# Data structure
class SA:

def __init__(self, x, y):
self.x = x
self.y = y

def __lt__(self, other):
return self.x < other.x


# Constants
N = int(2e5 + 10) # If using AR, modify accordingly
M = int(20) # If using AR, modify accordingly
INF = int(2e9)
E = int(100)

# Set recursion limit
setrecursionlimit(INF)

# Read
input = lambda: stdin.readline().rstrip("\r\n") # Remove when Mutiple data
read = lambda: map(int, input().split())
read_list = lambda: list(map(int, input().split()))


# Func
class std:

# Recursion
@staticmethod
def bootstrap(f, stack=None):
if stack is None:
stack = []

def wrappedfunc(*args, **kwargs):
if stack:
return f(*args, **kwargs)
else:
to = f(*args, **kwargs)
while True:
if isinstance(to, GeneratorType):
stack.append(to)
to = next(to)
else:
stack.pop()
if not stack:
break
to = stack[-1].send(to)
return to

return wrappedfunc

letter_to_num = staticmethod(lambda x: ord(x.upper()) - 65) # A -> 0
num_to_letter = staticmethod(lambda x: ascii_uppercase[x]) # 0 -> A
array = staticmethod(lambda x=0, size=N: [x] * size)
array2d = staticmethod(
lambda x=0, rows=N, cols=M: [std.array(x, cols) for _ in range(rows)])
max = staticmethod(lambda a, b: a if a > b else b)
min = staticmethod(lambda a, b: a if a < b else b)
filter = staticmethod(lambda func, iterable: list(filter(func, iterable)))


# —————————————————————Division line ——————————————————————


class Solution:

def specialPerm(self, nums: List[int]) -> int:
n = len(nums)
all_mask = (1 << n) - 1
MOD = int(1e9 + 7)

@cache
def dp(mask, prev_index):
if mask == all_mask:
return 1

total_perms = 0
for i in range(n):
if mask & (1 << i) == 0:
if prev_index == -1 or nums[prev_index] % nums[i] == 0 or nums[i] % nums[prev_index] == 0:
total_perms = (total_perms + dp(mask | (1 << i), i)) % MOD

return total_perms

return dp(0, -1)

def specialPerm(self, nums: List[int]) -> int:
MOD = int(1e9 + 7)
n = len(nums)
f = std.array2d(0, 1 << n, n)

for i in range(n):
f[1 << i][i] = 1

for state in range(1, 1 << n):
for i, x in enumerate(nums):
if not state >> i & 1:
continue
for j, y in enumerate(nums):
if i == j or not state >> j & 1:
continue
if x % y != 0 and y % x != 0:
continue
f[state][i] = (f[state][i] + f[state ^ (1 << i)][j]) % MOD

return sum(f[(1 << n) - 1][i] for i in range(n)) % MOD
使用搜索:谷歌必应百度