3213. 最小代价构造字符串

摘要
Title: 3213. 最小代价构造字符串
Categories: dp、trie、hash

Powered by:NEFU AB-IN

Link

3213. 最小代价构造字符串

题意

给你一个字符串 target、一个字符串数组 words 以及一个整数数组 costs,这两个数组长度相同。

设想一个空字符串 s。

你可以执行以下操作任意次数(包括零次):

选择一个在范围 [0, words.length - 1] 的索引 i。
将 words[i] 追加到 s。
该操作的成本是 costs[i]。
返回使 s 等于 target 的 最小 成本。如果不可能,返回 -1。

思路

字典树/字符串哈希 + dp

  1. 使用 Trie 树存储单词和成本:
    我们将所有的单词和对应的成本插入到一个 Trie 树中。Trie 树是一种多叉树,可以快速查找以某个前缀开头的所有单词。
    这样我们就能在 Trie 树中快速查找到以 target 中某个位置开始的所有前缀单词及其成本。
  2. 动态规划(Dynamic Programming):
    使用一个动态规划数组 dp,其中 dp[i] 表示构造 target 的前 i 个字符的最小成本。
    初始化 dp[0] = 0,表示构造空字符串的成本为 0,其他位置初始化为无穷大,表示尚未计算到该位置。
  3. 遍历目标字符串:
    对于目标字符串 target 的每一个位置 i,如果 dp[i] 是无穷大,表示不能从当前位置开始构造,则跳过。
    否则,使用 Trie 树的 search 方法,从当前位置 i 开始查找所有可能的前缀及其成本。
    对于每一个找到的前缀,更新 dp 数组:dp[i + length] = min(dp[i + length], dp[i] + cost),表示从当前位置 i 开始构造到 i + length 的最小成本。

代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
# 3.8.19 import
import random
from collections import Counter, defaultdict, deque
from datetime import datetime, timedelta
from functools import lru_cache
from heapq import heapify, heappop, heappush, nlargest, nsmallest
from itertools import combinations, compress, permutations, starmap, tee
from math import ceil, fabs, floor, gcd, log, sqrt
from string import ascii_lowercase, ascii_uppercase
from sys import exit, setrecursionlimit, stdin
from typing import Any, Dict, List, Tuple, TypeVar, Union

# Constants
TYPE = TypeVar('TYPE')
N = int(2e5 + 10) # If using AR, modify accordingly
M = int(20) # If using AR, modify accordingly
INF = int(2e9)
OFFSET = int(100)

# Set recursion limit
setrecursionlimit(INF)

class Arr:
array = staticmethod(lambda x=0, size=N: [x] * size)
array2d = staticmethod(lambda x=0, rows=N, cols=M: [Arr.array(x, cols) for _ in range(rows)])
graph = staticmethod(lambda size=N: [[] for _ in range(size)])
@staticmethod
def to_1_indexed(data: Union[List, str, List[List]]):
"""Adds a zero prefix to the data and returns the modified data and its length."""
if isinstance(data, list):
if all(isinstance(item, list) for item in data): # Check if it's a 2D array
new_data = [[0] * (len(data[0]) + 1)] + [[0] + row for row in data]
return new_data, len(new_data) - 1, len(new_data[0]) - 1
else:
new_data = [0] + data
return new_data, len(new_data) - 1
elif isinstance(data, str):
new_data = '0' + data
return new_data, len(new_data) - 1
else:
raise TypeError("Input must be a list, a 2D list, or a string")

class Str:
letter_to_num = staticmethod(lambda x: ord(x.upper()) - 65) # A -> 0
num_to_letter = staticmethod(lambda x: ascii_uppercase[x]) # 0 -> A
removeprefix = staticmethod(lambda s, prefix: s[len(prefix):] if s.startswith(prefix) else s)
removesuffix = staticmethod(lambda s, suffix: s[:-len(suffix)] if s.endswith(suffix) else s)

class Math:
max = staticmethod(lambda a, b: a if a > b else b)
min = staticmethod(lambda a, b: a if a < b else b)

class IO:
input = staticmethod(lambda: stdin.readline().rstrip("\r\n"))
read = staticmethod(lambda: map(int, IO.input().split()))
read_list = staticmethod(lambda: list(IO.read()))


class Std:
@staticmethod
def find(container: Union[List[TYPE], str], value: TYPE):
"""Returns the index of value in container or -1 if value is not found."""
if isinstance(container, list):
try:
return container.index(value)
except ValueError:
return -1
elif isinstance(container, str):
return container.find(value)

@staticmethod
def pairwise(iterable):
"""Return successive overlapping pairs taken from the input iterable."""
a, b = tee(iterable)
next(b, None)
return zip(a, b)

