1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123
| import random from collections import Counter, defaultdict, deque from datetime import datetime, timedelta from functools import lru_cache from heapq import heapify, heappop, heappush, nlargest, nsmallest from itertools import combinations, compress, permutations, starmap, tee from math import ceil, comb, fabs, floor, gcd, hypot, log, perm, sqrt from string import ascii_lowercase, ascii_uppercase from sys import exit, setrecursionlimit, stdin from typing import Any, Dict, List, Optional, Tuple, TypeVar, Union
TYPE = TypeVar('TYPE') N = int(2e5 + 10) M = int(20) INF = int(1e12) OFFSET = int(100) MOD = int(1e9 + 7)
setrecursionlimit(int(2e9))
class Arr: array = staticmethod(lambda x=0, size=N: [x() if callable(x) else x for _ in range(size)]) array2d = staticmethod(lambda x=0, rows=N, cols=M: [Arr.array(x, cols) for _ in range(rows)]) graph = staticmethod(lambda size=N: [[] for _ in range(size)])
class Math: max = staticmethod(lambda a, b: a if a > b else b) min = staticmethod(lambda a, b: a if a < b else b)
class IO: input = staticmethod(lambda: stdin.readline().rstrip("\r\n")) read = staticmethod(lambda: map(int, IO.input().split())) read_list = staticmethod(lambda: list(IO.read()))
class Std: class LCA: """Useing TreeAncestor calculate LCA"""
def __init__(self, edges: List[List[int]]): n = len(edges) + 1 m = n.bit_length() g = Arr.graph(n) cnt_ = Arr.array2d(Counter, n, m)
for x, y, w in edges: g[x].append((y, w)) g[y].append((x, w))
depth = Arr.array(0, n) pa = Arr.array2d(-1, n, m)
def dfs(x: int, fa: int) -> None: """Depth-first search to initialize the ancestor table and depth array.""" pa[x][0] = fa for y, w in g[x]: if y != fa: depth[y] = depth[x] + 1 cnt_[y][0][w] = 1 dfs(y, x) dfs(0, -1)
for i in range(m - 1): for x in range(n): p = pa[x][i] if p != -1: pp = pa[p][i] pa[x][i + 1] = pp cnt_[x][i + 1] = cnt_[x][i] + cnt_[p][i] self.depth = depth self.pa = pa self.cnt_ = cnt_
def get_kth_ancestor(self, node: int, k: int, cnt_: Counter) -> int: """Returns the k-th ancestor of the given node (The starting node). That is, jump up k steps""" for i in range(k.bit_length()): if (k >> i) & 1: node_new = self.pa[node][i] cnt_ += self.cnt_[node][i] node = node_new return node
def get_lca(self, x: int, y: int): """Returns the Lowest Common Ancestor (LCA) of nodes x and y.""" cnt_ = Counter() if self.depth[x] > self.depth[y]: x, y = y, x y = self.get_kth_ancestor(y, self.depth[y] - self.depth[x], cnt_) if y == x: return x, max(cnt_.values()) if cnt_ else 0 for i in range(len(self.pa[x]) - 1, -1, -1): px, py = self.pa[x][i], self.pa[y][i] if px != py: cnt_ += self.cnt_[x][i] cnt_ += self.cnt_[y][i] x, y = px, py
cnt_ += self.cnt_[x][0] cnt_ += self.cnt_[y][0] return self.pa[x][0], max(cnt_.values()) if cnt_ else 0
class Solution: def minOperationsQueries(self, n: int, edges: List[List[int]], queries: List[List[int]]) -> List[int]: lca = Std.LCA(edges) ans = [] for a, b in queries: path_len = lca.depth[a] + lca.depth[b] lca_ans, max_len = lca.get_lca(a, b) path_len -= lca.depth[lca_ans] * 2 ans.append(path_len - max_len) return ans
Solution().minOperationsQueries(7, [[0, 1, 1], [1, 2, 1], [2, 3, 1], [3, 4, 2], [4, 5, 2], [5, 6, 2]], [[0, 3], [3, 6], [2, 6], [0, 6]])
|