3241. 标记所有节点需要的时间

摘要
Title: 3241. 标记所有节点需要的时间
Categories: 树形dp、换根dp

Powered by:NEFU AB-IN

Link

3241. 标记所有节点需要的时间

题意

给你一棵 无向 树,树中节点从 0 到 n - 1 编号。同时给你一个长度为 n - 1 的二维整数数组 edges ,其中 edges[i] = [ui, vi] 表示节点 ui 和 vi 在树中有一条边。

一开始,所有 节点都 未标记 。对于节点 i :

当 i 是奇数时,如果时刻 x - 1 该节点有 至少 一个相邻节点已经被标记了,那么节点 i 会在时刻 x 被标记。
当 i 是偶数时,如果时刻 x - 2 该节点有 至少 一个相邻节点已经被标记了,那么节点 i 会在时刻 x 被标记。
请你返回一个数组 times ,表示如果你在时刻 t = 0 标记节点 i ,那么时刻 times[i] 时,树中所有节点都会被标记。

请注意,每个 times[i] 的答案都是独立的,即当你标记节点 i 时,所有其他节点都未标记。

思路

https://www.bilibili.com/video/BV1F4421S7XU/?vd_source=c2be79bc3abc8c9584470d3fed5d994e
https://leetcode.cn/problems/time-taken-to-mark-all-nodes/solutions/2868276/di-er-lei-huan-gen-dppythonjavacgo-by-en-411w/

问题抽象为:指向奇数索引点的边权值为1,指向偶数索引点的边权值为2,现在在图中从u到v的最短距离就变成了从u到v传播所用的最短时间。
核心问题则变为:求从树上一个点出发到其他所有点最短距离的最大值。

维护子树 x 的最大深度 max_d,次大深度 max_d2,以及最大深度要往儿子 my 走

  • 子树 x 的最大深度。
  • x 往上走到某个节点再往下拐弯的路径最大长度。(即from_up,这个就可以通过父节点的信息求出,然后带到子节点)
    • up1 = from_up + w 往上走不拐弯
    • up2 = (max_d2 if y == my else max_d) + w 在父节点拐弯,如果自己在树的最大深度的子树,那么就取次打深

代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
'''
Author: NEFU AB-IN
Date: 2024-08-07 15:26:32
FilePath: \LeetCode\3241\3241.py
LastEditTime: 2024-08-07 15:41:28
'''
# 3.8.19 import
import random
from collections import Counter, defaultdict, deque
from datetime import datetime, timedelta
from functools import lru_cache, reduce
from heapq import heapify, heappop, heappush, nlargest, nsmallest
from itertools import combinations, compress, permutations, starmap, tee
from math import ceil, comb, fabs, floor, gcd, hypot, log, perm, sqrt
from string import ascii_lowercase, ascii_uppercase
from sys import exit, setrecursionlimit, stdin
from typing import Any, Callable, Dict, List, Optional, Tuple, TypeVar, Union

# Constants
TYPE = TypeVar('TYPE')
N = int(2e5 + 10)
M = int(20)
INF = int(1e12)
OFFSET = int(100)
MOD = int(1e9 + 7)

# Set recursion limit
setrecursionlimit(int(2e9))


class Arr:
array = staticmethod(lambda x=0, size=N: [x() if callable(x) else x for _ in range(size)])
array2d = staticmethod(lambda x=0, rows=N, cols=M: [Arr.array(x, cols) for _ in range(rows)])
graph = staticmethod(lambda size=N: [[] for _ in range(size)])


class Math:
max = staticmethod(lambda a, b: a if a > b else b)
min = staticmethod(lambda a, b: a if a < b else b)


class IO:
input = staticmethod(lambda: stdin.readline().rstrip("\r\n"))
read = staticmethod(lambda: map(int, IO.input().split()))
read_list = staticmethod(lambda: list(IO.read()))


class Std:
pass

# ————————————————————— Division line ——————————————————————


class Solution:
def timeTaken(self, edges: List[List[int]]) -> List[int]:
n = len(edges) + 1
g = Arr.graph(n)
for x, y in edges:
g[x].append(y)
g[y].append(x)

def weight(x): return 2 if x % 2 == 0 else 1

# nodes[x] 保存子树 x 的最大深度 max_d,次大深度 max_d2,以及最大深度要往儿子 my 走
nodes = Arr.array(None, n)

def dfs(x: int, fa: int) -> int:
max_d = max_d2 = my = 0
for y in g[x]:
if y == fa:
continue
depth = dfs(y, x) + weight(y) # 从 x 出发,往 my 方向的最大深度
if depth > max_d:
max_d2 = max_d
max_d = depth
my = y
elif depth > max_d2:
max_d2 = depth
nodes[x] = (max_d, max_d2, my)
return max_d
dfs(0, -1)

ans = [0] * len(g)

def reroot(x: int, fa: int, from_up: int) -> None:
max_d, max_d2, my = nodes[x]
ans[x] = max(from_up, max_d)
w = weight(x) # 从 y 到 x 的边权
for y in g[x]:
if y != fa:
up1 = from_up + w
up2 = (max_d2 if y == my else max_d) + w
reroot(y, x, Math.max(up1, up2))
reroot(0, -1, 0)
return ans

使用搜索:谷歌必应百度