@staticmethod
def bisect_left(a, x, key=lambda y: y):
"""The insertion point is the first position where the element is not less than x."""
left, right = 0, len(a)
while left < right:
mid = (left + right) >> 1
if key(a[mid]) < x:
left = mid + 1
else:
right = mid
return left

@staticmethod
def bisect_right(a, x, key=lambda y: y):
"""The insertion point is the first position where the element is greater than x."""
left, right = 0, len(a)
while left < right:
mid = (left + right) >> 1
if key(a[mid]) <= x:
left = mid + 1
else:
right = mid
return left

class SparseTable:
def __init__(self, data: list, func=lambda x, y: x | y):
"""Initialize the Sparse Table with the given data and function."""
self.func = func
self.st = [list(data)]
i, n = 1, len(self.st[0])
while 2 * i <= n:
pre = self.st[-1]
self.st.append([func(pre[j], pre[j + i]) for j in range(n - 2 * i + 1)])
i <<= 1

def query(self, begin: int, end: int):
"""Query the combined result over the interval [begin, end]."""
lg = (end - begin + 1).bit_length() - 1
return self.func(self.st[lg][begin], self.st[lg][end - (1 << lg) + 1])

class TrieNode:
def __init__(self):
"""Initialize children dictionary and cost. The trie tree is a 26-ary tree."""
self.children = {}
self.cost = INF

def add(self, word, cost):
"""Add a word to the trie with the associated cost."""
node = self
for c in word:
if c not in node.children:
node.children[c] = Std.TrieNode()
node = node.children[c]
node.cost = min(node.cost, cost)

def search(self, word):
"""Search for prefixes of 'word' in the trie and return their lengths and costs."""
node = self
ans = []
for i, c in enumerate(word):
if c not in node.children:
break
node = node.children[c]
if node.cost != INF:
ans.append([i + 1, node.cost]) # i + 1 to denote length from start
return ans

class StringHash:
def __init__(self, s: str, mod: int = 1_070_777_777):
"""Initialize the StringHash object with the string, base, and mod."""
self.s = s
self.mod = mod
self.base = random.randint(8 * 10 ** 8, 9 * 10 ** 8)
self.n = len(s)
self.pow_base = [1] + Arr.array(0, self.n) # pow_base[i] = BASE^i
self.pre_hash = Arr.array(0, self.n + 1) # pre_hash[i] = hash(s[:i])
self._compute_hash()

def _compute_hash(self):
"""Compute the prefix hash values and power of base values for the string."""
for i, b in enumerate(self.s):
self.pow_base[i + 1] = self.pow_base[i] * self.base % self.mod
self.pre_hash[i + 1] = (self.pre_hash[i] * self.base + ord(b)) % self.mod

def get_sub_hash(self, l: int, r: int) -> int:
"""Get the hash value of the substring s[l:r+1] """
return (self.pre_hash[r + 1] - self.pre_hash[l] * self.pow_base[r - l + 1] % self.mod + self.mod) % self.mod

def get_full_hash(self) -> int:
"""Get the hash value of the full string"""
return self.pre_hash[self.n]

def compute_hash(self, word: str) -> int:
"""Compute the hash value of a given word using the object's base and mod."""
h = 0
for b in word:
h = (h * self.base + ord(b)) % self.mod
return h

# ————————————————————— Division line ——————————————————————

class Solution:
def minimumCost(self, target: str, words: List[str], costs: List[int]) -> int:
# Build the Trie
trie = Std.TrieNode()
for word, cost in zip(words, costs):
trie.add(word, cost)

n = len(target)
dp = Arr.array(INF, n + 1)
dp[0] = 0

# Dynamic programming to calculate the minimum cost
for i in range(n):
if dp[i] == INF:
continue
for length, cost in trie.search(target[i:]):
dp[i + length] = min(dp[i + length], dp[i] + cost)

return dp[n] if dp[n] != INF else -1

class Solution:
def minimumCost(self, target: str, words: List[str], costs: List[int]) -> int:
n = len(target)

target_hash = Std.StringHash(target)

# 每个 words[i] 的哈希值 -> 最小成本
min_cost = defaultdict(lambda: INF)
for w, c in zip(words, costs):
h = target_hash.compute_hash(w)
min_cost[h] = min(min_cost[h], c)

# 获取所有唯一的单词长度
sorted_lens = sorted(set(map(len, words)))

dp = Arr.array(INF, n + 1)
dp[0] = 0

for i in range(n):
if dp[i] == INF:
continue
for sz in sorted_lens:
if i + sz > n:
break
# 计算子串 target[i:i+sz] 的哈希值
sub_hash = target_hash.get_sub_hash(i, i + sz - 1)
dp[i + sz] = min(dp[i + sz], dp[i] + min_cost[sub_hash])

return -1 if dp[n] == INF else dp[n]
使用搜索:谷歌必应百度