834. 树中距离之和

摘要
Title: 834. 树中距离之和
Categories: 换根dp

Powered by:NEFU AB-IN

Link

834. 树中距离之和

题意

给定一个无向、连通的树。树中有 n 个标记为 0…n-1 的节点以及 n-1 条边 。

给定整数 n 和数组 edges , edges[i] = [ai, bi]表示树中的节点 ai 和 bi 之间有一条边。

返回长度为 n 的数组 answer ,其中 answer[i] 是树中第 i 个节点与所有其他节点之间的距离之和。

思路

第一类换根dp(树形 DP 中的换根 DP 问题又被称为二次扫描,通常不会指定根结点,并且根结点的变化会对一些值,例如子结点深度和、点权和等产生影响。通常需要两次 DFS,第一次 DFS 预处理诸如深度,点权和之类的信息,在第二次 DFS 开始运行换根动态规划。)

形象的解释,参考:https://leetcode.cn/problems/sum-of-distances-in-tree/solutions/2345592/tu-jie-yi-zhang-tu-miao-dong-huan-gen-dp-6bgb/

具体的参考:https://oi-wiki.org/dp/tree/

代码:

  1. 还是两遍dfs,思路都是一样,但是做了一点小调整
    1. 第一遍dfs,前序遍历算出深度(也就是算出根到每个点的距离),后序遍历算出子树的大小。这里就要注意了,如果是能在不进行dfs的情况下算出来的,放在前序遍历,比如我一开始知道u的深度,我就可以算出v的深度;不能的,就放在后序遍历,因为我不知道u的子树大小,需要用v来推
    2. 第二遍dfs,思路是一样的,一步步换根往下推即可,v结点的距离和,可以根据u结点的距离和得出

代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
'''
Author: NEFU AB-IN
Date: 2024-08-06 20:52:10
FilePath: \LeetCode\834\834.py
LastEditTime: 2024-08-06 22:42:21
'''
# 3.8.19 import
import random
from collections import Counter, defaultdict, deque
from datetime import datetime, timedelta
from functools import lru_cache, reduce
from heapq import heapify, heappop, heappush, nlargest, nsmallest
from itertools import combinations, compress, permutations, starmap, tee
from math import ceil, comb, fabs, floor, gcd, hypot, log, perm, sqrt
from string import ascii_lowercase, ascii_uppercase
from sys import exit, setrecursionlimit, stdin
from typing import Any, Callable, Dict, List, Optional, Tuple, TypeVar, Union

# Constants
TYPE = TypeVar('TYPE')
N = int(2e5 + 10)
M = int(20)
INF = int(1e12)
OFFSET = int(100)
MOD = int(1e9 + 7)

# Set recursion limit
setrecursionlimit(int(2e9))


class Arr:
array = staticmethod(lambda x=0, size=N: [x() if callable(x) else x for _ in range(size)])
array2d = staticmethod(lambda x=0, rows=N, cols=M: [Arr.array(x, cols) for _ in range(rows)])
graph = staticmethod(lambda size=N: [[] for _ in range(size)])


class Math:
max = staticmethod(lambda a, b: a if a > b else b)
min = staticmethod(lambda a, b: a if a < b else b)


class IO:
input = staticmethod(lambda: stdin.readline().rstrip("\r\n"))
read = staticmethod(lambda: map(int, IO.input().split()))
read_list = staticmethod(lambda: list(IO.read()))


class Std:
class Graph:
def __init__(self, n: int):
self.n = n
self.g = Arr.graph(n)
self.depth = Arr.array(0, n)
self.size = Arr.array(1, n)
self.dist = Arr.array(0, n)

def add_edge(self, u: int, v: int, w: int):
"""Add an edge to the graph."""
self.g[u].append((v, w))

def dfs1(self, u, fa):
for v, w in self.g[u]:
if fa == v:
continue
self.depth[v] = self.depth[u] + w # w = 1, 其实就是深度
self.dfs1(v, u)
self.size[u] += self.size[v]

def dfs2(self, u, fa):
for v, w in self.g[u]:
if fa == v:
continue
self.dist[v] = self.dist[u] + self.n - 2 * self.size[v]
self.dfs2(v, u)
# ————————————————————— Division line ——————————————————————


class Solution:
def sumOfDistancesInTree(self, n: int, edges: List[List[int]]) -> List[int]:
g = Std.Graph(n)
for u, v in edges:
g.add_edge(u, v, 1)
g.add_edge(v, u, 1)
g.dfs1(0, -1)
for i in range(n): # 计算0结点到所有点的距离
g.dist[0] += g.depth[i]
g.dfs2(0, -1)
return g.dist

使用搜索:谷歌必应百度