import java.lang.*; import java.util.*; import node.*; class InvertTree { public static TreeNode invertTree(TreeNode root) { if (root.left == null || root.right == null) return root; TreeNode tmp = root.right; root.right = invertTree(root.left); root.left = invertTree(tmp); return root; } public static void main(String[] args) { TreeNode leftLeaf = new TreeNode(1); TreeNode lMiddleLeaf = new TreeNode(3); TreeNode rMiddleLeaf = new TreeNode(6); TreeNode rightLeaf = new TreeNode(9); TreeNode leftParent = new TreeNode(2, leftLeaf, lMiddleLeaf); TreeNode rightParent = new TreeNode(7, rMiddleLeaf, rightLeaf); TreeNode root = new TreeNode(4, leftParent, rightParent); System.out.println(invertTree(root).val); } }