题目描述
牛牛有一颗大小为n的神奇Link-Cut 数组,数组上的每一个节点都有两种状态,一种为link状态,另一种为cut状态。数组上任意一对处于link状态的无序点对(即(u,v)和(v,u)被认为是同一对)会产生dis(u,v)的link能量,dis(u,v)为数组上u到v的距离。
我们定义整个数组的Link能量为所有处于link状态的节点产生的link能量之和。
一开始数组上每个节点的状态将由一个长度大小为n的01串给出,'1'表示Link状态,'0'表示Cut状态。
牛牛想要知道一开始,以及每次操作之后整个数组的Link能量,为了避免这个数字过于庞大,你只用输出答案对1e9+7取余后的结果即可。
思路
把一个大区间分成两个小区间计算,例如:11011111,可以分开成1101和1111,那么大区间的Link值就是左边区间内的Link加右区间的Link,再加上左边区间和右边区间互相的贡献,所以可以用线段树维护一下每个区间内的Link值,左右区间互相的贡献也很好计算。
设他们的下标从1开始,那么就有
1 | 1 | 0 | 1 | 1 | 1 | 1 | 1 |
1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 |
可以看出,左右两边互相的贡献就是
把左边区间的点的坐标和右边区间点的坐标提取出来,可以发现,结果就是
右区间所有为1 的点的下标之和乘以左边为1 的点的个数,减去左区间所有为1 的点的下标和乘以右区间为1 的点的个数,即
因此只需要线段树去维护各区间内的为1 的点的个数和为1 的点的下标和即可,每次查询的时候更新完直接输出根节点的sum值即可(注意取模)
#include <bits/stdc++.h>
#pragma warning (disable:4996)
#pragma warning (disable:6031)
#define mem(a, b) memset(a, b, sizeof a)
using namespace std;
const int N = 1e5 + 10;
typedef long long ll;
const int mod = 1e9 + 7;
char s[N];
struct p {
ll l, r, possum, cnt, sum;
};
struct SegementTree {
p c[N * 4];
SegementTree() {
mem(c, 0);
}
void build(ll l, ll r, ll k) {
c[k].l = l;
c[k].r = r;
c[k].possum = 0;
c[k].sum = 0;
c[k].cnt = 0;
if (l == r) {
if (s[l - 1] == '0')return ;
c[k].cnt = 1;
c[k].possum = l;
c[k].sum = 0;
return ;
}
ll mid = (l + r) / 2;
build(l, mid, k << 1);
build(mid + 1, r, k << 1 | 1);
c[k].cnt = c[k << 1].cnt + c[k << 1 | 1].cnt;
c[k].possum = c[k << 1].possum + c[k << 1 | 1].possum;
c[k].sum = c[k << 1].sum + c[k << 1 | 1].sum;
c[k].sum %= mod;
c[k].sum += c[k << 1 | 1].possum * c[k << 1].cnt % mod - c[k << 1].possum * c[k << 1 | 1].cnt % mod;
c[k].sum = (c[k].sum + mod) % mod;
}
void update1(ll ind, ll k) {
// 0->1
if (c[k].l == c[k].r) {
c[k].cnt = 1;
c[k].possum = c[k].l;
c[k].sum = 0;
return ;
}
ll mid = c[k].l + c[k].r;
mid /= 2;
if (ind <= mid)update1(ind, k << 1);
else update1(ind, k << 1 | 1);
c[k].possum = c[k << 1].possum + c[k << 1 | 1].possum;
c[k].cnt = c[k << 1].cnt + c[k << 1 | 1].cnt;
c[k].sum = c[k << 1].sum + c[k << 1 | 1].sum;
c[k].sum %= mod;
c[k].sum += c[k << 1 | 1].possum * c[k << 1].cnt % mod - c[k << 1].possum * c[k << 1 | 1].cnt % mod;
c[k].sum = (c[k].sum + mod) % mod;
}
void update2(ll ind, ll k) {
// 1->0
if (c[k].l == c[k].r) {
c[k].cnt = 0;
c[k].possum = 0;
c[k].sum = 0;
return ;
}
ll mid = c[k].l + c[k].r;
mid /= 2;
if (ind <= mid)update2(ind, k << 1);
else update2(ind, k << 1 | 1);
c[k].possum = c[k << 1].possum + c[k << 1 | 1].possum;
c[k].cnt = c[k << 1].cnt + c[k << 1 | 1].cnt;
c[k].sum = c[k << 1].sum + c[k << 1 | 1].sum;
c[k].sum %= mod;
c[k].sum += c[k << 1 | 1].possum * c[k << 1].cnt % mod - c[k << 1].possum * c[k << 1 | 1].cnt % mod;
c[k].sum = (c[k].sum + mod) % mod;
}
};
SegementTree st;
int main()
{
int n;
scanf("%d", &n);
scanf("%s", s);
st.build(1, n, 1);
printf("%lld\n", st.c[1].sum);
int q;
scanf("%d", &q);
for (int i = 1; i <= q; i++) {
ll a, b;
scanf("%lld %lld", &a, &b);
if (a == 1) {
st.update1(b, 1);
}
else st.update2(b, 1);
printf("%lld\n", st.c[1].sum);
}
return 0;
}