问题描述
给定一个 \(k\) 阶常系数齐次线性递推数列的前 \(k\) 项 \(h_1, h_2, h_3..h_k\) 和线性递推式 \(h_n = \sum_{i = 1} ^ k a_i h_{n - i}\), 求这个数列的第 \(n\) 项。
复杂度要求: \(O(k^2logn)\)
加强: \(O(klogklogn)\)
前置知识
矩阵的特征值
设 \(A\) 是 \(n\) 阶方阵,如果存在数 \(\lambda\) 和非零 \(n\) 维列向量 \(x\),使得 \(Ax=\lambda x\) 成立,则称 \(\lambda\) 是矩阵 \(A\) 的一个特征值。
矩阵的特征多项式
设 \(E\) 为单位矩阵,\(n\) 阶方阵 \(A\) 的特征多项式为 \(|\lambda E - A|\)
解释一下:\(\lambda\) 为一个变量。 \(\lambda E - A\) 的行列式是一个关于 \(\lambda\) 的 \(n\) 次多项式,即 \(A\) 的特征多项式。
特征值是 \(|\lambda E - A| = 0\) 的根。
Cayley-Hamilton定理
设 \(A\) 的特征多项式为 \(p_A\),\(O\) 为零矩阵。用矩阵 \(A\) 代替 \(\lambda\) 带入 特征多项式,有 \(p_A(A) = O\)
简要证明一下:把 \(A\) 直接带进特征多项式, \(P_A(A) = |AE - A| = 0\)
其实我并不会证这个东西。自学的线代等于没学
算法流程
现在我们已经熟背了Cayley-Hamilton定理。
考虑计算一下矩阵快速幂的转移矩阵 \(M\) 的特征多项式。
设 \(M_{x,y}\) 为去掉位置 \((x,y)\) 的代数余子式,\(m_{x,y}\) 为 \((x,y)\) 位置上的元素,直接对第 \(1\) 行拉普拉斯展开,\(M = \sum_{i=1}^k m_{1,i} M_{1,i}\) ,可以发现去掉第一行第 \(i\) 列之后留下的是一个下三角矩阵,\(M_{1,i} = (-1) ^ {i + 1} (-1)^{i-1}\lambda^{k - i}\),\(m_{1,1} =\lambda-a_1,m_{1,i} = -a_i\),整理一下就得到:
\[ |E\lambda - M| = \lambda ^ k - a_1 \lambda ^ {k - 1} - a_2 \lambda ^ {k - 2} - ... - a_n \]
由 Cayley-Hamilton定理,我们得到 \(p_M(M) = O\)
现在我们用 \(n\) 代替 \(n - k\),我们要求出 \(M^n\)
\(M^n= p_M(M) A(M) + r(M)\),其中\(r(M)\) 的次数不高于\(k - 1\)
因为 \(p_M(M) = O\), 得到 \(M^n = r(M)\)
我们现在只需要求 \(M^n \mod p_M(M)\)
因为 \(AB \mod C = (A \mod C)(B \mod C) \mod C\),所以可以快速幂计算 \(M^n \mod p_M(M)\)
直接暴力复杂度 \(O(k^2logn)\), 用多项式乘法和取模可以做到 \(O(klogklogn)\)
现在得到了 \(M^n = r(M) = \sum_{i = 0} ^ {k - 1} c_i M ^ i\)
再分析一下就可以得到 \(ans = \sum_{i = 0} ^ {k - 1} c_i h_{i + k}\)
再暴力或者 NTT 处理一下前 \(h\) 的前 \(2k\) 项就好了。
模板
BZOJ4161: Shlw loves matrixI
这道题里面 \(h\) 的下标是从 \(0\) 开始的。
#pragma GCC optimize("2,Ofast,inline")
#include<bits/stdc++.h>
#define fi first
#define se second
#define mp make_pair
#define pb push_back
#define LL long long
#define pii pair<int, int>
using namespace std;
const int mod = 1e9 + 7;
template <typename T> T read(T &x) {
int f = 0;
register char c = getchar();
while (c > '9' || c < '0') f |= (c == '-'), c = getchar();
for (x = 0; c >= '0' && c <= '9'; c = getchar())
x = (x << 3) + (x << 1) + (c ^ 48);
if (f) x = -x;
return x;
}
namespace Comb {
const int Maxn = 1e6 + 10;
int fac[Maxn], fav[Maxn], inv[Maxn];
void comb_init() {
fac[0] = fav[0] = 1;
inv[1] = fac[1] = fav[1] = 1;
for (int i = 2; i < Maxn; ++i) {
fac[i] = 1LL * fac[i - 1] * i % mod;
inv[i] = 1LL * -mod / i * inv[mod % i] % mod + mod;
fav[i] = 1LL * fav[i - 1] * inv[i] % mod;
}
}
inline int C(int x, int y) {
if (x < y || y < 0) return 0;
return 1LL * fac[x] * fav[y] % mod * fav[x - y] % mod;
}
inline int Qpow(int x, int p) {
int ans = 1;
for (; p; p >>= 1) {
if (p & 1) ans = 1LL * ans * x % mod;
x = 1LL * x * x % mod;
}
return ans;
}
inline int Inv(int x) {
return Qpow(x, mod - 2);
}
inline void upd(int &x, int y) {
(x += y) >= mod ? x -= mod : 0;
}
inline int add(int x, int y) {
return (x += y) >= mod ? x - mod : x;
}
inline int dec(int x, int y) {
return (x -= y) < 0 ? x + mod : x;
}
}
using namespace Comb;
namespace Linear {
static const int Maxn = 5005;
int n, k;
int a[Maxn], h[Maxn];
int b[Maxn], c[Maxn], p[Maxn];
void mul(int *x, int *y, int *z) {
static int res[Maxn];
for (int i = 0; i <= k * 2; ++i) res[i] = 0;
for (int i = 0; i < k; ++i) {
for (int j = 0; j < k; ++j) {
upd(res[i + j], 1LL * x[i] * y[j] % mod);
}
}
for (int i = k * 2; i >= k; --i) {
int tmp = 1LL * Inv(p[k]) * res[i] % mod;
for (int j = 0; j <= k; ++j) {
res[i - j] = dec(res[i - j], 1LL * p[k - j] * tmp % mod);
}
}
for (int i = 0; i < k; ++i) z[i] = res[i];
}
void poly_pow(int p) {
while (p) {
if (p & 1) mul(b, c, c);
p >>= 1;
if (!p) break;
mul(b, b, b);
}
}
int solve() {
if (n <= k) return h[n];
p[k] = 1;
for (int i = 0; i < k; ++i)
p[i] = dec(0, a[k - i]);
b[1] = 1; c[0] = 1;
poly_pow(n - k);
for (int i = k + 1; i <= k * 2; ++i) {
for (int j = 1; j <= k; ++j) {
upd(h[i], 1LL * h[i - j] * a[j] % mod);
}
}
int ans = 0;
for (int i = 0; i < k; ++i)
upd(ans, 1LL * c[i] * h[i + k] % mod);
return ans;
}
}
using namespace Linear;
int main() {
read(n); read(k); ++n;
for (int i = 1; i <= k; ++i) {
read(a[i]);
if (a[i] < 0) a[i] += mod;
}
for (int i = 1; i <= k; ++i) {
read(h[i]);
if (h[i] < 0) h[i] += mod;
}
cout << solve() << endl;
return 0;
}