[关闭]
@skyword 2016-10-04T17:01:00.000000Z 字数 10475 阅读 440

多项式计算--FFT与NTT学习小结

数学 数论


看了两天FFT
一直没有学fft,直到10天前打beijing online contest时,碰到了一个NTT的裸题,发现必须要看看了

hexo博客配置有一些问题,mathjax公式一直不能正常加载,可以访问我在cmdmarkdown上发布的版本 here
这几天看了一些博客,发现它纯粹在数学上需要功底。FFT,NTT的题目中,简单的那部分基本可以分为两个方面:建模列出表达式+套版
理解模板中那些纯粹属于FFT,NTT的做法并不难。难的是怎么用到题目中,这需要数学功底。。

我就不在这里从头写了,网上已经有了很好的博客来讲解FFT
首先,我决定尝试叙述一下逻辑关系:
我们现在要解决的,是高效的计算多项式乘法,传统朴素的方法是 的。
而FFT,即快速傅里叶变换,通过某种技巧,使复杂度降到 , 但是FFT的弱点在于,它的计算在复数域内进行,因此存在精度问题。
于是有了NTT,即快速数论变换。这一变换是对FFT的改进,使得所有计算在,即模p剩余系下考虑。因此,整数意义下的FFT问题可以用NTT来解决,从而避免精度误差。

  1. FFT

    贴出一些文章:

这两篇文章说的很清楚。傅里叶变换建立在一些基础的处理手法和引理上。
首先是多项式的点值表达,很好理解,一个n次一元多项式,与n个不同的点是相互确定的。这个可以从代数的角度上来理解。文章的推导也阐明了这点。
点值表达给我们带来的是,我们想要表述一个多项式,只需任取n个不同的点,就能唯一地表述它。我们想要求一个多项式,只要设法找到n个不同的该多项式经过的点,就可以求出其系数表达。
因此,我们接下来的工作就放在点值上来考虑。经验表明,只有真正想清楚了点值表达是什么以及它对下面的工作起到什么作用,才能更好地理解下面的推导。

教程里已经推导了,单纯的去取n个点,来作点值表达,是可以完成多项式相乘计算的。然而,普遍的时间复杂度还是 的。

到这时候,FFT才真正登场。
FFT的做法是,同样基于点值表达,但是选了一组很巧妙的点,即重单位复根


这组点的点值之间存在着巧妙的关系,使得我们可以用分治的思路来求解所有点值,复杂度降到
思想就是这些,剩下的就是引理,DFT&IDFT,Rader等等具体的一些东西了,上面的文章讲得很清楚了

  1. NTT
    多项式计算终极版(ACdreamer)
    关于gn的取法和原根的理解,这一篇说的比较好:
    原根与NTT
    正如刚才所说,FFT存在精度上的误差,有时候直接影响答案的正确输出。因此,针对整数意义上的问题,我们有NTT
    NTT是一个改进,我们选取一组新的点值

    其中为质数,是2的方幂,且满足

上面的文章证明了,具有同n重单位复根一样的那几条性质,因此可以代替n重单位复根来做计算。

所以要写NTT,只需要改动取 的部分即可。
我们需要的常数是大质数p,和它的原根g,p一般是取费马素数的形式,即:


常用的是 , 对应原根
p并不总是随便取的,有时候题目要求你取某个指定的模数p,如果模数p是费马素数,那么不用做任何改动,求出其原根,对应的做一些更改就好。
原根的求法和更多的一些知识:这里
(这个我还没仔细看。。目前做的题还很少。。)
如果是任意的质数,听说的做法是找两个小的费马素数来分别算,然后再用crt来合并,好吧我还没做这类题。。

目前想得到的一些自己对FFT,NTT的理解就是这样了,代码中有些细节的地方还需要我仔细琢磨,并没有搞得很透彻。

