求A(x)*B(x)
// fft模板.cpp : 此文件包含 "main" 函数。程序执行将在此处开始并结束。
//
#include"pch.h"
#include <iostream>
#include<algorithm>
#include<iostream>
#include<string>
#include<cstdio>
using namespace std;
typedef long long ll;
//long long ll;
const int N = 1e5 + 5;//n的范围
const double pi = acos(-1.0);
int a[N], a2[N];
ll num[N * 4];
ll ans[N * 4];
struct Complex
{
double x, y;
Complex(double x1 = 0.0, double y1 = 0.0)
{
x = x1;
y = y1;
}
Complex operator -(const Complex &b)const
{
return Complex(x - b.x, y - b.y);
}
Complex operator +(const Complex &b)const
{
return Complex(x + b.x, y + b.y);
}
Complex operator *(const Complex &b)const
{
return Complex(x*b.x - y * b.y, x*b.y + y * b.x);
}
}x1[N * 4], x2[N * 4], x3[N * 4];
void change(Complex y[], int len)
{
int i, j, k;
for (i = 1, j = len / 2; i < len - 1; i++)
{
if (i < j)swap(y[i], y[j]);
k = len / 2;
while (j >= k)
{
j -= k;
k /= 2;
}
j += k;
}
return;
}
void FFT(Complex y[], int len, int on)
{
change(y, len);
for (int h = 2; h <= len; h <<= 1)
{
Complex wn(cos(-on * 2 * pi / h), sin(-on * 2 * pi / h));
for (int j = 0; j < len; j += h)
{
Complex w(1, 0);
for (int k = j; k < j + h / 2; k++)
{
Complex u = y[k];
Complex t = w * y[k + h / 2];
y[k] = u + t;
y[k + h / 2] = u - t;
w = w * wn;
}
}
}
if (on == -1)
for (int i = 0; i < len; i++)
y[i].x /= len;
return;
}
int main()
{
int n;
memset(num, 0, sizeof(num));
scanf_s("%d", &n);
//scanf("%d", &m);//当两个多项式不同长度时,可补零到相同,这里用相同系数说明
int len1 = 1;
for (int i = 0; i < n; i++)
{
scanf_s("%d", &a[i]);//输入第一个多项式n个系数
}
for (int i = 0; i < n; i++)
{
scanf_s("%d", &a2[i]);//输入第二个多项式n个系数
}
int len = n;
while (len1 < 2 * len)
{
len1 <<= 1;
}
//cout << len1 << endl;
for (int i = 0; i < len; i++)
{
x1[i] = Complex(a[i], 0);
x2[i] = Complex(a2[i], 0);
}
for (int i = len; i < len1; i++)//不满2的指数用0补满
{
x1[i] = Complex(0, 0);
x2[i] = Complex(0, 0);
}
FFT(x1, len1, 1);//系数表示法转为点值表示法
FFT(x2, len1, 1);
for (int i = 0; i < len1; i++)
{
x3[i] = x1[i] * x2[i];
}
FFT(x3, len1, -1);//点值表示法转为系数表示法
for (int i = 0; i < len1; i++)
{
num[i] = ll(x3[i].x + 0.5);//存两个多项式相乘以后第i次方项前的系数。
//cout << num[i] << endl;
}
//比如1234*1234(输入4 3 2 1 ,4 3 2 1)=(1*10^3+2*10^2+3*10^1+4*10^0)*(1*10^3+2*10^2+3*10^1+4*10^0)=num[0]*10^0+num[1]*10^1+....+num[7]*10^7 ;
}
在这里插入代码片
题目链接http://acm.hdu.edu.cn/showproblem.php?pid=4609
大意是输入a数组,问你任选三个数字作为三边能组成三角形的概率。
我们先给他排序一下。
显然有一种办法,一共有多少种组合能组成三角形除以一共有多少种选择办法即是所求,显然一共的选择有n*(n-1)(n-2)/6种。
现在来求有多少种选法可以满足条件
三边之中肯定有一条是最大边,那当我们选择一条为最大边c以后,另外两边a和b怎么选呢?因为此时最大边已经确定,那现在只要满足条件a+b>c即可,
现在问题就转化为求满足的a+b有多少种呢,你可能想问fft不是得出两个多项式相乘后的各项系数吗,跟这个加法有什么关系?如果你把边的长度作为X的指数,把改边的数量作为该项的系数呢?相乘不就相当于相加了吗?
记num[i]表示a+b=i的组数(这里指用fft算出来的结果)(注意:fft是每一项都跟另一个多项式的所有项都各自乘了一次)
这里来消除一些不满足条件或重复的组合,因为不能本身跟本身结合,所以for(i=0;i<n;i++)num[a[i]+a[i]]–;
还有选择第1,3条边和选择第3,1条边是重复的,所以for(int i=0;i<len1;i++)num[i]/=2;
再弄一个前缀和就是两边之和小于等于特定数的组数了。
即len1是相乘后最高次数。
那长度和大于a[i]的取两个的取法是num[len1]-num[a[i]].
则所有选择办法就是for(int i=0;i<n;i++)sum+=num[len1]-num[a[i]].
这里还存在一些不满足条件的:注意:a数组已经提前排序过了
(1)两边都大于a[i]的有(n-i-1)(n-i-2)/2种
(2)一边大于a[i],一遍小于a[i]的有(n-i-1)i;
(3)有一条边是等于a[i]的,有(n-1);
则最终sum应为:for(int i=0;i<n;i++){
sum-=(n-i-1)(n-i-2)/2;
sum-=(n-i-1)*i;
sum-=n-1;
}
代码
#include<algorithm>
#include<iostream>
#include<string>
#include<cstdio>
using namespace std;
typedef long long ll;
//long long ll;
const int N = 1e5 + 5;
const double pi = acos(-1.0);
int a[N];
ll num[N * 4];
ll ans[N * 4];
struct Complex
{
double x, y;
Complex(double x1 = 0.0, double y1 = 0.0)
{
x = x1;
y = y1;
}
Complex operator -(const Complex &b)const
{
return Complex(x - b.x, y - b.y);
}
Complex operator +(const Complex &b)const
{
return Complex(x + b.x, y + b.y);
}
Complex operator *(const Complex &b)const
{
return Complex(x*b.x - y * b.y, x*b.y + y * b.x);
}
}x1[N * 4];
void change(Complex y[], int len)
{
int i, j, k;
for (i = 1, j = len / 2; i < len - 1; i++)
{
if (i < j)swap(y[i], y[j]);
k = len / 2;
while (j >= k)
{
j -= k;
k /= 2;
}
j += k;
}
return;
}
void FFT(Complex y[], int len, int on)
{
change(y, len);
for (int h = 2; h <= len; h <<= 1)
{
Complex wn(cos(-on * 2 * pi / h), sin(-on * 2 * pi / h));
for (int j = 0; j < len; j += h)
{
Complex w(1, 0);
for (int k = j; k < j + h / 2; k++)
{
Complex u = y[k];
Complex t = w * y[k + h / 2];
y[k] = u + t;
y[k + h / 2] = u - t;
w = w * wn;
}
}
}
if (on == -1)
for (int i = 0; i < len; i++)
y[i].x /= len;
return;
}
int main()
{
int T,n,len1;
scanf("%d",&T);
while (T--)
{
memset(num, 0, sizeof(num));
scanf("%d", &n);
len1 = 1;
for (int i = 0; i < n; i++)
{
scanf("%d",&a[i]);
num[a[i]]++;
}
sort(a, a + n);
int len = a[n - 1]+1;
while (len1<2*len)
{
len1 <<= 1;
}
//cout << len1 << endl;
for (int i = 0; i < len; i++)
{
x1[i]=Complex(num[i], 0);
}
for (int i = len; i < len1; i++)
{
x1[i]=Complex(0, 0);
}
FFT(x1, len1, 1);
for (int i = 0; i < len1; i++)
{
x1[i] =x1[i]*x1[i];
}
FFT(x1, len1, -1);
for (int i = 1; i < len1; i++)
{
num[i] = ll(x1[i].x + 0.5);
}len1 = 2 * a[n - 1];
for (int i = 1; i < n; i++)
{
num[2 * a[i]]--;
}
for (int i = 0; i <=len1; i++)
num[i] /= 2;
num[0] = 0;
for (int i = 1; i <=len1; i++)
{
num[i] += num[i - 1];
}
ll sum = 0;
for (int i = 0; i < n; i++)
{
sum += num[len1] - num[a[i]];
}
// cout << sum << endl;
for (int i = 0; i < n; i++)
{
sum -= ll(n - i - 1)*i;
sum -= ll(n - 1 - i)*(n - i - 2) / 2;
}
sum -= 1ll*n*(n - 1);
// cout << sum << endl;
ll ant = 1ll * n*(n - 1)*(n - 2) / 6;
printf("%.7f\n", (double)sum / ant);
}
}
在这里插入代码片