import java.lang.*; import java.util.*; import node.*; class MaxDepthOfTree { public static int maxDepth(TreeNode root) { if (root.left == null || root.right == null) return 1; int l = maxDepth(root.left) + 1; int r = maxDepth(root.right) + 1; return l > r ? l : r; } public static void main(String[] args) { TreeNode leftLeaf = new TreeNode(9); TreeNode middleLeaf = new TreeNode(15); TreeNode rightLeaf = new TreeNode(7); TreeNode parent = new TreeNode(20, middleLeaf, rightLeaf); TreeNode root = new TreeNode(3, leftLeaf, parent); System.out.println(maxDepth(root)); } }