Skip to main content

nodesInSubTreeWithLabel

# problem 1519
from typing import List
from collections import defaultdict, Counter

class Solution:
def countSubTrees(self, n: int, edges: List[List[int]], labels: str) -> List[int]:
tree = defaultdict(list)
for s,e in edges:
tree[s].append(e)
# tree[e].append(s)

res = [0] * n

def dfs(cur_node):
# keep res as a nonlocal variable - we want it to be te same
# across recursive calls
nonlocal res

# keep a counter to count everything "below" this in the tree
count = Counter()

for neighbor in tree[cur_node]:
count += dfs(neighbor)

count[labels[cur_node]] += 1
res[cur_node] = count[labels[cur_node]]

return count

dfs(0)
return res