Problem description:

Merge k sorted linked lists and return it as one sorted list. Analyze and describe its complexity.

Example:

1
2
3
4
5
6
7
Input:
[
1->4->5,
1->3->4,
2->6
]
Output: 1->1->2->3->4->4->5->6

Solution1:

Use the concept of merge two list. Divide the n lists into n/2 parts and merge every 2 lists until only one list left.
example:

1
2
3
4
5
6
7
8
9
10
11
12
totally 10 lists
1 2 3 4 5 6 7 8 9 10

merge i and i+k together, which k= (n+1)/2

merge(1,6) merge(2,7), merge(3,8), merge(4,9), merge(5,10)

merge(1,4), merge(2,5), 3

merge(1,3), 2

merge(1,2)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
class Solution:
def mergeKLists(self, lists: List[ListNode]) -> ListNode:
# merge sort, divide until it's two list then merge bottom up
if not lists:
return None
if len(lists) == 1:
return lists[0]
mid = len(lists) // 2
l, r = self.mergeKLists(lists[:mid]), self.mergeKLists(lists[mid:])
return self.merge(l, r)
def merge(self, l, r):
if not l and not r:
return l
dummy = ListNode(-1)
p = dummy
while l and r:
if l.val < r.val:
p.next = l
l = l.next
else:
p.next = r
r = r.next
p = p.next
p.next = l or r
return dummy.next
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
/**
* Definition for singly-linked list.
* struct ListNode {
* int val;
* ListNode *next;
* ListNode(int x) : val(x), next(NULL) {}
* };
*/
class Solution {
public:
ListNode* mergeKLists(vector<ListNode*>& lists) {
//divide lists into half, merge every two list together until only one left
if(lists.empty()) return NULL;
int n= lists.size();
while(n>1){
int k= (n+1)/2;
for(int i= 0; i< n/2; i++){
lists[i]= mergeList(lists[i], lists[i+k]);
}
n= k;
}
return lists[0];
}

ListNode* mergeList(ListNode* l1, ListNode* l2){
ListNode* head= new ListNode(-1);
ListNode* cur= head;
while(l1 && l2){
if(l1->val < l2->val){
cur->next= l1;
l1= l1->next;
}
else{
cur->next= l2;
l2= l2->next;
}
cur= cur->next;
}

if(l1)
cur->next= l1;
if(l2)
cur->next= l2;

return head->next;
}
};

time complexity: $O(nklog(k))$, n: average length of lists, there are k lists.
firstly, merge every two list need nk/2; in the next round, the length of list becomes 2n, the number of lists becomes k/2, so the complexity is still nk/2. Keep such rounds until k == 1, that would be log(k) rounds. so the total complexity is $O(nklog(k))$
space complexity: $O(1)$
reference:
https://goo.gl/wUqWfW

Solution 2:

Use a priority_queue to implement. Put every head of the list into the priority_queue, it will sort the value automatically.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
class Solution:
def mergeKLists(self, lists: List[ListNode]) -> ListNode:
ListNode.__eq__ = lambda self, other: self.val == other.val
ListNode.__lt__ = lambda self, other: self.val < other.val
heap = []
for l in lists:
if l:
heappush(heap, l)
dummy = ListNode(-1)
p = dummy

while heap:
node = heapq.heappop(heap)
if node.next:
heapq.heappush(heap, node.next)
p.next = node
p = p.next
return dummy.next
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
struct comp{
bool operator() (ListNode* a, ListNode* b){
return a->val > b->val;
}
};

class Solution {
public:
ListNode* mergeKLists(vector<ListNode*>& lists) {
priority_queue<ListNode*, vector<ListNode*>, comp> q;
for(int i= 0; i< lists.size(); i++){
if(lists[i]) q.push(lists[i]);
}

ListNode* head= new ListNode(-1);
ListNode *cur= head, *tmp= NULL;
while(!q.empty()){
tmp= q.top(); //it's the smallest element right now
q.pop();
cur->next= tmp;
cur = cur->next;
if(tmp->next) q.push(tmp->next);
}
return head->next;
}
};

time complexity: $O(nlogk)$, n: average length of lists, there are k lists. The height of the priority_queue would be $logk$.
space complexity: $O(logk)$
reference:
https://goo.gl/nM8sHt