import java.lang.*; import java.util.*; import node.*; class RemoveNthNode { public static ListNode removeNthFromEnd(ListNode head, int n) { if (head.next == null && n == 1) return head.next; ListNode prev = head; ListNode end = head; while (n > 0 && end.next != null) { end = end.next; n--; } while (end.next != null) { prev = prev.next; end = end.next; } prev.next = prev.next.next; return head; } public static void main(String[] args) { ListNode fourth = new ListNode(5, null); ListNode third = new ListNode(4, fourth); ListNode second = new ListNode(3, third); ListNode first = new ListNode(2, second); ListNode head = new ListNode(1, first); ListNode curr = removeNthFromEnd(head, 1); while (curr != null) { System.out.println(curr.val); curr = curr.next; } } }