2977. 转换字符串的最小成本 II

摘要
Title: 2977. 转换字符串的最小成本 II
Categories: floyd、trie、记忆化搜索

Powered by:NEFU AB-IN

Link

2977. 转换字符串的最小成本 II

题意

你有两个字符串 source 和 target,它们长度相同并且都由小写字母组成。

还有两个字符串数组 original 和 changed,以及一个整数数组 cost。cost[i] 表示将 original[i] 替换成 changed[i] 的成本。

你需要通过一系列操作将 source 转换成 target,每次操作可以选择 source 中的一个子串 original[j] 并以 cost[j] 的成本将其替换为 changed[j]。但有以下两个条件:

两次操作选择的子串在 source 中的位置不能重叠。
如果两次操作选择了相同的位置,那么它们选择的子串必须完全相同。
你需要找到将 source 转换为 target 的最小成本。如果无法完成转换,返回 -1。

思路

https://leetcode.cn/problems/minimum-cost-to-convert-string-ii/solutions/2577877/zi-dian-shu-floyddp-by-endlesscheng-oi2r/

  1. floyd (优化了) 算字符串的距离
  2. trie 用来转换字符串为整数下标,方便最短路;用来search当前的字符串的所有前缀
    1. 注意这里的search方法我进行修改了,里面也有不符合的,需要自己判断。我的目的是把所有长度求出来,方便zip对应
  3. 记忆化搜索,感觉比dp好写一些,从0开始往后判断,能改到就dfs下一个地方,相同的其实就是边长为0的最短路,是符合条件的,我们也dfs

代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
# 3.8.19 import
import random
from collections import Counter, defaultdict, deque
from datetime import datetime, timedelta
from functools import lru_cache
from heapq import heapify, heappop, heappush, nlargest, nsmallest
from itertools import combinations, compress, permutations, starmap, tee
from math import ceil, comb, fabs, floor, gcd, log, perm, sqrt
from string import ascii_lowercase, ascii_uppercase
from sys import exit, setrecursionlimit, stdin
from typing import Any, Dict, List, Tuple, TypeVar, Union

# Constants
TYPE = TypeVar('TYPE')
N = int(2e5 + 10)
M = int(20)
INF = int(1e12)
OFFSET = int(100)
MOD = int(1e9 + 7)

# Set recursion limit
setrecursionlimit(int(2e9))


class Arr:
array = staticmethod(lambda x=0, size=N: [x() if callable(x) else x for _ in range(size)])
array2d = staticmethod(lambda x=0, rows=N, cols=M: [Arr.array(x, cols) for _ in range(rows)])
graph = staticmethod(lambda size=N: [[] for _ in range(size)])


class Math:
max = staticmethod(lambda a, b: a if a > b else b)
min = staticmethod(lambda a, b: a if a < b else b)


class IO:
input = staticmethod(lambda: stdin.readline().rstrip("\r\n"))
read = staticmethod(lambda: map(int, IO.input().split()))
read_list = staticmethod(lambda: list(IO.read()))


class Std:
class GraphShortestPath:
def __init__(self, n: int):
self.n = n
self.g = Arr.graph(n)
self.spfa_cache = {}
self.dijkstra_cache = {}
self.floyd_cache = None

def add_edge(self, u: int, v: int, w: int):
"""Add an edge to the graph."""
self.g[u].append((v, w))

def spfa(self, s: int) -> List[int]:
"""SPFA (Shortest Path Faster Algorithm) for finding the shortest path in a graph."""
if s in self.spfa_cache:
return self.spfa_cache[s]

dist = Arr.array(INF, self.n)
st = Arr.array(0, self.n)
q = deque()

dist[s] = 0
q.appendleft(s)
st[s] = 1

while q:
u = q.pop()
st[u] = 0
for v, w in self.g[u]:
if dist[v] > dist[u] + w:
dist[v] = dist[u] + w
if st[v] == 0:
q.appendleft(v)
st[v] = 1

self.spfa_cache[s] = dist
return dist

def dijkstra(self, s: int) -> List[int]:
"""Dijkstra's algorithm for finding the shortest path in a graph."""
if s in self.dijkstra_cache:
return self.dijkstra_cache[s]

dist = Arr.array(INF, self.n)
st = Arr.array(0, self.n)
q = []

dist[s] = 0
heappush(q, (0, s))

while q:
d, u = heappop(q)
if st[u]:
continue
st[u] = 1
for v, w in self.g[u]:
if dist[v] > dist[u] + w:
dist[v] = dist[u] + w
heappush(q, (dist[v], v))

self.dijkstra_cache[s] = dist
return dist

