문제
풀이
우선 피보나치 문제를 행렬 곱으로 계산해야 한다. 즉, 이전 피보나치수를 나타내는 행렬에서 다음 피보나치 수를 나타내는 행렬을 만들어야 한다.
나는 다음과 같은 방법을 생각했다.
$\begin{bmatrix}f_n & f_{n+1}\\ 0 & 0 \end{bmatrix} \times SOME \ MATRIX=\begin{bmatrix}f_{n+1} & f_{n+2}\\ 0 & 0 \end{bmatrix}$
이는 다시 쓰면 아래와 같다.
$\begin{bmatrix}f_n & f_{n+1}\\ 0 & 0 \end{bmatrix} \times SOME \ MATRIX=\begin{bmatrix}f_{n+1} & f_{n}+f_{n+1}\\ 0 & 0 \end{bmatrix}$
이를 만족하는 SOME MATRIX를 구해보면 다음을 찾을 수 있다.
$\begin{bmatrix}f_n & f_{n+1}\\ 0 & 0 \end{bmatrix} \times \begin{bmatrix}0 & 1\\ 1 & 1 \end{bmatrix} =\begin{bmatrix}f_{n+1} & f_{n}+f_{n+1}\\ 0 & 0 \end{bmatrix}$
( 한번 계산해 보면 어렵지 않게 찾을 수 있다 ! )
그러면 N번째 피보나치 수는
$\begin{bmatrix}1 & 1\\ 0 & 0 \end{bmatrix} \times \begin{bmatrix}0 & 1\\ 1 & 1 \end{bmatrix} ^{N-1}=\begin{bmatrix}f_{N} & f_{N+1}\\ 0 & 0 \end{bmatrix}$
이므로 1번째 행의 1번째 열 값으로 알 수 있다. ( $f_1 = 1,\ f_2=1$ )
그러면 위의 식을 어떻게 구현해야 할까? 행렬 곱하기를 구현해서 N-1 번의 행렬 곱하기를 진행 할 수 도 있지만, 그렇게 되면 행렬 곱 연산에서 너무 많은 시간을 할애하여 시간초과가 날 것이다. 그래서 분할 정복을 이용한 거듭제곱이라 불리는 알고리즘을 사용할 것이다.
분할 정복을 이용하면 $O(N)$ 만큼 걸리던 거듭제곱을 $O(logN)$만큼으로 줄일 수 있는데 방법은 다음과 같다.
$C^n=\begin{cases} C^{n/2}\cdot C^{n/2} \ \ \ \ n\mathrm{\ is\ even} \\ C^{n/2}\cdot C^{n/2} \cdot C\ \ \ \ n\mathrm{\ is\ odd}\end {cases}$
$C^{n/2}$의 연산 결과를 알면 반복하지 않고 그 결과를 한 번 더 곱해주기만 하면 되므로, 연산량이 엄청 줄어들게 된다.
이 알고리즘을 사용해서 $\begin{bmatrix}0 & 1\\ 1 & 1 \end{bmatrix} ^{N-1}$을 구하고 후에 $\begin {bmatrix} 1 &1 \\ 0 & 0 \end{bmatrix}$와 곱하여 계산하였다. 행렬곱의 결합법칙 (combination law) 에 의해서 뒷 행렬의 거듭 제곱을 먼저 연산하여도 문제 없다.
전체 코드
## 행렬 곱 구현
def matrix_multiple(A, B):
R00 = (A[0][0] * B[0][0] + A[0][1] * B[1][0]) % 1000000007
R01 = (A[0][0] * B[0][1] + A[0][1] * B[1][1]) % 1000000007
R10 = (A[1][0] * B[0][0] + A[1][1] * B[1][0]) % 1000000007
R11 = (A[1][0] * B[0][1] + A[1][1] * B[1][1]) % 1000000007
result_matrix = [[R00, R01],
[R10, R11]]
return result_matrix
## 분할 정복을 이용한 거듭제곱
def n_square(n, matrix):
if n == 1:
return matrix
sqrt_matrix = n_square( n//2, matrix)
result_matrix = matrix_multiple(sqrt_matrix, sqrt_matrix)
if n % 2 == 1: ## n이 홀수인 경우
result_matrix = matrix_multiple(result_matrix, matrix)
return result_matrix
if __name__ == "__main__":
N = int(input())
initial_matrix = [[1, 1],
[0, 0]]
factor_matrix = [[0, 1],
[1, 1]]
factor_matrix_n = n_square(N, factor_matrix)
print(factor_matrix_n)
result = matrix_multiple(initial_matrix, factor_matrix_n)
print(result[0][0])