본문 바로가기
Quality control (Univ. Study)/Algorithm Design

알고리즘 설계 실습 - 연쇄 행렬 곱셈

by 생각하는 이상훈 2024. 5. 16.
728x90

문제

i × j 행렬과 j × k행렬을 곱하기 위해서는 i × j × k번 만큼의 곱셈이 필요합니다. 연쇄적으로 행렬 을 곱할 때, 어떤 행렬 곱셈을 먼저 수행하는지에 따라서 필요한 곱셈의 횟수는 달라지게 됩니다.

예를 들어, 크기가 1×9인 행렬 A, 크기가 9×9인 행렬 B, 크기가 9×3인 행렬 C가 주어졌을 때 행 렬의 곱 ABC를 구하는 경우, 다음과 같이 여러 방법이 존재합니다.

 

AB를 먼저 곱하고 C를 곱하는 경우 (AB)C에 필요한 곱셈 연산의 수는 1×9×9 + 1×9×3 = 81 + 27 = 108번입니다.

BC를 먼저 곱하고 A를 곱하는 경우 A(BC)에 필요한 곱셈 연산의 수는 9×9×3 + 1×9×3 = 243 + 27 = 270번입니다.

 

행렬 N개의 크기가 주어졌을 때, 모든 행렬을 곱하는데 필요한 곱셈 연산 횟수의 최솟값을 구하 고, 이때의 행렬 곱셈의 순서를 수식으로 표현하는 프로그램을 작성하세요. 입력으로 주어진 행렬 의 순서를 바꾸면 안 됩니다.

 

입력

첫째 줄에 행렬의 개수 N(1 ≤ N ≤ 500)이 주어집니다.

둘째 줄부터 N개 줄에는 행렬의 크기 i과 j가 주어집니다. (1 ≤ i, j ≤ 500)

항상 순서대로 곱셈을 할 수 있는 크기만 입력으로 주어집니다.

 

출력

첫째 줄에 입력으로 주어진 행렬을 곱하는데 필요한 곱셈 연산의 최솟값을 출력하고, 둘째 줄에 곱셈 연산의 최솟값을 만족하는 행렬 곱셈의 수식을 출력합니다.

 

입력 예시 1

3

1 9

9 9

9 3

출력 예시 1

108

((M1*M2)*M3)

 

입력 예시 2

5

4 9

9 10

10 6

6 8

8 10

출력 예시 2

1112

((((M1*M2)*M3)*M4)*M5)


풀이 코드

def matrix_chain_multiplication(dims):
    n = len(dims) - 1
    M = [[0] * (n + 1) for _ in range(n + 1)]
    P = [[0] * (n + 1) for _ in range(n + 1)]

    # 최소 곱셈 횟수 계산
    for length in range(2, n + 1):  # 연쇄 길이
        for i in range(1, n - length + 2):  # 연쇄 시작점
            j = i + length - 1  # 연쇄 끝점
            M[i][j] = float('inf')  # 무한대로 초기화
            for k in range(i, j):
                # A_i부터 A_k까지와 A_(k+1)부터 A_j까지 곱의 비용
                q = M[i][k] + M[k + 1][j] + dims[i - 1] * dims[k] * dims[j]
                if q < M[i][j]:
                    M[i][j] = q
                    P[i][j] = k

    return M, P

def construct_optimal_order(P, i, j):
    if i == j:
        # i와 j가 같다면 하나의 행렬만 나타냄
        return f"M{i}"
    else:
        # P[i][j]에서 최적의 분할 위치 k를 얻음
        k = P[i][j]
        # i부터 k까지 최적의 순서로 문자열을 구성
        left = construct_optimal_order(P, i, k)
        # k+1부터 j까지 최적의 순서로 문자열을 구성
        right = construct_optimal_order(P, k + 1, j)
        # 왼쪽과 오른쪽 결과를 괄호로 묶어서 반환
        return f"({left}*{right})"

# 입력 받기
N = int(input())  # 행렬의 개수 N을 입력받음
dims = [0] * (N + 1)  # 행렬의 차원을 저장할 배열 초기화
for i in range(1, N + 1):
    r, c = map(int, input().split())  # 행렬의 행과 열 크기를 입력받음
    if i == 1:
        dims[0] = r  # 첫 번째 행렬의 행 크기를 dims[0]에 저장
    dims[i] = c  # 각 행렬의 열 크기를 dims 배열에 저장

# 최소 곱셈 횟수와 분할 위치를 계산
M, P = matrix_chain_multiplication(dims)

# 결과 출력
print(M[1][N])  # 최소 곱셈 횟수 출력
print(construct_optimal_order(P, 1, N))  # 최적의 곱셈 순서를 문자열 형태로 출력

위 코드는 연속된 행렬 곱셈의 최적 순서를 결정하여 최소 곱셈 횟수를 계산하는 DP 방식을 구현한다. 행렬의 개수와 각 행렬의 크기를 입력받은 후, 행렬의 크기 정보를 바탕으로 최소 곱셈 횟수를 계산하는 `matrix_chain_multiplication` 함수와 최적의 곱셈 순서를 문자열로 반환하는 `construct_optimal_order` 함수로 구성되어 있다.

1. 사용자로부터 행렬의 개수 N과 각 행렬의 차원 정보를 입력받는다.
2. `matrix_chain_multiplication` 함수는 각 행렬 조합에 대해 최소 곱셈 횟수를 계산하고, 각 단계의 최적 분할 위치를 기록한다.
3. 계산된 최소 곱셈 횟수는 `M` 배열에 저장되며, 어떤 분할이 최적인지는 `P` 배열에 저장된다.
4. `construct_optimal_order` 함수는 `P` 배열을 사용하여 재귀적으로 최적의 곱셈 순서를 문자열로 구성하고 반환한다.
5. 최종적으로 계산된 최소 곱셈 횟수와 최적의 곱셈 순서가 출력된다. 

위와 같은 구성의 순서를 찾는 매트릭스를 만들고

행렬 곱셈 수도 계산하는 것이다.


728x90