# Definition for a binary tree node. class TreeNode: def __init__(self, val=0, left=None, right=None): self.val = val self.left = left self.right = right class Solution: def __init__(self): self.levels = dict() def maxLevelSum(self, root): # recursively add sums for all nodes to hashmap with key as level self.traverseWithLevel(root, 1) maximum = -1 maxLvl = 1 for key, value in self.levels.items(): if value > maximum: maxLvl = key maximum = value return maxLvl def traverseWithLevel(self, tree, level): if tree is None: return if level not in self.levels: self.levels[level] = 0 self.levels[level] += tree.val self.traverseWithLevel(tree.left, level + 1) self.traverseWithLevel(tree.right, level + 1) s = Solution() print("Expected: 2") tree = TreeNode(1) tree.right = TreeNode(2) tree.left = TreeNode(3) print("Got:", s.maxLevelSum(tree))