def floyd(self) -> List[List[int]]:
"""Floyd's algorithm for finding the shortest paths between all pairs of nodes."""
if self.floyd_cache is not None:
return self.floyd_cache

dist = Arr.array2d(INF, self.n, self.n)
# Initialize distances with the given edges
for u in range(self.n):
for v, w in self.g[u]:
dist[u][v] = Math.min(dist[u][v], w)

# Set the diagonal to zero
for i in range(self.n):
dist[i][i] = 0

# Floyd-Warshall algorithm
for k in range(self.n):
for i in range(self.n):
if dist[i][k] == INF: # If there is no path from i to k, skip
continue
for j in range(self.n):
if dist[i][j] > dist[i][k] + dist[k][j]:
dist[i][j] = dist[i][k] + dist[k][j]

self.floyd_cache = dist
return dist

def shortest_path(self, x: int, y: int, method: str = 'dijkstra') -> int:
"""Calculate the shortest path from node x to node y using the specified method."""
if method == 'spfa':
dist = self.spfa(x)
elif method == 'dijkstra':
dist = self.dijkstra(x)
elif method == 'floyd':
dist_matrix = self.floyd()
return dist_matrix[x][y] if dist_matrix[x][y] < INF // 2 else INF
else:
raise ValueError("Unsupported method. Use 'spfa', 'dijkstra', or 'floyd'.")

return dist[y]

class TrieNode:
"""
TrieNode class can convert each string into an integer identifier, useful in graph theory.
It can also quickly process string prefixes, a common feature used in applications like autocomplete and spell checking.
"""
sid_cnt = 0 # sid counter, representing string index starting from 0

def __init__(self):
"""Initialize children dictionary and cost. The trie tree is a 26-ary tree."""
self.children = {}
self.cost = INF
self.is_end_of_word = False # Flag to indicate end of word
self.sid = -1 # Unique ID for the node, -1 if not assigned

def add(self, word, cost):
"""Add a word to the trie with the associated cost and return a unique ID."""
node = self
for c in word:
if c not in node.children:
node.children[c] = Std.TrieNode()
node = node.children[c]
node.cost = Math.min(node.cost, cost)
node.is_end_of_word = True # Mark the end of the word
if node.sid < 0:
node.sid = self.sid_cnt
self.sid_cnt += 1
return node.sid

def search(self, word: str):
"""Search for prefixes of 'word' in the trie and return their lengths, costs, and sids.

!! Collects all prefix lengths and their associated costs and sids.
Valid matches are those where node.cost != INF and node.sid != -1.
"""
node = self
ans = []
for i, c in enumerate(word):
if c not in node.children:
break
node = node.children[c]
ans.append([i + 1, node.cost, node.sid]) # i + 1 to denote length from start
return ans

def search_exact(self, word: str, return_type: str = 'cost'):
"""Search for the exact word in the trie and return its cost or unique ID.

Args:
word (str): The word to search for.
return_type (str): The type of value to return. Can be 'cost' or 'sid'.

Returns:
int: The cost or unique ID of the word, or INF / -1 if not found.
"""
node = self
for c in word:
if c not in node.children:
return INF if return_type == 'cost' else -1
node = node.children[c]
if node.is_end_of_word:
return node.cost if return_type == 'cost' else node.sid
else:
return INF if return_type == 'cost' else -1
# ————————————————————— Division line ——————————————————————


class Solution:
def minimumCost(self, source: str, target: str, original: List[str], changed: List[str], cost: List[int]) -> int:
trie = Std.TrieNode()
edges = []
for u, v, w in zip(original, changed, cost):
x, y = trie.add(u, 0), trie.add(v, 0)
edges.append((x, y, w))

short_path = Std.GraphShortestPath(trie.sid_cnt)
for u, v, w in edges:
short_path.add_edge(u, v, w)

n = len(source)

@lru_cache(None)
def dfs(l: int):
if l >= n:
return 0
res = INF
if source[l] == target[l]:
res = dfs(l + 1)

for (len_, _, x), (_, _, y) in zip(trie.search(source[l:]), trie.search(target[l:])):
if x != -1 and y != -1:
res = Math.min(res, short_path.shortest_path(x, y, 'floyd') + dfs(l + len_))
return res

ans = dfs(0)

return ans if ans != INF else -1


# print(Solution().minimumCost("abcdefgh", "acdeeghh", ["bcd", "fgh", "thh"], ["cde", "thh", "ghh"], [1, 3, 5]))
# print(Solution().minimumCost("abcd", "acbe", ["a", "b", "c", "c", "e", "d"], ["b", "c", "b", "e", "b", "e"], [2, 5, 5, 1, 2, 20]))
使用搜索:谷歌必应百度