Description
给定 \(n\) 个三元组 \((A_1, B_1, C_1), (A_2, B_2, C_2),\cdots ,(A_n, B_n, C_n)\)。支持 7 种操作。
对于区间 \([l, r]\):
- 进行操作 \(A_i \leftarrow A_i + B_i\)
- 进行操作 \(B_i \leftarrow B_i + C_i\)
- 进行操作 \(C_i \leftarrow C_i + A_i\)
- 进行操作 \(A_i \leftarrow A_i + v\)
- 进行操作 \(B_i \leftarrow B_i \times v\)
- 进行操作 \(C_i \leftarrow v\)
- 分别求 \(\sum A_i\),\(\sum B_i\),\(\sum C_i\) 的值。
\(A, B, C\) 的值时刻对 \(998244353\) 取模。
Hint
- \(1\le n, m\le 2.5\times 10^5\)
- \(A_i, B_i, C_i, v \in [0, 998244353]\)
Solution
单个位置有多种值并且 互相影响,可以向 矩阵 的方向思考。
我们把这个三元组 \((A, B, C)\) 存入一个三行一列的矩阵中:\(\begin{bmatrix} A \\ B \\ C \end{bmatrix}\)。我们尝试着使用 矩阵乘法 来完成对矩阵中的元素进行更新。如果矩阵加速熟的话很快可以写出来:
操作一、二、三分别为:
\[1 :\begin{bmatrix} 1 & 1 & 0 \\ 0 & 1 & 0 \\ 0 & 0 & 1 \end{bmatrix} \times\begin{bmatrix} A \\ B \\ C \end{bmatrix} = \begin{bmatrix} A+B \\ B \\ C \end{bmatrix} \\ ----------------------------\\ 2:\begin{bmatrix} 1 & 0 & 0 \\ 0 & 1 & 1 \\ 0 & 0 & 1 \end{bmatrix} \times\begin{bmatrix} A \\ B \\ C \end{bmatrix} = \begin{bmatrix} A \\ B +C \\ C \end{bmatrix}\\ ----------------------------\\ 3:\begin{bmatrix} 1 & 0 & 0 \\ 0 & 1 & 0 \\ 1 & 0 & 1 \end{bmatrix} \times\begin{bmatrix} A \\ B \\ C \end{bmatrix} = \begin{bmatrix} A \\ B \\ C + A \end{bmatrix} \]
对于操作七,我们发现可以转化为矩阵的区间和,那么不难想到 线段树——一个结点维护区间矩形的和,由于矩阵乘法有结合律所以正确性显然。在修改时也可以直接转化为区间乘上一个元素,通过打标记实现(因为矩阵乘法满足分配律)。
进行了发现了剩下三个神奇的操作,并不能直接简单的通过乘法实现。当然可以结合矩阵加法做。只不过代码难度和长度会大大增加。
我们不妨将矩阵拓展一层——加入常数项:\(\begin{bmatrix}A \\ B \\ C \\ 1\end{bmatrix}\)。
于是一切都得到了解决:
\[1 :\begin{bmatrix} 1 & 1 & 0 & 0\\ 0 & 1 & 0 & 0 \\ 0 & 0 & 1 & 0 \\ 0 & 0 & 0 & 1 \end{bmatrix} \times\begin{bmatrix} A \\ B \\ C \\ 1 \end{bmatrix} = \begin{bmatrix} A+B \\ B \\ C \\ 1 \end{bmatrix} \\ ----------------------------\\ 2 :\begin{bmatrix} 1 & 0 & 0 & 0\\ 0 & 1 & 1 & 0 \\ 0 & 0 & 1 & 0 \\ 0 & 0 & 0 & 1 \end{bmatrix} \times\begin{bmatrix} A \\ B \\ C \\ 1 \end{bmatrix} = \begin{bmatrix} A \\ B + C \\ C \\ 1 \end{bmatrix} \\ ----------------------------\\ 3 :\begin{bmatrix} 1 & 0 & 0 & 0\\ 0 & 1 & 0 & 0 \\ 1 & 0 & 1 & 0 \\ 0 & 0 & 0 & 1 \end{bmatrix} \times\begin{bmatrix} A \\ B \\ C \\ 1 \end{bmatrix} = \begin{bmatrix} A \\ B \\ C+A \\ 1 \end{bmatrix} \\ ----------------------------\\ 4 :\begin{bmatrix} 1 & 0 & 0 & v\\ 0 & 1 & 0 & 0 \\ 0 & 0 & 1 & 0 \\ 0 & 0 & 0 & 1 \end{bmatrix} \times\begin{bmatrix} A \\ B \\ C \\ 1 \end{bmatrix} = \begin{bmatrix} A+v \\ B \\ C \\ 1 \end{bmatrix} \\ ----------------------------\\ 5 :\begin{bmatrix} 1 & 0 & 0 & 0\\ 0 & v & 0 & 0 \\ 0 & 0 & 1 & 0 \\ 0 & 0 & 0 & 1 \end{bmatrix} \times\begin{bmatrix} A \\ B \\ C \\ 1 \end{bmatrix} = \begin{bmatrix} A \\ B\times v \\ C \\ 1 \end{bmatrix} \\ ----------------------------\\ 6 :\begin{bmatrix} 1 & 0 & 0 & 0\\ 0 & 1 & 0 & 0 \\ 0 & 0 & 0 & v \\ 0 & 0 & 0 & 1 \end{bmatrix} \times\begin{bmatrix} A \\ B \\ C \\ 1 \end{bmatrix} = \begin{bmatrix} A \\ B \\ v \\ 1 \end{bmatrix} \\ \]
然后这个算法复杂度为 \(O(L^3\log n)\),其中矩阵大小为 \(L=4\)。如果用上述的“比较麻烦的做法”那么 \(L=3\)。 但常数都很大。
卡常细节:
- 不要用
vector
实现的矩阵类,常数爆炸。 - 矩阵乘法中剪掉无用运算。
pushdown
中省去空标记的下传。- 使用较快的 I/O 方式。
实现细节:
- 可以通过打表的方法打出转移矩阵以减少出错。打表可以用初始化列表(C++11 语法慎用)。
- 矩阵类功能写全、人性化,如下标取值,重载运算符等等。
Code
/*
- Author : Wallace
- Source : https://www.cnblogs.com/-Wallace-/
- Problem : LOJ #2980 THUSCH 2017 大魔法师
*/
#include <cstdio>
#include <cstring>
#include <initializer_list>
#include <vector>
using namespace std;
namespace fastIO_int{
int get_int()
{
int x=0,f=1;char c=getchar();
while(c<'0'||c>'9')
{
if(c=='-')f=-1;
c=getchar();
}
while(c>='0'&&c<='9')
x=x*10+c-'0',c=getchar();
return f*x;
}
void read(){}
template<class T1,class ...T2>
void read(T1 &i,T2&... rest)
{
i=get_int();
read(rest...);
}
void put_int(int x)
{
if(x<0)putchar('-'),x=-x;
if(x>9)put_int(x/10);
putchar(x%10+'0');
}
void write(){}
template<class T1,class ...T2>
void write(T1 i,T2... rest)
{
put_int(i),putchar(' ');
write(rest...);
}
};
typedef long long i64;
typedef vector<i64> vec;
const int N = 2.5e5 + 5;
const int mod = 998244353;
struct mat {
i64 e[4][4];
int r, c;
mat() { }
mat(int r, int c) : r(r), c(c) {
memset(e, 0, sizeof e);
}
mat(initializer_list<vec> k) {
r = 0;
for (auto i : k) {
c = 0;
for (auto j : i)
e[r][c++] = j;
++r;
}
r = k.size(), c = k.begin()->size();
}
inline i64* operator [] (int k) {
return e[k];
}
};
inline bool operator == (mat a, mat b) {
for (int i = 0; i < a.r; i++)
for (int j = 0; j < a.c; j++)
if (a[i][j] == b[i][j]) return false;
return true;
}
inline bool operator != (mat a, mat b) {
return !(a == b);
}
inline mat operator + (mat a, mat b) {
mat x(a.r, a.c);
for (int i = 0; i < a.r; i++)
for (int j = 0; j < a.c; j++)
x[i][j] = (a[i][j] + b[i][j]) % mod;
return x;
}
inline mat operator * (mat a, mat b) {
mat x(a.r, b.c);
for (int i = 0; i < a.r; i++)
for (int k = 0; k < b.r; k++) if(a[i][k])
for (int j = 0; j < b.c; j++) if(b[k][j])
(x[i][j] += a[i][k] * b[k][j]) %= mod;
return x;
}
mat matI = mat {
vec{1, 0, 0, 0},
vec{0, 1, 0, 0},
vec{0, 0, 1, 0},
vec{0, 0, 0, 1},
};
mat opt[6] = {
mat {
vec{1, 1, 0, 0},
vec{0, 1, 0, 0},
vec{0, 0, 1, 0},
vec{0, 0, 0, 1},
},
mat {
vec{1, 0, 0, 0},
vec{0, 1, 1, 0},
vec{0, 0, 1, 0},
vec{0, 0, 0, 1},
},
mat {
vec{1, 0, 0, 0},
vec{0, 1, 0, 0},
vec{1, 0, 1, 0},
vec{0, 0, 0, 1},
},
// 魔力激发
mat {
vec{1, 0, 0, 0},
vec{0, 1, 0, 0},
vec{0, 0, 1, 0},
vec{0, 0, 0, 1},
// changing position : [0][3]
},
mat {
vec{1, 0, 0, 0},
vec{0, 0, 0, 0},
vec{0, 0, 1, 0},
vec{0, 0, 0, 1},
// changing position : [1][1]
},
mat {
vec{1, 0, 0, 0},
vec{0, 1, 0, 0},
vec{0, 0, 0, 0},
vec{0, 0, 0, 1},
// changing position : [2][3]
}
// 魔力增强
};
/* basic settings */
int A[N], B[N], C[N];
int n, q;
int L[N << 2], R[N << 2];
mat sum[N << 2], tag[N << 2];
inline void pushup(int x) {
sum[x] = sum[x << 1] + sum[x << 1 | 1];
}
inline void pushTag(int x, mat& v) {
tag[x] = v * tag[x], sum[x] = v * sum[x];
}
inline void pushdown(int x) {
if (tag[x] == matI) return;
pushTag(x << 1, tag[x]);
pushTag(x << 1 | 1, tag[x]);
tag[x] = matI;
}
#define mid ((L[x] + R[x]) >> 1)
void build(int l, int r, int x = 1) {
L[x] = l, R[x] = r, tag[x] = matI;
if (l == r) {
sum[x] = mat(4, 1);
sum[x][0][0] = A[l];
sum[x][1][0] = B[l];
sum[x][2][0] = C[l];
sum[x][3][0] = 1;
return;
}
build(l, mid, x << 1);
build(mid + 1, r, x << 1 | 1);
pushup(x);
}
void mutiply(int l, int r, mat& v, int x = 1) {
if (v == matI) return;
if (l <= L[x] && R[x] <= r) return pushTag(x, v);
pushdown(x);
if (l <= mid) mutiply(l, r, v, x << 1);
if (r > mid) mutiply(l, r, v, x << 1 | 1);
pushup(x);
}
void getSum(int l, int r, mat& v, int x = 1) {
if (l <= L[x] && R[x] <= r) return void(v = v + sum[x]);
pushdown(x);
if (l <= mid) getSum(l, r, v, x << 1);
if (r > mid) getSum(l, r, v, x << 1 | 1);
}
#undef mid
signed main() {
using fastIO_int::read;
using fastIO_int::write;
read(n);
for (int i = 1; i <= n; i++)
read(A[i], B[i], C[i]);
build(1, n);
read(q);
while (q--) {
int cmd, l, r;
read(cmd, l, r);
if (1 <= cmd && cmd <= 3) {
mutiply(l, r, opt[cmd - 1]);
} else if (cmd == 7) {
mat ans(4, 1); getSum(l, r, ans);
write(ans[0][0]), putchar(' ');
write(ans[1][0]), putchar(' ');
write(ans[2][0]), putchar('\n');
} else {
mat cur = opt[cmd - 1];
int v; read(v);
switch (cmd) {
case 4 : cur[0][3] = v; break;
case 5 : cur[1][1] = v; break;
case 6 : cur[2][3] = v; break;
}
mutiply(l, r, cur);
}
}
return 0;
}