A1087 All Roads Lead to Rome (30)

摘要
Title: A1087 All Roads Lead to Rome (30)
Categories: 最短路

Powered by:NEFU AB-IN

Link

A1087 All Roads Lead to Rome (30)

题意

你是一个导游,要让你的顾客从当前城市出发,到达“ROM”这个目的地,图中可以游览别的城市。每个城市都有一个可以使旅客快乐的值。从一个城市到另一个城市要花费一定的路费。你要求出从起点到终点的旅客花费最少的路径。如果有多条路径,那么选择沿途快乐值之和最大的路线。如果仍有多条路线,那么选择图中经过城市个数最少的路线。
数据保证有且仅由一条路线。一共有n个城市,城市之间有k条双向道路。每个城市都有一个名字,用三个字母的字符串表示。

求出这个最短路的条数、长度、最大点权和、最大平均点权和

思路

https://blog.csdn.net/qq_40531479/article/details/104188442
正常最短路即可,然后维护这些值

  1. 当最短路距离相等时,说明存在另一个最短路,这时候 cnt[v] += cnt[u]
  2. 然后继续判断,self.sum_[u] > self.sum_[self.pre_[v]]时,说明我当前点的幸福值比之前v的父节点的幸福值大,所以我们可以用u来更新v

代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
'''
Author: NEFU AB-IN
Date: 2024-08-15 23:57:49
FilePath: \GPLT\A1087\A1087.py
LastEditTime: 2024-08-17 21:06:27
'''
# 3.8.19 import
import random
from collections import Counter, defaultdict, deque
from datetime import datetime, timedelta
from functools import lru_cache, reduce
from heapq import heapify, heappop, heappush, nlargest, nsmallest
from itertools import combinations, compress, permutations, starmap, tee
from math import ceil, comb, fabs, floor, gcd, hypot, log, perm, sqrt
from string import ascii_lowercase, ascii_uppercase
from sys import exit, setrecursionlimit, stdin
from typing import Any, Callable, Dict, List, Optional, Tuple, TypeVar, Union

# Constants
TYPE = TypeVar('TYPE')
N = int(2e5 + 10)
M = int(20)
INF = int(1e12)
OFFSET = int(100)
MOD = int(1e9 + 7)

# Set recursion limit
setrecursionlimit(int(2e9))


class Arr:
array = staticmethod(lambda x=0, size=N: [x() if callable(x) else x for _ in range(size)])
array2d = staticmethod(lambda x=0, rows=N, cols=M: [Arr.array(x, cols) for _ in range(rows)])
graph = staticmethod(lambda size=N: [[] for _ in range(size)])


class Math:
max = staticmethod(lambda a, b: a if a > b else b)
min = staticmethod(lambda a, b: a if a < b else b)


class IO:
input = staticmethod(lambda: stdin.readline().strip())
read = staticmethod(lambda: map(int, IO.input().split()))
read_list = staticmethod(lambda: list(IO.read()))
read_mixed = staticmethod(lambda *types: [t(v) for t, v in zip(types, IO.input().split())])


class Std:
class Dijkstra:
"""Dijkstra's algorithm for finding the shortest path in a weighted graph, designed to compute various properties related to the shortest paths from a source node."""

def __init__(self, n: int, val_: List):
self.n = n # Number of nodes in the graph
self.val_ = val_ # Node values (weights associated with each node)
self.g_ = Arr.graph(n) # Adjacency list to store the graph
self.dist_ = Arr.array(INF, n) # Shortest distance from the source to each node
self.sum_ = Arr.array(0, n) # Sum of node values along the shortest path
self.cnt_ = Arr.array(0, n) # Count of shortest paths
self.pre_ = Arr.array(0, n) # Predecessor node in the shortest path
self.num_ = Arr.array(0, n) # Number of nodes in the shortest path

def add_edge(self, u: int, v: int, w: int):
"""Add an edge to the graph."""
self.g_[u].append((v, w))

