In K-th Smallest Element in a Sorted Matrix problem, we have given an n x n matrix, where every row and column is sorted in non-decreasing order. Find the kth smallest element in the given 2D array.
Table of Contents
Example
Input 1:
k = 3 and
matrix =
11, 21, 31, 41
16, 26, 36, 46
25, 30, 38, 49
33, 34, 40, 50
Ouput: 21
Explanation: The 3rd smallest element = 21
Input 2:
k = 7 and
matrix =
12, 23, 31
17, 27, 36
24, 29, 38
Ouput: 31
Explanation: The 7th smallest element is 30
Types of Solution for K-th Smallest Element in a Sorted Matrix
- Naive/Brute Force
- using min Heap data structure
- Binary Search
Naive/Brute Force
Approach
This approach involves simply adding all the elements of the matrix into an array, sorting the array and returning the k-th smallest number.
Algorithm for K-th Smallest Element in a Sorted Matrix
- given a matrix mat[][] of dimensions n x n, create a linear array arr[] of size n*n.
- Add all the elements of mat into the arr.
- Sort arr and return arr[k-1] (k-th largest element).
Implementation for K-th Smallest Element in a Sorted Matrix
C++ Program
#include <iostream>
#include <bits/stdc++.h>
using namespace std;
// function to find k-th largest element
int kthSmallest(vector < vector<int> > mat,int k)
{
int n = mat.size();
if(k > n*n)
return -1;
// smallest element is the first element of the matrix
if(k == 1)
return mat[0][0];
// define array and push contents of the matrix into it
vector <int> arr;
for(int i=0;i<n;i++)
for(int j=0;j<n;j++)
arr.push_back(mat[i][j]);
// sort the array and obtain k-th smallest element
sort(arr.begin(),arr.end());
return arr[k-1];
}
int main()
{
vector< vector<int> > mat {
{11, 21, 31, 41},
{16, 26, 36, 46},
{25, 30, 38, 49},
{33, 34, 40, 50}
};
int k = 3;
int kthsmall = kthSmallest(mat,k);
if(kthsmall == -1)
cout<<"3rd smallest element doesn't exist.";
else
cout<<"3rd smallest element = "<<kthsmall<<endl;
return 0;
}3rd smallest element = 21
Java Program
import java.io.*;
import java.util.*;
class tutorialCup
{
// function to find k-th largest element
static int kthSmallest(ArrayList<ArrayList<Integer>> mat,int k)
{
int n = mat.size();
if(k > n*n)
return -1;
// smallest element is the first element of the matrix
if(k == 1)
return mat.get(0).get(0);
// define array and push contents of the matrix into it
ArrayList <Integer> arr = new ArrayList <Integer>();
for(int i=0;i<n;i++)
for(int j=0;j<n;j++)
arr.add(mat.get(i).get(j));
// sort the array and obtain k-th smallest element
Collections.sort(arr);
return arr.get(k-1);
}
public static void main (String[] args)
{
ArrayList<ArrayList<Integer>> mat = new ArrayList<ArrayList<Integer>>();
mat.add(new ArrayList<Integer>(Arrays.asList(11, 21, 31, 41)));
mat.add(new ArrayList<Integer>(Arrays.asList(16, 26, 36, 46)));
mat.add(new ArrayList<Integer>(Arrays.asList(25, 30, 38, 49)));
mat.add(new ArrayList<Integer>(Arrays.asList(33, 34, 40, 50)));
int k = 3;
int kthsmall = kthSmallest(mat,k);
if(kthsmall == -1)
System.out.println("3rd smallest element doesn't exist.");
else
System.out.println("3rd smallest element = "+kthsmall);
}
}3rd smallest element = 21
Complexity Analysis
- Time Complexity : T(n) = O(n2)
- Space Complexity : A(n) = O(n2)
Using min Heap data structure
Approach
we use heap data structure(priority queue/min heap) to store the elements of the first row of the matrix, then one by one pop elements from the heap and add the element from the matrix lying just below the last popped element. This is done until k-steps obtain k-th largest number. The algorithm is discussed below :
Algorithm for K-th Smallest Element in a Sorted Matrix
- deal with the base cases explicitly(n = 0,1).
- define a min-heap pq.
- add the elements from the first row of mat[][] (input matrix) into pq.
- Now, run a loop k-times. the loop variable is i, inside that loop perform steps below :
- pop an element from pq and store the popped element into kthsmall.
- add the element lying in mat[][] just below kthsmall into pq.
- After the loop ends, return the last element value stored in kthsmall (ie kthsmall.num).


