Problem description:

Design a max stack that supports push, pop, top, peekMax and popMax.

push(x) – Push element x onto stack.
pop() – Remove the element on top of the stack and return it.
top() – Get the element on the top.
peekMax() – Retrieve the maximum element in the stack.
popMax() – Retrieve the maximum element in the stack, and remove it. If you find more than one maximum elements, only remove the top-most one.
Example 1:

1
2
3
4
5
6
7
8
9
10
MaxStack stack = new MaxStack();
stack.push(5);
stack.push(1);
stack.push(5);
stack.top(); -> 5
stack.popMax(); -> 5
stack.top(); -> 1
stack.peekMax(); -> 5
stack.pop(); -> 1
stack.top(); -> 5

Note:
-1e7 <= x <= 1e7
Number of operations won’t exceed 10000.
The last four operations won’t be called when stack is empty.

Solution:

Use two stacks to store the ordering of element and a maxStack.

For normal push, check if the value is greater than maxStack[-1]

Notice the popMax(), the element of that max item might not be stack[-1], so need to keep all element comes after current maximum(maxStack[-1]) in a tmp queue.

For example, since 1 comes after 5, it would not be pushed into maxStack at first. So once 5 is pop out, we need to push 1 back

1
2
["MaxStack","push","push","popMax","peekMax"]
[[],[5],[1],[],[]]
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
class MaxStack:

def __init__(self):
self.stack = deque()
self.maxStack = deque()

def push(self, x: int) -> None:
self.stack.append(x)
if not self.maxStack or x >= self.maxStack[-1]:
self.maxStack.append(x)

def pop(self) -> int:
if self.stack[-1] == self.maxStack[-1]:
self.maxStack.pop()
tmp = self.stack.pop()
return tmp

def top(self) -> int:
return self.stack[-1]

def peekMax(self) -> int:
return self.maxStack[-1]

def popMax(self) -> int:
q = deque() # to keep element until we find maxStack[-1] in stack
while self.stack[-1] != self.maxStack[-1]:
q.append(self.stack.pop())

self.stack.pop()
res = self.maxStack.pop()

# maxStack might be empty
while q:
if not self.maxStack or self.maxStack[-1] <= q[-1]:
self.maxStack.append(q[-1])
self.stack.append(q.pop())
return res
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
class MaxStack {
public:
/** initialize your data structure here. */

stack<int> total, maxstk;

void push(int x) {
total.push(x);
if(maxstk.empty() || x>= maxstk.top())
maxstk.push(x);
}

int pop() {
if(maxstk.top() == total.top())
maxstk.pop();

int tmp= total.top();
total.pop();
return tmp;
}

int top() {
return total.top();
}

int peekMax() {
return maxstk.top();
}

int popMax() {
stack<int> q;
while(total.top() != maxstk.top()){
q.push(total.top()); total.pop();
}
int tmp= maxstk.top();
maxstk.pop();
total.pop();
while(!q.empty()){
total.push(q.top());
if(maxstk.empty() || q.top()>= maxstk.top())
maxstk.push(q.top());
q.pop();
}
return tmp;
}
};

/**
* Your MaxStack object will be instantiated and called as such:
* MaxStack obj = new MaxStack();
* obj.push(x);
* int param_2 = obj.pop();
* int param_3 = obj.top();
* int param_4 = obj.peekMax();
* int param_5 = obj.popMax();
*/

time complexity: $O()$
space complexity: $O()$
reference:

heap to find the max in O(logN) and do removal in O(1) with the help of dict and doubly linked list.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
class DoubleLinkedList:
def __init__(self, val=None):
self.val = val
self.next = None
self.pre = None

class MaxStack:
def __init__(self):
"""
initialize your data structure here.
"""

self.stack = DoubleLinkedList(float('-inf')) # init a dummy node
self.last = self.stack # reference the stack tail
self.heap = []
self.hmap = defaultdict(list)


def push(self, x: int) -> None:
# O(logn)
node = DoubleLinkedList(x)

# update the tail
self.last.next = node
node.pre = self.last
self.last = node

# push -x to the min heap
heappush(self.heap, -x)

# append node the the map entry
self.hmap[x].append(node)

def pop(self) -> int:
# O(1)
# pop from the stack and remove from map
num = self.last.val
self.last = self.last.pre
self.last.next = None

self.hmap[num].pop()
if not self.hmap[num]:
del self.hmap[num]
return num

def top(self) -> int:
# O(1)
return self.last.val

def peekMax(self) -> int:
# O(logN)
# during the pop(), we didn't remove the element from heap
# So here is to remove the the poped elements from heap
while -self.heap[0] not in self.hmap:
heappop(self.heap)

return -self.heap[0]

def popMax(self) -> int:
# O(logN)
# get the top-most node from map
num = self.peekMax()
node = self.hmap[num].pop()
if not self.hmap[num]:
del self.hmap[num]

# update the tail reference
if node == self.last:
self.last = self.last.pre

# remove the node from stack
if node.pre:
node.pre.next = node.next
if node.next:
node.next.pre = node.pre
return num