코딩 테스트/백준

[Baekjoon/백준] 11049번: 행렬 곱셈 순서(C/C++)

JongHoon 2022. 7. 19. 23:30

단계별로 풀어보기 24단계(동적 계획법 2) 2번 문제

https://www.acmicpc.net/step/17

 

동적 계획법 2 단계

더 이상 사용되지 않는 값을 버림으로써 공간 복잡도를 향상시키는 문제. 메모리 제한에 주목하세요.

www.acmicpc.net


백준 11049번: 행렬 곱셈 순서

https://www.acmicpc.net/problem/11049

 

11049번: 행렬 곱셈 순서

첫째 줄에 입력으로 주어진 행렬을 곱하는데 필요한 곱셈 연산의 최솟값을 출력한다. 정답은 231-1 보다 작거나 같은 자연수이다. 또한, 최악의 순서로 연산해도 연산 횟수가 231-1보다 작거나 같

www.acmicpc.net


문제 설명

크기가 N×M인 행렬 A와 M×K인 B를 곱할 때 필요한 곱셈 연산의 수는 총 N×M×K번이다. 행렬 N개를 곱하는데 필요한 곱셈 연산의 수는 행렬을 곱하는 순서에 따라 달라지게 된다.

예를 들어, A의 크기가 5×3이고, B의 크기가 3×2, C의 크기가 2×6인 경우에 행렬의 곱 ABC를 구하는 경우를 생각해보자.

  • AB를 먼저 곱하고 C를 곱하는 경우 (AB)C에 필요한 곱셈 연산의 수는 5×3×2 + 5×2×6 = 30 + 60 = 90번이다.
  • BC를 먼저 곱하고 A를 곱하는 경우 A(BC)에 필요한 곱셈 연산의 수는 3×2×6 + 5×3×6 = 36 + 90 = 126번이다.

같은 곱셈이지만, 곱셈을 하는 순서에 따라서 곱셈 연산의 수가 달라진다.

행렬 N개의 크기가 주어졌을 때, 모든 행렬을 곱하는데 필요한 곱셈 연산 횟수의 최솟값을 구하는 프로그램을 작성하시오. 입력으로 주어진 행렬의 순서를 바꾸면 안 된다.


입력과 출력

입력: 첫째 줄에 행렬의 개수 N(1 ≤ N ≤ 500)이 주어진다.

         둘째 줄부터 N개 줄에는 행렬의 크기 r과 c가 주어진다. (1 ≤ r, c ≤ 500)

         항상 순서대로 곱셈을 할 수 있는 크기만 입력으로 주어진다.

출력: 첫째 줄에 입력으로 주어진 행렬을 곱하는데 필요한 곱셈 연산의 최솟값을 출력한다. 정답은 231-1 보다 작거나 같은 자연수이다. 또한, 최악의 순서로 연산해도 연산 횟수가 2^31 - 1보다 작거나 같다.


접근 방법

이번 문제는 행렬 개수 N을 입력한 뒤 N개의 행렬 크기 r, c를 입력해서 행렬을 곱하는데 필요한 곱셈 연산의 최솟값을 출력하는 문제다.

행렬 곱셈은 문제 설명에 있듯 N * M 행렬과 M * K 행렬을 곱하면 N * M * K 개의 곱셈 연산을 필요로 한다. 3개 이상의 행렬을 곱할때는 그 순서에 따라서 결과가 달라진다.

그러면 문제 설명을 더 자세히 봐보자. 설명에서 행렬 A는 5 * 3, 행렬 B는 3 * 2, 행렬 C는 2 * 6이다. (AB)C를 할 경우 각 행렬의 곱이 5 * 3 * 2, 5 * 2 * 6인데, 여기서 5 * 3 * 2가 중간 값이 생략되어 두번째 곱셈에서 5 * 2가 되어 계산되었다. 그리고 A(BC)의 경우 3 * 2 * 6, 5 * 3 * 6인데, 여기서도 중간 값이 생략되어 두번째 곱셈에서 3 * 2 * 6이 3 * 6이 되어 계산되었다.

이를 보면 결국 N * M과 M * K 행렬을 곱하면 N * K 값이 된다. 그러니까 여러개의 행렬이 있다면, 처음 행렬의 행과 마지막 행렬의 열을 곱한 행렬이 나오게 된다. 정리하면 행렬 곱셈 연산 횟수는 N * M * K, 각 행렬을 곱한 후에는 처음 행렬 행 * 마지막 행렬 열한 값이 나온다.

이 점을 이용해서 이중 배열과 삼중 반복문을 이용해 행렬들을 돌면서 곱해서 최솟값을 구하면 문제를 해결할 수 있다.


코드

#include <iostream>
#include <algorithm>
using namespace std;
#define MAX 501

int r[MAX];
int c[MAX];
int arr[MAX][MAX];

int main(int argc, char * argv[]) {
	ios_base::sync_with_stdio(0); cin.tie(0); cout.tie(0);
	int N;
	cin >> N;
	
	for (int i = 1; i <= N; i++)
		cin >> r[i] >> c[i];

	for (int i = 1; i < N; i++)
		for (int j = 1; i + j <= N; j++) {
			arr[j][i + j] = 125000000;

			for (int h = j; h <= i + j; h++)
				arr[j][i + j] = min(arr[j][i + j], arr[j][h] + arr[h + 1][i + j] + r[j] * c[h] * c[i + j]);
		}

	cout << arr[1][N] << endl;

	return 0;
}

결과

백준 제출 결과