几道题目:
FFT入门题:
hdu1402
大数乘法,n位的大整数,实际上可以看成n次的多项式。两个大数的相乘,就可以理解成这两个多项式的乘法。在用FFT计算好乘法之外,我们额外做的就是进位和顺序调整了。

  1. #include <cstdio>
  2. #include <cstring>
  3. #include <iostream>
  4. #include <algorithm>
  5. #include <cmath>
  6. #include <vector>
  7. #include <utility>
  8. #include <string>
  9. #include <queue>
  10. //#define maxn 1024
  11. #define LL long long
  12. #define fp freopen("in.txt","r",stdin)
  13. #define fpo freopen("out9.txt","w",stdout)
  14. //#define judge
  15. using namespace std;
  16. const double eps = 1e-10;
  17. const double pi = acos(-1.0);
  18. const int maxn = 250050;
  19. const int INF = 0x3f3f3f3f;
  20. const double lg2 = log(2.0);
  21. const LL MOD = 1e9+7;
  22. #define MAX 5050
  23. struct Complex
  24. {
  25. double re,im;
  26. Complex(double re = 0.0, double im = 0.0)
  27. {
  28. this->re = re;
  29. this->im = im;
  30. }
  31. Complex operator -(const Complex &elem) const
  32. {
  33. return Complex(this->re - elem.re, this->im - elem.im);
  34. }
  35. Complex operator +(const Complex &elem) const
  36. {
  37. return Complex(this->re + elem.re, this->im + elem.im);
  38. }
  39. Complex operator *(const Complex &elem) const
  40. {
  41. return Complex(this->re * elem.re - this->im * elem.im, this->im*elem.re + this->re*elem.im);
  42. }
  43. void val(double re = 0.0, double im = 0.0)
  44. {
  45. this->re = re;
  46. this->im = im;
  47. }
  48. };
  49. Complex A[maxn],B[maxn];
  50. int res[maxn],len,multilen,len1,len2;
  51. char s1[maxn],s2[maxn];
  52. void swap(Complex &a, Complex &b)
  53. {
  54. Complex tmp = a;
  55. a = b, b=tmp;
  56. }
  57. void init()
  58. {
  59. len1 = strlen(s1),len2 = strlen(s2);
  60. multilen = max(len1,len2);
  61. len = 1;
  62. while(len<(multilen<<1)) len<<=1;
  63. for(int i = 0; i <len1; i++)
  64. A[i].val(s1[len1-i-1]-'0',0);
  65. for(int i = 0; i < len2; i++)
  66. B[i].val(s2[len2-i-1]-'0',0);
  67. for(int i = len1; i < len; i++)A[i].val();
  68. for(int i = len2; i < len; i++)B[i].val();
  69. }
  70. void rader(Complex y[])
  71. {
  72. for(int i = 1,j=len>>1,k; i < len-1; i++)
  73. {
  74. if(i<j) swap(y[i],y[j]);
  75. k = len>>1;
  76. while(j >= k)
  77. {
  78. j-=k;
  79. k>>=1;
  80. }
  81. if(j<k)j+=k;
  82. }
  83. }
  84. //op==1 DFT
  85. //op==-1 IDFT
  86. void FFT(Complex y[], int op)
  87. {
  88. rader(y);
  89. for(int h = 2; h <=len; h<<=1)
  90. {
  91. Complex wn(cos(op*2*pi/h),sin(op*2*pi/h));
  92. for(int i = 0; i < len; i +=h)
  93. {
  94. Complex w(1,0);
  95. for(int j = i; j < i+h/2; j++)
  96. {
  97. Complex u = y[j];
  98. Complex t = w * y[j + h/2];
  99. y[j] = u + t;
  100. y[j+h/2] = u - t;
  101. w = w *wn;
  102. }
  103. }
  104. }
  105. if(op==-1)// for IDFT
  106. {
  107. for(int i = 0; i < len ; i++)
  108. {
  109. y[i].re/=len;
  110. }
  111. }
  112. }
  113. void convolution(Complex *A, Complex *B)
  114. {
  115. FFT(A,1),FFT(B,1);
  116. for(int i = 0; i <len; i++)
  117. {
  118. A[i] = A[i] * B[i];
  119. }
  120. FFT(A,-1);
  121. for(int i = 0 ; i < len; i++)
  122. res[i] = (int)(A[i].re +0.5);
  123. }
  124. void adjust(int *arr)
  125. {
  126. for(int i = 0; i < len ; i++)
  127. {
  128. res[i+1]+=res[i]/10;
  129. res[i]%=10;
  130. }
  131. while(--len && res[len]==0);
  132. }
  133. void print(int *arr)
  134. {
  135. for(int i = len; i>=0;i--)
  136. {
  137. printf("%c",arr[i]+'0');
  138. }
  139. printf("\n");
  140. }
  141. int main()
  142. {
  143. while(gets(s1)&&gets(s2))
  144. {
  145. init();
  146. convolution(A,B);
  147. adjust(res);
  148. print(res);
  149. }
  150. }

