Problem description:

Given a binary search tree, write a function kthSmallest to find the kth smallest element in it.

Note:
You may assume k is always valid, 1 ≤ k ≤ BST’s total elements.

Example 1:

Input: root = [3,1,4,null,2], k = 1

1
2
3
4
5
  3
/ \
1 4
\
2

Output: 1
Example 2:

Input: root = [5,3,6,2,4,null,null,1], k = 3

1
2
3
4
5
6
7
      5
/ \
3 6
/ \
2 4
/
1

Output: 3
Follow up:
What if the BST is modified (insert/delete operations) often and you need to find the kth smallest frequently? How would you optimize the kthSmallest routine?

Solution:

  1. use a variable to count ith smallest
  2. whenever the count is zero, return the res
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
# Definition for a binary tree node.
# class TreeNode:
# def __init__(self, val=0, left=None, right=None):
# self.val = val
# self.left = left
# self.right = right
class Solution:
def kthSmallest(self, root: TreeNode, k: int) -> int:
self.k = k
self.res = 0
def dfs(root):
if not root:
return

if root.left:
dfs(root.left)
self.k -= 1
if self.k == 0:
self.res = root.val
return
if root.right:
dfs(root.right)
dfs(root)
return self.res
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
/**
* Definition for a binary tree node.
* struct TreeNode {
* int val;
* TreeNode *left;
* TreeNode *right;
* TreeNode(int x) : val(x), left(NULL), right(NULL) {}
* };
*/
class Solution {
public:
int res= 0;
int count= 0;

int kthSmallest(TreeNode* root, int k) {
count= k;
helper(root);
return res;
}
void helper(TreeNode* root){
//inorder traversal, find the left most first
if(root->left) helper(root->left);

count--;
if(count == 0){
res= root->val;
return;
}
if(root->right) helper(root->right);
}
};

Second time:

BFS solution with inorder traversal

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
# Definition for a binary tree node.
# class TreeNode:
# def __init__(self, val=0, left=None, right=None):
# self.val = val
# self.left = left
# self.right = right
class Solution:
def kthSmallest(self, root: TreeNode, k: int) -> int:
self.k = k
self.res = 0
# BFS
stack = deque()
while root or stack:
while root:
stack.append(root)
root = root.left
top = stack.pop()
self.k -= 1
if self.k == 0:
return top.val
root = top.right
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
/**
* Definition for a binary tree node.
* struct TreeNode {
* int val;
* TreeNode *left;
* TreeNode *right;
* TreeNode(int x) : val(x), left(NULL), right(NULL) {}
* };
*/
class Solution {
public:
int kthSmallest(TreeNode* root, int k) {
stack<TreeNode*> q;
TreeNode* cur= root;
while(cur || !q.empty()){
while(cur){
q.push(cur);
cur= cur->left;
}

cur= q.top(); q.pop();
k--;
if(k == 0)
return cur->val;
cur= cur->right;
}
}
};

reference:
https://goo.gl/Ed7PGv