def dijkstra(self, s: int):
"""Dijkstra's algorithm for finding the shortest path in a graph.
This method calculates the shortest distances, maximizes the sum of node values (`val_`) along the paths, and minimizes the number of edges (`num_`) used in the paths. Additionally, it counts the number of distinct shortest paths to each node."""
st_ = Arr.array(0, self.n)
q = []

self.dist_[s] = 0
self.sum_[s] = self.val_[s]
self.cnt_[s] = self.num_[s] = 1
heappush(q, (0, s))

while q:
_, u = heappop(q)
if st_[u]:
continue
st_[u] = 1
for v, w in self.g_[u]:
if self.dist_[v] > self.dist_[u] + w:
self.dist_[v] = self.dist_[u] + w
self.pre_[v] = u
self.sum_[v] = self.sum_[u] + self.val_[v]
self.num_[v] = self.num_[u] + 1
self.cnt_[v] = self.cnt_[u]
heappush(q, (self.dist_[v], v))
elif self.dist_[v] == self.dist_[u] + w:
self.cnt_[v] += self.cnt_[u]
if self.sum_[u] > self.sum_[self.pre_[v]]:
self.pre_[v] = u
self.sum_[v] = self.sum_[u] + self.val_[v]
self.num_[v] = self.num_[u] + 1
heappush(q, (self.dist_[v], v))
elif self.sum_[u] == self.sum_[self.pre_[v]] and self.num_[v] > self.num_[u] + 1:
self.pre_[v] = u
self.num_[v] = self.num_[u] + 1
heappush(q, (self.dist_[v], v))

class TrieNodeGraph:
"""TrieNode class can convert each string into an integer identifier, useful in graph theory."""
_sid_cnt = 0 # sid counter, representing string index starting from 0
_sid_to_word_ = {} # Dictionary mapping sid to the original string

def __init__(self):
"""Initialize children dictionary and cost. The trie tree is a 26-ary tree."""
self._children_ = {}
self._is_end_of_word = False # Flag to indicate end of word
self._sid = -1 # Unique ID for the node, -1 if not assigned

def add(self, word: str) -> int:
"""Add a word to the trie and return a unique ID."""
node = self
for c in word:
if c not in node._children_:
node._children_[c] = Std.TrieNodeGraph()
node = node._children_[c]
node._is_end_of_word = True # Mark the end of the word
if node._sid < 0:
node._sid = self._sid_cnt
self._sid_cnt += 1
self._sid_to_word_[node._sid] = word
return node._sid

def _search(self, word: str) -> int:
"""Search for the exact word in the trie and return its unique ID, else -1."""
node = self
for c in word:
if c not in node._children_:
return -1
node = node._children_[c]
return node._sid if node._is_end_of_word else -1

def get_id(self, word: str) -> int:
"""Retrieve the unique ID for a given word."""
return self._search(word)

def get_str(self, sid: int) -> str:
"""Retrieve the original string associated with a given unique ID."""
return word if (word := self._sid_to_word_.get(sid)) else "-1"
# ————————————————————— Division line ——————————————————————


n, k, st = IO.read_mixed(int, int, str)

trie = Std.TrieNodeGraph()
dj = Std.Dijkstra(n, Arr.array(0, n))
trie.add(st)

for i in range(n - 1):
u, d = IO.read_mixed(str, int)
index = trie.add(u)
dj.val_[index] = d

for i in range(k):
u, v, d = IO.read_mixed(str, str, int)
dj.add_edge(trie.get_id(u), trie.get_id(v), d)
dj.add_edge(trie.get_id(v), trie.get_id(u), d)

dj.dijkstra(trie.get_id(st))

ed_id = trie.get_id("ROM")

print(dj.cnt_[ed_id], dj.dist_[ed_id], dj.sum_[ed_id], dj.sum_[ed_id] // (dj.num_[ed_id] - 1))
stack_ = []

while ed_id != dj.pre_[ed_id]:
stack_.append(ed_id)
ed_id = dj.pre_[ed_id]

print(st, end="")
while stack_:
print(f"->{trie.get_str(stack_.pop())}", end="")

使用搜索:谷歌必应百度