51nod上有一道同样的题目,好像规模更大,我用fft和ntt做了两次,有趣的是NTT跑的比FFT快了不少。

hdu4609
n条线段,有各自的长度,任取三条,求能组成三角形的组合数目。
这个题就不那么裸了,用到FFT的地方是,用num[]数组来存任意取两条线段能组成长度为i的方案数,基于两边之和大于第三边的原则,枚举三角形中的最长边来计数。然后去掉几类重复的情况,个人觉得这是更考察功底的地方。。kuangbin的题解

  1. /*************************************************************************
  2. > File Name: hdu4609.cpp
  3. > Author: skyword
  4. > Mail: skywordsun@gmail.com
  5. > Created Time: 2016年10月02日 星期日 11时11分32秒
  6. ************************************************************************/
  7. #include <iostream>
  8. #include <cstdio>
  9. #include <cstring>
  10. #include <cmath>
  11. #include <vector>
  12. #include <set>
  13. #include <algorithm>
  14. #include <queue>
  15. using namespace std;
  16. const double pi = acos(-1.0);
  17. #define maxn 400050
  18. #define ll long long
  19. #define fp freopen("in.txt","r",stdin);
  20. struct Complex
  21. {
  22. double re,im;
  23. Complex(double re=0.0, double im = 0.0)
  24. {
  25. this->re = re; this->im = im;
  26. }
  27. Complex operator +(const Complex &b)
  28. {
  29. return Complex(this->re + b.re, this->im + b.im);
  30. }
  31. Complex operator -(const Complex &b)
  32. {
  33. return Complex(this->re - b.re , this->im - b.im);
  34. }
  35. Complex operator *(const Complex &b)
  36. {
  37. return Complex(this->re*b.re - this->im*b.im , this->re*b.im + this->im*b.re);
  38. }
  39. };
  40. void rader(Complex y[],int len)
  41. {
  42. int i,j,k;
  43. for(i = 1,j=len/2; i<len-1;i++)
  44. {
  45. if(i<j) swap(y[i],y[j]);
  46. k = len/2;
  47. while(j >= k)
  48. {
  49. j-=k;
  50. k/=2;
  51. }
  52. if(j<k) j += k;
  53. }
  54. }
  55. void fft(Complex y[], int len , int op)
  56. {
  57. rader(y,len);
  58. for(int h = 2; h <= len; h<<=1)
  59. {
  60. Complex wn(cos(-op*2*pi/h),sin(-op*2*pi/h));
  61. for(int j = 0; j < len; j+=h)
  62. {
  63. Complex w(1,0);
  64. for(int k = j; k < j+h/2; k++)
  65. {
  66. Complex u = y[k];
  67. Complex t = w*y[k+h/2];
  68. y[k] = u+t;
  69. y[k+h/2] = u-t;//butterfly op
  70. w = w* wn;
  71. }
  72. }
  73. }
  74. if(op == -1)
  75. {
  76. for(int i =0;i<len;i++)
  77. {
  78. y[i].re /= len;
  79. }
  80. }
  81. }
  82. Complex x[maxn];
  83. int a[maxn];
  84. ll num[maxn],sum[maxn];
  85. int t,n;
  86. int main()
  87. {
  88. scanf("%d",&t);
  89. while(t--)
  90. {
  91. scanf("%d",&n);
  92. memset(num,0,sizeof(num));
  93. for(int i = 0; i<n;i++)
  94. {
  95. scanf("%d",&a[i]);
  96. num[a[i]]++;
  97. }
  98. sort(a,a+n);
  99. int len1 = a[n-1]+1;
  100. int len = 1; // multilength
  101. while(len < 2*len1) len <<= 1;
  102. //init
  103. for(int i = 0; i<len1 ; i++) x[i] = Complex(num[i],0);
  104. for(int i = len1; i<len; i++) x[i] = Complex(0,0);
  105. fft(x,len,1);
  106. for(int i= 0; i<len;i++) x[i] = x[i] * x[i];
  107. fft(x,len,-1);
  108. for(int i = 0; i < len; i++)
  109. {
  110. num[i] = (ll)(x[i].re + 0.5);
  111. }
  112. len = 2*a[n-1];
  113. for(int i =0; i<n;i++)
  114. num[a[i]*2]--;
  115. for(int i = 0; i <=len;i++) num[i]/=2;
  116. sum[0] = 0;
  117. for(int i = 1; i<=len;i++) sum[i] = sum[i-1] + num[i];
  118. ll cnt = 0;
  119. for(int i = 0; i<n;i++)
  120. {
  121. cnt += (sum[len]-sum[a[i]]);
  122. cnt -= (ll)(n-1-i)*i;
  123. cnt -= (n-1);
  124. cnt -= (ll)(n-1-i)*(n-2-i)/2;
  125. }
  126. ll all = (ll)n*(n-1)*(n-2)/6;
  127. double ans = (double)cnt/all;
  128. printf("%.7lf\n",ans);
  129. }
  130. }

