Printing brackets in Matrix Chain Multiplication Problem

Difficulty Level Hard
Frequently asked in Amazon Avalara Citadel Databricks Directi JP Morgan Paytm Twilio
Array Dynamic Programming MatrixViews 4928

Problem Statement

We need to find the order of multiplication of matrices such that the number of operations involved in the multiplication of all the matrices is minimized. Then we need to print this order i.e. printing brackets in matrix chain multiplication problem.

Consider you have 3 matrices A, B, C of sizes a x b, b x c, c xd respectively. There are two possible cases to multiply the matrices. Either we first multiply A and B and then C or we first multiply B and C and then  A. The first choice will result in the cost of (a x b x c + a x c x d) while the second option takes a total of ( b x c x d + a x b x d ) operations. We can see that the number of operations involved are dependent on the order of multiplication of matrices. Thus, we need to find the order of multiplication of matrices which results in a minimum number of operations. 

Example

number of matrices = 3

Dimensions of matrices = {1, 2, 3, 4}
((AB)C)

 

Explanation: First we multiply matrices with dimensions 1 x 2 and 2 x 3, which takes the cost of 6 operations. Then we multiply matrix C with the resultant matrix from the multiplication of A and B. This operation again takes 1 x 3 x 4 making a total of 18 operations.

number of matrices = 2

Dimensions of matrices = {10, 10, 20}
(AB)

Explanation: Since there are only two matrices there is only a single way of multiplying matrices which takes a total of 2000 operations.

Approach for printing brackets in matrix chain multiplication problem

We have already solved the Matrix Chain Multiplication problem where we needed to find the minimum number of operations involved in the multiplication of all the matrices. Matrix Chain Multiplication using dynamic programming is a prerequisite for this problem. Making just small modifications in the matrix chain multiplication problem can print the brackets. We make a brackets matrix, in which brackets[i][j] stores the optimal index. An index is optimal for indices i, j if before and after of which, all the matrices in boundary [i, j]  are multiplied. Afterward their resultant is multiplied which gives us the minimum number of operations required. This just means we first find result for matrices i to optimal index and matrices from optimal index + 1 to j and then combine their result.

We use the brackets matrix to print the brackets around the matrix. We keep on dividing our problem into sub-problems until we have a single matrix and then print it.

Printing brackets in Matrix Chain Multiplication Problem

Code for printing brackets in Matrix Chain Multiplication problem

C++ Code

#include <bits/stdc++.h>
using namespace std;

// Recursively print the arrangement for minimum cost of multiplication
void printBracketsMatrixChain(int i, int j, vector<vector<int>> &brackets, char &cur_name){
    
    // you have a single matrix ( you cannot further reduce the problem, so print the matrix )
    if(i == j){
        cout<<cur_name;
        cur_name++;
    } else {
        cout<<"(";
        
        // Reduce the problem into left sub-problem ( left of optimal arrangement )
        printBracketsMatrixChain(i, brackets[i][j], brackets, cur_name);
        
        // Reduce the problem into right sub-problem ( right of optimal arrangement )
        printBracketsMatrixChain(brackets[i][j]+1, j, brackets, cur_name);
        cout<<")";
    }
}

void matrixMultiplicationProblem(vector<int> matrixSize) {
    int numberOfMatrices = matrixSize.size()-1;

    // dp[i][j] = minimum number of operations required to multiply matrices i to j
    int dp[numberOfMatrices][numberOfMatrices];

    // initialising dp array with INT_MAX ( maximum number of operations )
    for(int i=0;i<numberOfMatrices;i++){
        for(int j=0;j<numberOfMatrices;j++){
            dp[i][j] = INT_MAX;
            if(i == j) // for a single matrix from i to i, cost = 0
                dp[i][j] = 0;
        }
    }

    vector<vector<int>> brackets(numberOfMatrices, vector<int>(numberOfMatrices, 0));
    for(int len=2;len<=numberOfMatrices;len++){
        for(int i=0;i<numberOfMatrices-len+1;i++){
            int j = i+len-1;
            for(int k=i;k<j;k++) {
                int val = dp[i][k]+dp[k+1][j]+(matrixSize[i]*matrixSize[k+1]*matrixSize[j+1]);
                if(val < dp[i][j]) {
                    dp[i][j] = val;
                    brackets[i][j] = k;
                }
            }
        }
    }

    // naming the first matrix as A
    char cur_name = 'A';
    
    // calling function to print brackets
    printBracketsMatrixChain(0, numberOfMatrices-1, brackets, cur_name);
    cout<<endl;
}

