1171. 距离

摘要
Title: 1171. 距离
Tag: tarjan、lca
Memory Limit: 64 MB
Time Limit: 1000 ms

Powered by:NEFU AB-IN

Link

1171. 距离

  • 题意

    给出 n 个点的一棵树,多次询问两点之间的最短距离。
    注意:
    边是无向的。
    所有节点的编号是 1,2,…,n。

  • 思路

    离线tarjan求lca的算法:需要先将所有的查询存起来,边tarjan边求答案

    dfsdfs时,将所有点分成三大类

    • 2号点:代表已经访问并结束回溯

    • 1号点:代表正在访问

    • 0号点:代表还没有访问过

    其中所有2号点和正在搜索的1号点路径中已经通过并查集合并成一个集合

    1171

    • 1、首先dfsdfs, 先求出所有点到根结点的距离dist[],设x号点和y号点的最近公共祖先是p,则x和y的最近距离等于

      dist[x]+dist[y]2dist[p]dist[x] + dist[y] - 2 * dist[p]

    • 2、在深度优先遍历1号点中的u点的时候,需要把u的查询的另外一个点的最短距离进行计算并存储,最后把u点合并到上一结点的集合

  • 代码

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    39
    40
    41
    42
    43
    44
    45
    46
    47
    48
    49
    50
    51
    52
    53
    54
    55
    56
    57
    58
    59
    60
    61
    62
    63
    64
    65
    66
    67
    68
    69
    70
    '''
    Author: NEFU AB-IN
    Date: 2022-04-06 22:47:25
    FilePath: \ACM\Acwing\1171.py
    LastEditTime: 2022-04-06 23:02:31
    '''
    import sys

    sys.setrecursionlimit(int(2e9))


    class Query():
    def __init__(self, index, v):
    self.index = index
    self.v = v


    N = int(1e4 + 10)
    g = [[] for _ in range(N)]
    p = [i for i in range(N)]
    st, dist = [0] * N, [0] * N
    res = [[] for _ in range(N)]

    ans = [0] * N * 2 # 查询的编号 -> 答案


    def find(x):
    if x != p[x]:
    p[x] = find(p[x])
    return p[x]


    def dfs(fa, u):
    for v, w in g[u]:
    if fa == v:
    continue
    dist[v] = dist[u] + w
    dfs(u, v)


    def tarjan(u):
    st[u] = 1
    for v, w in g[u]:
    if not st[v]: #如果未被遍历过,才遍历
    tarjan(v)
    p[v] = u

    for q in res[u]: #找和u相关的查询
    if st[q.v] == 2: #如果查询中的另一个点已经被遍历且回溯过,说明已经合并完集合了,直接求即可
    lca = find(q.v)
    ans[q.index] = dist[u] + dist[q.v] - 2 * dist[lca]

    st[u] = 2 #最后别忘了将其更新为2状态


    n, m = map(int, input().split())
    for i in range(n - 1):
    x, y, k = map(int, input().split())
    g[x].append([y, k])
    g[y].append([x, k])
    for i in range(m):
    x, y = map(int, input().split())
    res[x].append(Query(i, y))
    res[y].append(Query(i, x))

    dfs(-1, 1)
    tarjan(1)

    for i in range(m):
    print(ans[i])
使用搜索:谷歌必应百度