hdu5829
这是今年多校第八场的1009.
题意就叙述的很拗口。。我真的感觉有些出题人该好好润色一下英语表达了。。
题目给了一个n元素的数集,考虑任意的非空子集,考虑其前k大元素(如果子集本身元素数目小于k,就指的是其中全部元素)之和。对每一个给定的k,对所有子集,求sum的总和,即

懒得写了。。网上有题解。。
这个题用NTT做蛮好。更难的点是如何处理出卷积的形式。我自己琢磨了挺久的。。什么时候单独写个题解好了。。挺考验熟练度和功底

  1. #include <cstdio>
  2. #include <cstring>
  3. #include <iostream>
  4. #include <algorithm>
  5. #include <cmath>
  6. #include <vector>
  7. #include <utility>
  8. #include <string>
  9. #include <queue>
  10. //#define maxn 1024
  11. #define LL long long
  12. #define fp freopen("in.txt","r",stdin)
  13. #define fpo freopen("out9.txt","w",stdout)
  14. //#define judge
  15. using namespace std;
  16. const double eps = 1e-10;
  17. const double pi = acos(-1.0);
  18. const int maxn = (1e5+20);
  19. const int INF = 0x3f3f3f3f;
  20. const int p = 998244353;
  21. const int G = 3;
  22. const double lg2 = log(2.0);
  23. const LL MOD = 1e9+7;
  24. #define MAX 5050
  25. LL A[maxn<<2],B[maxn<<2];
  26. LL quick_mod(LL a, LL b, LL m)
  27. {
  28. LL ans = 1;
  29. while(b)
  30. {
  31. if(b&1) ans = ans * a % m;
  32. a = a * a%m;
  33. b >>= 1;
  34. }
  35. return ans;
  36. }
  37. void rader(LL y[], int len)
  38. {
  39. for(int i = 1,j=len/2; i < len-1; i++)
  40. {
  41. if(i < j) swap(y[i], y[j]);
  42. int k = len/2;
  43. while(j >= k)
  44. {
  45. j-=k;
  46. k/=2;
  47. }
  48. if(j < k)j+=k;
  49. }
  50. }
  51. void NTT(LL y[], int len , int op)
  52. {
  53. rader(y,len);
  54. for(int h = 2; h <= len; h <<=1)
  55. {
  56. LL wn = quick_mod(G,(p-1)/h,p);
  57. if(op == -1)
  58. {
  59. wn = quick_mod(wn,p-2,p);
  60. }
  61. // now wn is the rotation factor.
  62. for(int j = 0; j <len; j+=h)
  63. {
  64. LL w = 1;
  65. for(int k = j; k < j + h/2; k++)
  66. {
  67. LL u = y[k];
  68. LL t = (w * y[k + h/2])%p;
  69. y[k] = (u + t)%p;
  70. y[k + h/2] = (u - t + p)%p;
  71. w = w * wn % p;
  72. }
  73. }
  74. }
  75. // for IDFT(or maybe we call it IFNT)
  76. if(op==-1)
  77. {
  78. LL inv = quick_mod(len , p-2, p);
  79. for(int i = 0; i <len; i++)
  80. y[i] = y[i] * inv % p;
  81. }
  82. }
  83. int n,t,a[maxn],ans[maxn];
  84. LL fac[maxn],tfac[maxn],inv_fac[maxn],inv_tfac[maxn];
  85. void init()
  86. {
  87. fac[0] = tfac[0] = inv_fac[0] = inv_tfac[0] =1;
  88. for(int i = 1; i < maxn ; i++)
  89. {
  90. fac[i] = fac[i-1] * i % p;
  91. tfac[i] = 2 * tfac[i-1] %p;
  92. inv_fac[i] = quick_mod(fac[i], p - 2, p);
  93. inv_tfac[i] = quick_mod(tfac[i], p - 2, p);
  94. }
  95. }
  96. int main()
  97. {
  98. init();
  99. scanf("%d",&t);
  100. while(t--)
  101. {
  102. scanf("%d",&n);
  103. for(int i = 1; i <= n; i++) scanf("%d",&a[i]);
  104. sort(a+1, a+1+n,greater<int>());
  105. int len = 1;
  106. while(len < ((n<<1)+1) ) len <<= 1;
  107. for(int i = 0; i < len ; i++)
  108. {
  109. if(i <= n)
  110. A[i] = tfac[n-i] * inv_fac[i] % p;
  111. else A[i] = 0;
  112. if(i <= n && i >= 1)
  113. B[i] = a[i] * fac[i-1] % p;
  114. else B[i] = 0;
  115. }
  116. reverse(B+1,B+1+n);
  117. NTT(A,len,1);
  118. NTT(B,len,1);
  119. for(int i = 0; i<len;i++)
  120. {
  121. A[i] = A[i] * B[i] % p;
  122. }
  123. NTT(A, len ,-1);
  124. for(int i = 1; i<=n; i++)
  125. {
  126. ans[i] =((inv_tfac[i] * inv_fac[i-1])%p) * A[n-i+1] %p;
  127. }
  128. for(int i = 1; i <=n;i++)
  129. {
  130. ans[i] = (ans[i] + ans[i-1]) % p;
  131. }
  132. for(int i = 1; i <=n;i++)
  133. {
  134. printf("%d ",ans[i]);
  135. }
  136. puts("");
  137. }
  138. }

