mathematics

분할 정복을 이용한 행렬 거듭 제곱

leesu0605 2023. 2. 27. 20:00

dp[i] = 1 * dp[i - 1] + 2 * dp[i - 2] + 3 * dp[i - 3] ... dp[i - K]
위와 같은 점화식으로 dp[N]를 구해야 한다고 생각해보자.
일단 dp를 하나하나 구한다고 하면 총 O(NK)의 시간복잡도가 나올 것이다.

그러나 이를 잘 최적화하면 O(K^3 log N) 안에 dp[N]을 구할 수 있다.

바로 행렬 거듭제곱을 이용하는 것인데, 그 전에 "행렬 곱셈을 한다"가 어떤 의미인지 한 번 해석을 해보도록 하겠다.
다음과 같은 행렬 두 개가 있다고 가정하자.

[1, 2, 3]        [10]
[1, 0, 0]        [20]
[0, 1, 0]        [30]

그리고 위 행렬을 곱하면, 다음과 같은 행렬이 나올 것이다.

[1*10+2*20+3*30 = 140]
[1*10+0*20+0*30 = 10]
[0*10+1*20+0*30 = 20]

이제 뭔가 보인다.

위의 3*3 행렬의 각 행의 숫자들이 의미하는 바를 각각 dp 점화식의 계수라 하고,
3*1 행렬의 각 행의 숫자가 의미하는 바를 각각 dp[i-1], dp[i-2], dp[i - 3]이라고 가정해보자.

그럼 두 행렬을 곱한 행렬은 다음과 같이 바꿔줄 수 있다.

[dp[i] = 1 * dp[i - 1] + 2 * dp[i - 2] + 3 * dp[i - 3]]
[dp[i - 1] = 1 * dp[i - 1] + 0 * dp[i - 2] + 0 * dp[i - 3] = dp[i - 1]]
[dp[i - 2] = 0 * dp[i - 1] + 1 * dp[i - 2] + 0 * dp[i - 3] = dp[i - 2]]

즉, dp[i - 1], dp[i - 2], dp[i - 3]를 가지고 dp[i], dp[i - 1], dp[i - 2]을 원소로 갖는 3*1 행렬을 만들어낸 것이다.
그리고 위의 3*3 행렬과 새로 만들어낸 행렬을 또 곱해주면 dp[i + 1], dp[i], dp[i - 1]을 원소로 갖는 행렬이 나올 것이고....
이렇게 쭉 가다보면 dp[N]을 구할 수 있다.

즉, 행렬을 N번 제곱한 것과 dp[K-1] ~ dp[0]을 원소로 갖는 K*1 행렬과 곱하면 dp[N]을 구할 수 있다.

그러나 이런 식으로 행렬곱으로 dp[i]를 구하는 방법은 당연히 dp[i]를 그냥 하나하나 구하는 것보다 더 긴 시간이 요구된다.

따라서 이를 최적화하기 위해 분할 정복을 이용한 거듭제곱 알고리즘을 이용할 것이다.
이 알고리즘은 N번의 거듭제곱을 O(log N) 안에 해주기 때문에 현재 상황에 쓰기 적합하다.

다음은 N번째 피보나치 수를 1000000007로 나눈 값을 O(2^3 * log N) 안에 구해주는 코드이다.

#include <iostream>
#include <vector>
using namespace std;
typedef long long int ll;
using Matrix=vector<vector<ll>>;
#define MOD 1000000007

Matrix operator * (Matrix a, Matrix b){
        Matrix res(a.size(), vector<ll>(b[0].size()));
        for(int i=0;i<a.size();i++)
                for(int j=0;j<b[0].size();j++){
                        ll cur=0LL;
                        for(int k=0;k<b.size();k++)
                                cur=(cur+a[i][k]*b[k][j])%MOD;
                        res[i][j]=cur;
                }
        return res;
}

int main(){
        ios_base::sync_with_stdio(0);
        cin.tie(0);
        cout.tie(0);
        ll n;
        cin>>n;
        if(!n){
                cout<<0<<'\n';
                return 0;
        }
        Matrix fibo(2, vector<ll>(2, 1));
        fibo[1][1]=0;
        Matrix mul_mat(fibo), init(2, vector<ll>(1, 1));
        init[1][0]=0;
        n--;
        while(n){
                if(n&1)
                        fibo=mul_mat*fibo;
                mul_mat=mul_mat*mul_mat;
                n>>=1;
        }
        Matrix res=fibo*init;
        cout<<res[1][0]<<'\n';
}