Link

Medium Binary Search Tree

Amazon

Apple

Bloomberg

Facebook

Oracle

TripleByte

2020-05-20

230. Kth Smallest Element in a BST

Question:

Given a binary search tree, write a function kthSmallest to find the kth smallest element in it.

Example 1:

Input: root = [5,3,6,2,4,null,null,1], k = 3
       5
      / \
     3   6
    / \
   2   4
  /
 1
Output: 3

Example 2:

Input: root = [3,1,4,null,2], k = 1
   3
  / \
 1   4
  \
   2
Output: 1

Follow up:
What if the BST is modified (insert/delete operations) often and you need to find the kth smallest frequently? How would you optimize the kthSmallest routine?

Constraints:

  • The number of elements of the BST is between 1 to 10^4.
  • You may assume k is always valid, 1 ≤ k ≤ BST's total elements.

Solution:

Using Inorder Traversal to count the order of the number.

/**
 * Definition for a binary tree node.
 * public class TreeNode {
 *     int val;
 *     TreeNode left;
 *     TreeNode right;
 *     TreeNode() {}
 *     TreeNode(int val) { this.val = val; }
 *     TreeNode(int val, TreeNode left, TreeNode right) {
 *         this.val = val;
 *         this.left = left;
 *         this.right = right;
 *     }
 * }
 */
class Solution {
    
    int target;
    
    public int kthSmallest(TreeNode root, int k) {
        target = k;
        return inOrder(root);
    }
    
    public int inOrder(TreeNode root) {
        if (root == null) {
            return 0;
        }
        
        int left = inOrder(root.left);
        
        if (left + 1 == target) {
            
        }
            
        int right = inOrder(root.right);
        
        return left + 1 + right;
        
    }
    
}

It can also be achieved by using the stack.

/**
 * Definition for a binary tree node.
 * public class TreeNode {
 *     int val;
 *     TreeNode left;
 *     TreeNode right;
 *     TreeNode() {}
 *     TreeNode(int val) { this.val = val; }
 *     TreeNode(int val, TreeNode left, TreeNode right) {
 *         this.val = val;
 *         this.left = left;
 *         this.right = right;
 *     }
 * }
 */
class Solution {
    public int kthSmallest(TreeNode root, int k) {
        Stack<TreeNode> stack = new Stack<>();
        
        while(root != null) {
            stack.push(root);
            root = root.left;
        }
        
        
        for(int i = 0; i < k - 1; i++) {
            TreeNode curr = stack.peek();
            
            
            if(curr.right == null) {
                stack.pop();
                // Go to the first node that has the left children
                while (!stack.isEmpty() && stack.peek().right == curr) {
                    curr = stack.pop();
                }
            } else {
                // Add the middle node
                curr = curr.right;
                // Add all the children on the left
                while (curr != null) {
                    stack.push(curr);
                    curr = curr.left;
                }
            }
        }
        return stack.peek().val;
    }
    
}