[BZOJ5020][THUWC 2017]在美妙的数学王国中畅游(LCT + 一点数学知识)

Address

Solution

  • 如果只有一次函数 a x + b ax+b ,那么这是非常裸的 LCT ,维护 a a 之和与 b b 之和即可
  • 然后你会发现
  • e a 1 x + b 1 × e a 2 x + b 2 = e ( a 1 x + b 1 ) ( a 2 x + b 2 ) = e a 1 a 2 x 2 + ( a 1 b 2 + a 2 b 1 ) x + b 1 b 2 e^{a_1x+b_1}\times e^{a_2x+b_2}=e^{(a_1x+b_1)(a_2x+b_2)}=e^{a_1a_2x^2+(a_1b_2+a_2b_1)x+b_1b_2}
  • 幂中有 x 2 x^2 的项,照这样的话如果一条路径很长,那么使用到的 x x 的最高次幂与路径上的点数相同,无法简单维护
  • sin \sin 甚至没有可加性
  • 回到 a x + b ax+b ,我们可以很容易地推广到,如果不是一次函数而是 m m 次多项式( m m 较小),那么我们同样可以通过维护多项式每一次项系数的和,做到 O ( n m log n ) O(nm\log n) 的优秀复杂度
  • 这让我们思考能不能把 sin \sin exp \exp 转化成多项式
  • 先科普下泰勒展开
  • f ( x ) = i = 0 f ( i ) ( x 0 ) ( x x 0 ) i i ! f(x)=\sum_{i=0}^{\infty}\frac{f^{(i)}(x_0)(x-x_0)^i}{i!}
  • 其中 f ( i ) ( x ) f^{(i)}(x) 表示 f ( x ) f(x) i i 阶导数
  • 一般地,为了让函数转成多项式的形式,我们通常取 x 0 = 0 x_0=0
  • sin ( a x + b ) \sin(ax+b) e a x + b e^{ax+b} 是复合函数
  • 可以按照 h ( x ) = f ( g ( x ) ) , h ( x ) = f ( g ( x ) ) g ( x ) h(x)=f(g(x)),h'(x)=f'(g(x))g'(x) 进行求导
  • 于是我们得到
  • sin ( a x + b ) = i = 0 a i orz ( i , b ) i ! x i \sin(ax+b)=\sum_{i=0}^{\infty}\frac{a^i\text{orz}(i,b)}{i!}x^i
  • 其中
  • orz ( i , x ) = { sin x i   m o d   4 = 0 cos x i   m o d   4 = 1 sin x i   m o d   4 = 0 cos x i   m o d   4 = 3 \text{orz}(i,x)=\begin{cases}\sin x&i\bmod 4=0\\\cos x&i\bmod 4=1\\-\sin x&i\bmod 4=0\\-\cos x&i\bmod 4=3\end{cases}
  • e a x + b = i = 0 a i e b i ! x i e^{ax+b}=\sum_{i=0}^{\infty}\frac{a^ie^b}{i!}x^i
  • 注意到当 i > 16 i>16 时,上面两个多项式的 i i 次项系数已经被 1 i ! \frac 1{i!} 压得非常小,其误差可以忽略
  • 所以我们只需要维护出泰勒展开后的多项式的次数 16 \le16 的项的系数和即可
  • 如果维护多项式的前 m m 项,则复杂度为 O ( n m log n ) O(nm\log n)
  • m m 减小能优化复杂度,但误差会变大
  • m m 16 16 左右较合适

Code

// luogu-judger-enable-o2
#include <cmath>
#include <cstdio>
#include <cstring>
#include <iostream>
#include <algorithm>
#define For(i, a, b) for (i = a; i <= b; i++)
#define Rof(i, a, b) for (i = a; i >= b; i--)
#define LCT(y, x) for (y = 0; x; y = x, x = fa[x])

template <class T>
inline void Swap(T &a, T &b) {T t = a; a = b; b = t;}

const int N = 1e5 + 5, M = 18;

int n, m, lc[N], rc[N], fa[N], rev[N], len, que[N];
double val[N][M], sum[N][M];
char op[M];

void getsin(double a, double b, double *val)
{
	int i; double tmp = 1, _sin = sin(b), _cos = cos(b);
	For (i, 0, 16)
	{
		if (i) tmp /= i;
		double st = i & 1 ? _cos : _sin;
		if ((i >> 1) & 1) val[i] = -st * tmp;
		else val[i] = st * tmp;
		tmp *= a;
	}
}

void getexp(double a, double b, double *val)
{
	int i; double tmp = 1, _exp = exp(b);
	For (i, 0, 16)
	{
		if (i) tmp /= i;
		val[i] = _exp * tmp;
		tmp *= a;
	}
}

void getlin(double a, double b, double *val)
{
	int i;
	val[0] = b; val[1] = a;
	For (i, 2, 16) val[i] = 0;
}

int which(int x) {return rc[fa[x]] == x;}

bool isroot(int x)
{
	return !fa[x] || (lc[fa[x]] != x && rc[fa[x]] != x);
}

void down(int x)
{
	if (rev[x])
	{
		rev[x] = 0;
		Swap(lc[x], rc[x]);
		if (lc[x]) rev[lc[x]] ^= 1;
		if (rc[x]) rev[rc[x]] ^= 1;
	}
}

void upt(int x)
{
	int i;
	For (i, 0, 16)
	{
		sum[x][i] = val[x][i];
		if (lc[x]) sum[x][i] += sum[lc[x]][i];
		if (rc[x]) sum[x][i] += sum[rc[x]][i];
	}
}

void rotate(int x)
{
	int y = fa[x], z = fa[y], b = lc[y] == x ? rc[x] : lc[x];
	if (!isroot(y)) (lc[z] == y ? lc[z] : rc[z]) = x;
	fa[x] = z; fa[y] = x; if (b) fa[b] = y;
	if (lc[y] == x) rc[x] = y, lc[y] = b;
	else lc[x] = y, rc[y] = b; upt(y); upt(x);
}

void splay(int x)
{
	int i, y = x;
	que[len = 1] = x;
	while (!isroot(y)) que[++len] = fa[y], y = fa[y];
	Rof (i, len, 1) down(que[i]);
	while (!isroot(x))
	{
		if (!isroot(fa[x]))
		{
			if (which(x) == which(fa[x])) rotate(fa[x]);
			else rotate(x);
		}
		rotate(x);
	}
}

void access(int x)
{
	int y;
	LCT(y, x)
	{
		splay(x); rc[x] = y;
		if (y) fa[y] = x; upt(x);
	}
}

int findroot(int x)
{
	access(x); splay(x);
	while (down(x), lc[x]) x = lc[x];
	return splay(x), x;
}

void makeroot(int x)
{
	access(x); splay(x);
	rev[x] ^= 1;
}

void link(int x, int y)
{
	makeroot(x); fa[x] = y;
}

void cut(int x, int y)
{
	makeroot(x); access(y); splay(y);
	lc[y] = fa[x] = 0; upt(y);
}

double query(int x, int y, double val)
{
	int i; double res = 0, tmp = 1;
	makeroot(x); access(y); splay(y);
	For (i, 0, 16) res += tmp * sum[y][i], tmp *= val;
	return res;
}

int main()
{
	int i, u, v, x; double s, t;
	scanf("%d%d%*s", &n, &m);
	For (i, 1, n)
	{
		scanf("%d%lf%lf", &x, &s, &t);
		if (x == 1) getsin(s, t, val[i]);
		else if (x == 2) getexp(s, t, val[i]);
		else getlin(s, t, val[i]);
		upt(i);
	}
	while (m--)
	{
		scanf("%s%d%d", op + 1, &u, &v);
		if (op[1] == 'a') link(u + 1, v + 1);
		else if (op[1] == 'd') cut(u + 1, v + 1);
		else if (op[1] == 'm')
		{
			scanf("%lf%lf", &s, &t);
			u++; splay(u);
			if (v == 1) getsin(s, t, val[u]);
			else if (v == 2) getexp(s, t, val[u]);
			else getlin(s, t, val[u]);
			upt(u);
		}
		else
		{
			scanf("%lf", &s); u++; v++;
			if (findroot(u) != findroot(v))
				puts("unreachable");
			else printf("%.8lf\n", query(u, v, s));
		}
	}
	return 0;
}

猜你喜欢

转载自blog.csdn.net/xyz32768/article/details/85038925