Implementation for K-th Smallest Element in a Sorted Matrix
C++ Program
#include <iostream>
#include <bits/stdc++.h>
using namespace std;
// define a strucure
// to store a particular value from the matrix
// it's row and column number
struct element
{
int num;
int i;
int j;
// define constructors
element(){}
element(int element,int rownum,int colnum)
{
num = element;
i = rownum;
j = colnum;
}
};
// this is an strucure which implements the
// operator overlading
// this will help us in creating a heap for storing
// values of element(struct) data type
struct compare
{
bool operator()(element const& p1, element const& p2)
{
return p1.num > p2.num;
}
};
// function to find k-th largest element
int kthSmallest(vector < vector<int> > mat,int k)
{
int n = mat.size();
if(k > n*n)
return -1;
// smallest element is the first element of the matrix
if(k == 1)
return mat[0][0];
// define a min Heap
priority_queue <struct element,vector <element>,compare > pq;
// push all the elements of first row
// along with their row and column numbers
// into a min heap
for(int i=0;i<n;i++)
{
struct element curr(mat[0][i],0,i);
pq.push(curr);
}
// define a variable to process each value stored in min heap
struct element kthsmall;
// pop from heap for k steps
for(int i=0;i<k;i++)
{
kthsmall = pq.top();
pq.pop();
// if last popped element lies in the last row\
// dont need to push any number
if(kthsmall.i+1 >= n)
{
struct element curr(INT_MAX,-1,-1);
pq.push(curr);
}
/* push the element below (element in the same column as the element itself)
the last element popped from the minheap*/
else
{
struct element curr(mat[kthsmall.i + 1][kthsmall.j],kthsmall.i+1,kthsmall.j);
pq.push(curr);
}
}
// return the element last popped from the heap pq
return kthsmall.num;
}
int main()
{
vector< vector<int> > mat {
{11, 21, 31, 41},
{16, 26, 36, 46},
{25, 30, 38, 49},
{33, 34, 40, 50}
};
int k = 3;
int kthsmall = kthSmallest(mat,k);
if(kthsmall == -1)
cout<<"3rd smallest element doesn't exist.";
else
cout<<"3rd smallest element = "<<kthsmall<<endl;
return 0;
}3rd smallest element = 21
Java Program
import java.io.*;
import java.util.*;
/*
define a strucure
to store a particular value from the matrix
it's row and column number
*/
public class element
{
public int num;
public int i;
public int j;
public element(){}
public element(int element,int rownum,int colnum)
{
num = element;
i = rownum;
j = colnum;
}
public int numReturn()
{
return num;
}
}
class tutorialCup
{
// function to find k-th largest element
static int kthSmallest(ArrayList<ArrayList<Integer>> mat,int k)
{
int n = mat.size();
if(k > n*n)
return -1;
// smallest element is the first element of the matrix
if(k == 1)
return mat.get(0).get(0);
// Comparator forms heap using num attribute of object of element type
// sorts the heap using obj.num value
// where obj is an object of element type
Comparator<element> sorter = Comparator.comparing(element::numReturn);
// define a min Heap
PriorityQueue <element> pq = new PriorityQueue<element>(sorter);
// push all the elements of first row
// along with their row and column numbers
// into a min heap
for(int i=0;i<n;i++)
{
element curr = new element(mat.get(0).get(i),0,i);
pq.add(curr);
}
// define a variable to process each value stored in min heap
element kthsmall = new element();
// pop from heap for k steps
for(int i=0;i<k;i++)
{
kthsmall = pq.peek();
pq.poll();
// if last popped element lies in the last row\
// dont need to push any number
if(kthsmall.i+1 >= n)
{
element curr = new element(Integer.MAX_VALUE,-1,-1);
pq.add(curr);
}
/* push the element below (element in the same column as the element itself)
the last element popped from the minheap*/
else
{
element curr = new element(mat.get(kthsmall.i + 1).get(kthsmall.j),kthsmall.i+1,kthsmall.j);
pq.add(curr);
}
}
// return the element last popped from the heap pq
return kthsmall.num;
}
public static void main (String[] args)
{
ArrayList<ArrayList<Integer>> mat = new ArrayList<ArrayList<Integer>>();
mat.add(new ArrayList<Integer>(Arrays.asList(11, 21, 31, 41)));
mat.add(new ArrayList<Integer>(Arrays.asList(16, 26, 36, 46)));
mat.add(new ArrayList<Integer>(Arrays.asList(25, 30, 38, 49)));
mat.add(new ArrayList<Integer>(Arrays.asList(33, 34, 40, 50)));
int k = 3;
int kthsmall = kthSmallest(mat,k);
if(kthsmall == -1)
System.out.println("3rd smallest element doesn't exist.");
else
System.out.println("3rd smallest element = "+kthsmall);
}
}3rd smallest element = 21
Complexity Analysis
- Time Complexity : T(n) = O(n + klogn)
- Space Complexity : A(n) = O(n)
Binary Search
Approach
We can Apply Binary Search on the matrix mat[][], however, this is different from a conventional binary search as in this case we consider “number range”(matrix values itself) instead of the “index range”. As we know that the smallest number of our matrix is at the top left corner and the biggest number is at the bottom lower corner. These two numbers can represent the “range” i.e., the lo and the hi values for the Binary Search. below we explain how we implement our algorithm.
Algorithm for K-th Smallest Element in a Sorted Matrix
- Start the Binary Search with
lo = mat[0][0]andhi = mat[n-1][n-1]. - Find
midof theloand thehi. Thismiddlenumber is NOT necessarily an element in the matrix. - Count all the numbers smaller than or equal to
midin the matrix. As the matrix is sorted, we can do this in O(N). - If the count is less than ‘K’, we can update
lo = mid+1to search in the higher part of the matrix - Else if the count is greater than or equal to ‘K’, we can update
hi = midto search in the lower part of the matrix in the next iteration. - At some point lo and hi become equal in values and the loop terminates. lo contains the value of k-th smallest element as we iteratively narrowed down our search range to a single matrix element.
Implementation for K-th Smallest Element in a Sorted Matrix
C++ Program
#include <iostream>
#include <bits/stdc++.h>
using namespace std;
// function to find k-th largest element
int kthSmallest(vector < vector<int> > mat,int k)
{
int n = mat.size();
if(k == 0 || k > n*n || n == 0)
return -1;
if(k == 1)
return mat[0][0];
// define smallest and largest element in the matrix as
// lower range &
// upper range
int lo = mat[0][0];
int hi = mat[n-1][n-1];
// perform binary search (by value)
// between smallest(top-left) and largest(bottom-down) values
while (lo < hi)
{
int mid = lo + (hi - lo) / 2;
int count = 0;
int j = n - 1;
// find out how many numbers are greater than mid
// between lo and hi
for (int i = 0; i < n; i++)
{
while (j >= 0 && mat[i][j] > mid)
j--;
count += (j + 1);
}
// if number of such element is less than k
// narrow the search range by increasing lo value
if (count < k)
lo = mid + 1;
// if number of such element is greater or equal to k
// narrow the search range by decreasing hi value
else
hi = mid;
}
// after iteratively narrowing the search range
// you narrow down to a single element in the matrix
// which is k-th smallest element
return lo;
}
int main()
{
vector< vector<int> > mat {
{11, 21, 31, 41},
{16, 26, 36, 46},
{25, 30, 38, 49},
{33, 34, 40, 50}
};
int k = 3;
int kthsmall = kthSmallest(mat,k);
if(kthsmall == -1)
cout<<"3rd smallest element doesn't exist.";
else
cout<<"3rd smallest element = "<<kthsmall<<endl;
return 0;
}3rd smallest element = 21
Java Program
import java.io.*;
import java.util.*;
class tutorialCup
{
// function to find k-th largest element
// this function performs a binary search by value
// on the given matrix
static int kthSmallest(ArrayList<ArrayList<Integer>> mat,int k)
{
int n = mat.size();
if(k == 0 || k > n*n || n == 0)
return -1;
if(k == 1)
return mat.get(0).get(0);
// define smallest and largest element in the matrix as
// lower range &
// upper range
int lo = mat.get(0).get(0);
int hi = mat.get(n-1).get(n-1);
// perform binary search (by value)
// between smallest(top-left) and largest(bottom-down) values
while (lo < hi)
{
int mid = lo + (hi - lo) / 2;
int count = 0;
int j = n - 1;
// find out how many numbers are greater than mid
// between lo and hi
for (int i = 0; i < n; i++)
{
while (j >= 0 && mat.get(i).get(j) > mid)
j--;
count += (j + 1);
}
// if number of such element is less than k
// narrow the search range by increasing lo value
if (count < k)
lo = mid + 1;
// if number of such element is greater or equal to k
// narrow the search range by decreasing hi value
else
hi = mid;
}
// after iteratively narrowing the search range
// you narrow down to a single element in the matrix
// which is k-th smallest element
return lo;
}
public static void main (String[] args)
{
ArrayList<ArrayList<Integer>> mat = new ArrayList<ArrayList<Integer>>();
mat.add(new ArrayList<Integer>(Arrays.asList(11, 21, 31, 41)));
mat.add(new ArrayList<Integer>(Arrays.asList(16, 26, 36, 46)));
mat.add(new ArrayList<Integer>(Arrays.asList(25, 30, 38, 49)));
mat.add(new ArrayList<Integer>(Arrays.asList(33, 34, 40, 50)));
int k = 3;
int kthsmall = kthSmallest(mat,k);
if(kthsmall == -1)
System.out.println("3rd smallest element doesn't exist.");
else
System.out.println("3rd smallest element = "+kthsmall);
}
}3rd smallest element = 21
Complexity Analysis
- Time Complexity : T(n) = O(nlog(max-min)) = O(nlog(const)) = O(n), here max-min is the difference between largest and smallest values in the matrix.Since both are integers, the maximum difference can’t be more than INT_MAX, so essentially max-min becomes constant.
- Space Complexity : A(n) = O(1)