int main() {
    int t;cin>>t;
    while(t--) {
        int n; cin>>n;
        vector<int> matrixSize(n);
        for(int i=0;i<n;i++)cin>>matrixSize[i];
        matrixMultiplicationProblem(matrixSize);
    }
}
2

5 // number of inputs = dimensions of n-1 matrices

5 6 89 49 10 // dimensions of ith matrix = matrixSize[i]*matrixSize[i+1]

7

1 5 2 3 4 1000 64
(((AB)C)D)
(((((AB)C)D)E)F)

Java Code

import java.util.*;
import java.lang.*;
import java.io.*;

class Main {
    
    static char cur_name;
    
    // Recursively print the arrangement for minimum cost of multiplication
    static void printBracketsMatrixChain(int i, int j, int brackets[][]){
        
        // you have a single matrix ( you cannot further reduce the problem, so print the matrix )
        if(i == j){
            System.out.print(cur_name);
            cur_name++;
        } else {
            System.out.print("(");
            
            // Reduce the problem into left sub-problem ( left of optimal arrangement )
            printBracketsMatrixChain(i, brackets[i][j], brackets);
            
            // Reduce the problem into right sub-problem ( right of optimal arrangement )
            printBracketsMatrixChain(brackets[i][j]+1, j, brackets);
            System.out.print(")");
        }
    }

    static void matrixMultiplicationProblem(int matrixSize[]) {
        int numberOfMatrices = matrixSize.length-1;
    
        // dp[i][j] = minimum number of operations required to multiply matrices i to j
        int dp[][] = new int[numberOfMatrices][numberOfMatrices];
    
        // initialising dp array with Integer.MAX_VALUE ( maximum number of operations )
        for(int i=0;i<numberOfMatrices;i++){
            for(int j=0;j<numberOfMatrices;j++){
                dp[i][j] = Integer.MAX_VALUE;
                if(i == j) // for a single matrix from i to i, cost = 0
                    dp[i][j] = 0;
            }
        }
    
        int brackets[][] = new int[numberOfMatrices][numberOfMatrices];
        for(int len=2;len<=numberOfMatrices;len++){
            for(int i=0;i<numberOfMatrices-len+1;i++){
                int j = i+len-1;
                for(int k=i;k<j;k++) {
                    int val = dp[i][k]+dp[k+1][j]+(matrixSize[i]*matrixSize[k+1]*matrixSize[j+1]);
                    if(val < dp[i][j]) {
                        dp[i][j] = val;
                        brackets[i][j] = k;
                    }
                }
            }
        }
    
        // naming the first matrix as A
        cur_name = 'A';
        
        // calling function to print brackets
        printBracketsMatrixChain(0, numberOfMatrices-1, brackets);
        System.out.println();
    }

    public static void main(String[] args) {
        Scanner sc = new Scanner(System.in);
        int t = sc.nextInt();
        while(t-- > 0) {
            int n = sc.nextInt();
            int matrixSize[] = new int[n];
            for(int i=0;i<n;i++) matrixSize[i] = sc.nextInt();
            matrixMultiplicationProblem(matrixSize);
        }
    }

}

 

2

5 // number of inputs = dimensions of n-1 matrices

5 6 89 49 10 // dimensions of ith matrix = matrixSize[i]*matrixSize[i+1]

7

1 5 2 3 4 1000 64
(((AB)C)D) 
(((((AB)C)D)E)F)

 

Complexity Analysis

Time Complexity: O(N^3)

Here, we are considering two pointers i and j which are acting as bounds for matrices that run in O(N^2). The nested loop inside the outer loops itself takes linear time O(N). So it makes the algorithm runs in O(N^3) in total.

Space Complexity: O(N^2)

We have a polynomial space complexity of O(N^2) because we have a 2D DP array.

Translate »