Problem description:

Implement an iterator over a binary search tree (BST). Your iterator will be initialized with the root node of a BST.

Calling next() will return the next smallest number in the BST.

Note: next() and hasNext() should run in average O(1) time and uses O(h) memory, where h is the height of the tree.

Solution:

  1. we need to use O(1) time to find the smallest element, so we can use traverse the tree toward left, push the element into a stack, which will have the root in the bottom, smallest element(left most) element on the top.
  2. One thing to notice is that, when we call hasNext(), we need to check whether if the current pop out root has a right subtree.
    For example:
    1
    2
    3
    4
    5
        3
    / \
    1 4
    / \
    0 2

In above graph, we can see that 2 is smaller 3. But we need to remember that we only push [3,1,0] into the stack. If we did not check for 1's right subtree, we will get an incorrect next smaller element.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
class BSTIterator:

def __init__(self, root: TreeNode):
self.stk = []
while root:
self.stk.append(root)
root = root.left

def next(self) -> int:
node = self.stk.pop()
if node.right:
p = node.right
while p:
self.stk.append(p)
p = p.left
return node.val

def hasNext(self) -> bool:
return self.stk
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
/**
* Definition for binary tree
* struct TreeNode {
* int val;
* TreeNode *left;
* TreeNode *right;
* TreeNode(int x) : val(x), left(NULL), right(NULL) {}
* };
*/
class BSTIterator {
private:
stack<TreeNode*> s;
public:
BSTIterator(TreeNode *root) {
while(root){
s.push(root);
root = root->left;
}
}

/** @return whether we have a next smallest number */
bool hasNext() {
return !s.empty();
}

/** @return the next smallest number */
int next() {
TreeNode* n = s.top();
s.pop();
int res = n->val;
if (n->right) { //Currently, n is the left most node, so we need to check if there's any right subtree of this n
n = n->right;
while (n) {
s.push(n);
n = n->left;
}
}
return res;
}
};

/**
* Your BSTIterator will be called like this:
* BSTIterator i = BSTIterator(root);
* while (i.hasNext()) cout << i.next();
*/

reference: