698. 划分为k个相等的子集

摘要
Title: 698. 划分为k个相等的子集
Categories: dfs、记忆化搜索

Powered by:NEFU AB-IN

Link

698. 划分为k个相等的子集

题意

给定一个整数数组 nums 和一个正整数 k,找出是否有可能把这个数组分成 k 个非空子集,其总和都相等。

思路

  1. 首先肯定是判断数据是否可行,总和需要是k的倍数,且数组长度应该不小于k
  2. 记忆化搜索,设的参数一定把一个状态定死,也就是能把这个状态完整的表述出来,且别的参数排列无法描述这个状态
    1. 比如这里的 dfs(cur_sum: int, cur_num: int, used_cnt: int)
      1. cur_sum 代表当前集合中的数字和
      2. cur_num 代表当前多少个集合被使用了
      3. used_cnt 代表nums多少个数被使用了
  3. 然后进行dfs即可,记忆化搜索是为了除掉那些已经跑过的状态,注意进行剪枝,当当前的元素个数,比剩余坑位都还有小,说明不能dfs

代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
'''
Author: NEFU AB-IN
Date: 2024-08-25 21:17:11
FilePath: \LeetCode\698\698.py
LastEditTime: 2024-08-25 22:06:24
'''
# 3.8.19 import
import random
from collections import Counter, defaultdict, deque
from datetime import datetime, timedelta
from functools import lru_cache, reduce
from heapq import heapify, heappop, heappush, nlargest, nsmallest
from itertools import combinations, compress, permutations, starmap, tee
from math import ceil, comb, fabs, floor, gcd, hypot, log, perm, sqrt
from string import ascii_lowercase, ascii_uppercase
from sys import exit, setrecursionlimit, stdin
from tokenize import group
from typing import Any, Callable, Dict, List, Optional, Tuple, TypeVar, Union

# Constants
TYPE = TypeVar('TYPE')
N = int(2e5 + 10)
M = int(20)
INF = int(1e12)
OFFSET = int(100)
MOD = int(1e9 + 7)

# Set recursion limit
setrecursionlimit(int(2e9))


class Arr:
array = staticmethod(lambda x=0, size=N: [x() if callable(x) else x for _ in range(size)])
array2d = staticmethod(lambda x=0, rows=N, cols=M: [Arr.array(x, cols) for _ in range(rows)])
graph = staticmethod(lambda size=N: [[] for _ in range(size)])


class Math:
max = staticmethod(lambda a, b: a if a > b else b)
min = staticmethod(lambda a, b: a if a < b else b)


class Std:
pass

# ————————————————————— Division line ——————————————————————


class Solution:
def canPartitionKSubsets(self, nums: List[int], k: int) -> bool:
if sum(nums) % k != 0 or len(nums) < k:
return False

group_sum = sum(nums) // k
nums.sort(reverse=True)
vis_ = Arr.array(0, len(nums))
flag = False

@cache
def dfs(cur_sum: int, cur_num: int, used_cnt: int):
nonlocal flag
if flag:
return

if cur_sum == group_sum:
cur_num += 1
if cur_num == k:
flag = True
return
if len(nums) - used_cnt >= k - cur_num:
dfs(0, cur_num, used_cnt)
return

for i, num in enumerate(nums):
if not vis_[i] and cur_sum + num <= group_sum:
vis_[i] = 1
used_cnt += 1
dfs(cur_sum + num, cur_num, used_cnt)
vis_[i] = 0
used_cnt -= 1
return

dfs(0, 0, 0)
return flag

使用搜索:谷歌必应百度