diff --git a/divide_and_conquer/strassen_matrix_multiplication.py b/divide_and_conquer/strassen_matrix_multiplication.py index f529a255d2ef..418ff0f4ae2a 100644 --- a/divide_and_conquer/strassen_matrix_multiplication.py +++ b/divide_and_conquer/strassen_matrix_multiplication.py @@ -74,7 +74,21 @@ def print_matrix(matrix: list) -> None: def actual_strassen(matrix_a: list, matrix_b: list) -> list: """ Recursive function to calculate the product of two matrices, using the Strassen - Algorithm. It only supports square matrices of any size that is a power of 2. + Algorithm. + + Time complexity: + The recurrence is T(n) = 7 T(n/2) + \u0398(n^2), which solves to + T(n) = \u0398(n^{log_2 7}) \u2248 \u0398(n^{2.8074}). This is asymptotically + faster than the naive \u0398(n^3) algorithm for sufficiently large n. + + Space complexity: + Uses additional memory for temporary submatrices and padding; overall + space complexity is O(n^2). + + Notes: + This function expects square matrices whose size is a power of two. + Matrices of other sizes are handled by `strassen` which pads to the + next power of two. """ if matrix_dimensions(matrix_a) == (2, 2): return default_matrix_multiplication(matrix_a, matrix_b) @@ -106,6 +120,16 @@ def actual_strassen(matrix_a: list, matrix_b: list) -> list: def strassen(matrix1: list, matrix2: list) -> list: """ + Multiply two matrices using Strassen's divide-and-conquer algorithm. + + Time complexity: + \u0398(n^{log_2 7}) \u2248 \u0398(n^{2.8074}) + (recurrence T(n) = 7 T(n/2) + \u0398(n^2)). + + Space complexity: + O(n^2) due to padding and temporary matrices used during recursion. + + Examples: >>> strassen([[2,1,3],[3,4,6],[1,4,2],[7,6,7]], [[4,2,3,4],[2,1,1,1],[8,6,4,2]]) [[34, 23, 19, 15], [68, 46, 37, 28], [28, 18, 15, 12], [96, 62, 55, 48]] >>> strassen([[3,7,5,6,9],[1,5,3,7,8],[1,4,4,5,7]], [[2,4],[5,2],[1,7],[5,5],[7,8]])