题意
题解
先考虑没有子序列 p x − p 1 ≤ k p_x - p_1 \leq k px−p1≤k 限制的情况。则所有的子序列与 N N N 个元素可构成的子集一一对应。设 A , B A,B A,B 为两个不相交区间的元素构成的非空集合, a , b a,b a,b 代表对应的非空子集,则有
∑ a ∈ A , b ∈ B ( ∑ a + b s i × ∏ a + b s i ) = ∑ a ∈ A , b ∈ B [ ( ∑ a s i + ∑ b s i ) × ( ∏ a s i ∏ b s i ) ] = ∑ a ∈ A ( ∑ a s i × ∏ a s i ) × ∑ b ∈ B ( ∏ b s i ) + ∑ b ∈ B ( ∑ b s i × ∏ b s i ) × ∑ a ∈ A ( ∏ a s i ) \begin{aligned} \sum\limits_{a\in A,b\in B}\Big(\sum\limits_{a+b}s_i\times\prod\limits_{a+b}s_i\Big) &= \sum\limits_{a\in A,b\in B}\Big[\big(\sum\limits_{a}s_i+\sum\limits_{b}s_i\big)\times\big(\prod\limits_{a}s_i\prod\limits_{b}s_i\big)\Big] \\ &=\sum\limits_{a\in A}\Big(\sum\limits_{a}s_i\times\prod\limits_{a}s_i\Big)\times\sum\limits_{b\in B}\Big(\prod_{b}s_i\Big)+\sum\limits_{b\in B}\Big(\sum\limits_{b}s_i\times\prod\limits_{b}s_i\Big)\times\sum\limits_{a\in A}\Big(\prod_{a}s_i\Big) \\ \end{aligned} a∈A,b∈B∑(a+b∑si×a+b∏si)=a∈A,b∈B∑[(a∑si+b∑si)×(a∏sib∏si)]=a∈A∑(a∑si×a∏si)×b∈B∑(b∏si)+b∈B∑(b∑si×b∏si)×a∈A∑(a∏si) 维护如下两个变量
F = ∑ a ∈ A ( ∑ a s i × ∏ a s i ) , G = ∑ a ∈ A ( ∏ a s i ) F=\sum\limits_{a\in A}\Big(\sum\limits_{a}s_i\times\prod\limits_{a}s_i\Big), G=\sum\limits_{a\in A}\Big(\prod_{a}s_i\Big) F=a∈A∑(a∑si×a∏si),G=a∈A∑(a∏si) 集合 A + B A+B A+B 的非空子集由三个部分构成:属于 A A A 的非空子集;属于 B B B 的非空子集; A , B A,B A,B 非空子集的并集。则集合信息合并如下
F A + B = F A × G B + F B × G A + F A + F B , G A + B = G A × G B + G A + G B F_{A+B}=F_{A}\times G_{B}+F_{B}\times G_{A}+F_{A}+F_{B}, G_{A+B}=G_{A}\times G_{B}+G_{A}+G_{B} FA+B=FA×GB+FB×GA+FA+FB,GA+B=GA×GB+GA+GB 线段树维护区间信息。枚举子序列的末尾元素,统计包含末尾元素且满足限制条件的 F F F。具体而言,将 [ i − k , i ) [i-k,i) [i−k,i) 与 [ i , i + 1 ) [i,i+1) [i,i+1) 合并,且不计入 [ i − k , i ) [i-k,i) [i−k,i) 的子集贡献。总时间复杂度 O ( N log N ) O(N\log N) O(NlogN)。由于查询区间的规模固定,分块可以做到 O ( N ) O(N) O(N)。
#include <bits/stdc++.h>
using namespace std;
#define rep(i, l, r) for (int i = l, _ = r; i < _; ++i)
typedef long long ll;
const int maxn = 1000005, sg_size = 1 << 21;
int N, K, mod, A[maxn];
struct node
{
int f, g;
} tree[sg_size];
int rd()
{
int x = 0;
char c = 0;
for (; !isdigit(c); c = getchar())
;
for (; isdigit(c); c = getchar())
x = (x << 1) + (x << 3) + c - '0';
return x;
}
void merge(node &chl, node &chr, node &p)
{
p.g = ((ll)chl.g * chr.g % mod + chl.g + chr.g) % mod;
p.f = (((ll)chl.f * chr.g + (ll)chr.f * chl.g) % mod + chl.f + chr.f) % mod;
}
void init(int k = 0, int l = 0, int r = N)
{
if (r - l == 1)
{
tree[k].f = (ll)A[l] * A[l] % mod, tree[k].g = A[l];
return;
}
int chl = (k << 1) + 1, chr = (k << 1) + 2, m = (l + r) >> 1;
init(chl, l, m), init(chr, m, r);
merge(tree[chl], tree[chr], tree[k]);
}
node ask(int a, int b, int k = 0, int l = 0, int r = N)
{
if (a <= l && r <= b)
return tree[k];
int chl = (k << 1) + 1, chr = (k << 1) + 2, m = (l + r) >> 1;
if (b <= m)
return ask(a, b, chl, l, m);
if (m <= a)
return ask(a, b, chr, m, r);
node res, rl = ask(a, b, chl, l, m), rr = ask(a, b, chr, m, r);
merge(rl, rr, res);
return res;
}
int main()
{
N = rd(), K = rd(), mod = rd();
rep(i, 0, N) A[i] = rd();
if (K == 0)
{
int res = 0;
rep(i, 0, N) res = (res + (ll)A[i] * A[i]) % mod;
printf("%d\n", res);
return 0;
}
init();
int res = 0;
rep(i, 0, N)
{
node t = ask(max(0, i - K), i);
int f = (ll)A[i] * A[i] % mod, g = A[i];
res = ((ll)res + f + (ll)t.f * g + (ll)f * t.g) % mod;
}
printf("%d\n", res);
return 0;
}