hihocoder1388 : Periodic Signal
这题是今年北京网络赛的F题
经过处理可以知道,核心在于计算


处理一下发现,这是个循环卷积。同样是经过一些变换,得到可以用NTT计算的形式,这个。。以后单独写题解吧。。
目前我查到的做法其实挺多的。。①单纯用FFT做的话,会损失精度,但是只是个别低位处出错,算出的k值依然是正确的,所以耍个赖,用FFT算出最优解的k值,然后手动算一下就行 ②NTT,不会丢精度,一个难点是找足够大的质数p,这里找到的是 ,我不知道他们怎么找到这个数的。。试出来的罢。
它的原根(之一)是3

  1. #include <cstdio>
  2. #include <cstring>
  3. #include <iostream>
  4. #include <algorithm>
  5. #include <cmath>
  6. #include <vector>
  7. #include <utility>
  8. #include <string>
  9. #include <queue>
  10. //#define maxn 1024
  11. #define LL long long
  12. #define fp freopen("in.txt","r",stdin)
  13. #define fpo freopen("out9.txt","w",stdout)
  14. //#define judge
  15. using namespace std;
  16. const double eps = 1e-10;
  17. const double pi = acos(-1.0);
  18. const int maxn = 200020;
  19. const int INF = 0x3f3f3f3f;
  20. const LL p = 180143985094819841LL;
  21. const int G = 3;
  22. const double lg2 = log(2.0);
  23. const LL MOD = 1e9+7;
  24. #define MAX 5050
  25. LL wn[20];
  26. LL mul(LL x,LL y)
  27. {
  28. return (x*y-(LL)(x / (long double)p*y+1e-3)*p +p)%p;
  29. }
  30. LL quick_mod(LL a, LL b, LL m)
  31. {
  32. LL ans = 1;
  33. while(b)
  34. {
  35. if(b&1) ans = mul(ans , a );
  36. a = mul(a , a);
  37. b >>= 1;
  38. }
  39. return ans;
  40. }
  41. void getwn()
  42. {
  43. for(int i = 1; i <=18;i++)
  44. {
  45. int t = 1<<i;
  46. wn[i] = quick_mod(G,(p-1)/t,p);
  47. }
  48. }
  49. void rader(LL y[], int len)
  50. {
  51. for(int i = 1,j=len/2; i < len-1; i++)
  52. {
  53. if(i < j) swap(y[i], y[j]);
  54. int k = len/2;
  55. while(j >= k)
  56. {
  57. j-=k;
  58. k/=2;
  59. }
  60. if(j < k)j+=k;
  61. }
  62. }
  63. void NTT(LL y[], int len , int op)
  64. {
  65. rader(y,len);
  66. int id = 0;
  67. for(int h = 2; h <= len; h <<=1)
  68. {
  69. id++;
  70. // now wn is the rotation factor.
  71. for(int j = 0; j <len; j+=h)
  72. {
  73. LL w = 1;
  74. for(int k = j; k < j + h/2; k++)
  75. {
  76. LL u = y[k];
  77. LL t = mul(y[k + h/2], w);
  78. y[k] = (u + t)%p;
  79. y[k + h/2] = (u - t + p)%p;
  80. w = mul(w , wn[id]);
  81. }
  82. }
  83. }
  84. // for IDFT(or maybe we call it IFNT)
  85. if(op==-1)
  86. {
  87. for(int i = 1; i < len/2; i++)
  88. swap(y[i], y[len-i]);
  89. LL inv = quick_mod(len , p-2, p);
  90. for(int i = 0; i <len; i++)
  91. y[i] = mul(y[i] , inv );
  92. }
  93. }
  94. int t,n;
  95. LL a[60060],b[60060];
  96. LL A[maxn],B[maxn],C[maxn],ans;
  97. LL sum = 0;
  98. void init()
  99. {
  100. for(int i = 0; i < maxn;i++)
  101. {
  102. if(i<60060)
  103. A[i]=B[i]=C[i]=a[i]=b[i]=0;
  104. else A[i]=B[i]=C[i]=0;
  105. }
  106. }
  107. int main()
  108. {
  109. getwn();
  110. scanf("%d",&t);
  111. while(t--)
  112. {
  113. init();
  114. sum = 0;
  115. scanf("%d",&n);
  116. for(int i = 0; i < n; i++)
  117. {
  118. scanf("%lld",&a[i]);
  119. sum += a[i]*a[i];
  120. }
  121. for(int i = 0; i < n; i++)
  122. {
  123. scanf("%lld",&b[i]);
  124. sum += b[i]*b[i];
  125. }
  126. int len = 1;
  127. while(len < (n<<1) ) len<<=1;
  128. for(int i = 0; i < n ; i++)
  129. {
  130. A[i] = a[i];
  131. }
  132. for(int i = 0; i < n; i++)
  133. {
  134. B[i] = b[n-1-i];
  135. }
  136. NTT(A,len,1);
  137. NTT(B,len,1);
  138. for(int i = 0; i < len; i++)
  139. {
  140. C[i] = mul(A[i],B[i]);
  141. }
  142. NTT(C,len,-1);
  143. ans = C[n-1];
  144. for(int i = 0 ; i < n-2; i++)
  145. {
  146. ans = max(ans, C[i]+C[i+n]);
  147. }
  148. //cout<<"**"<<sum<<endl;
  149. sum -= (2LL*ans);
  150. printf("%lld\n",sum);
  151. }
  152. }

就先写这么多吧。。

添加新批注
在作者公开此批注前,只有你和作者可见。
回复批注