整系数Discrete Fourier transform技巧

Discrete Fourier transform

向量的DFT 定义为:

其中是primitive root of unity of order ,即:。 通常取

Inverse transform

性质: (j模n为0时取1,否则取0)

因此

假设加减乘除时间复杂度为,朴素实现的时间复杂度为。 一个fast Fourier transform算法的时间复杂度为。 radix-2的fast Fourier transform算法最常见。

注意使用的是DFT用的系数的倒数。如果使用DFT实现IDFT可以先翻转向量:

1
2
3
4
5
6
reverse(a+1, a+n);

// compute DFT

for (long i = 0; i < n; i++)
a[i] *= 1.0/n;

Decimation in time FFT

Cooley-Tukey是一种radix-2分治算法,把偶数下标的子序列的Fourier transform与奇数下标的子序列的Fourier transform,用时间合并。

Decimation in time在变换之后对下标进行bitreverse操作。

Decimation in frequency FFT

Sande-Tukey是另一种radix-2分治算法,把前一半子序列的Fourier transform与后一半子序列的Fourier transform,用时间合并。

Decimation in frequency在变换之前对下标进行bitreverse操作。如果Decimation in time的DFT与Decimation in frequency的IDFT一起用,可以省略两处bitreverse操作。

Number theoretic transform

DFT可以用于复数以外的其他ring,常用于

使用128 bits模数需要高效的u64*u64%u64,其中模数是常数。

硬件除法指令(32 bits、64 bits)

DIV指令性能很低。

1
2
3
4
5
6
extern inline u64 mul_mod(u64 a, u64 b, u64 m)
{
u64 r;
asm("mulq %2\n\tdivq %3" : "=&d"(r), "+a"(a) : "rm"(b), "rm"(m) : "cc");
return r;
}

到AVX-512也没有提供把两个64 bits乘数的积放在一个128 bits寄存器的指令,GCC没有提供用乘法、移位等模拟的u128除以u64的常量除法。

64位mantissa浮点数(32 bits、64 bits)

当模数时可以用64位mantissa浮点数计算u64*u64%u64

由等式

两边模,得

即用u64乘法计算的低64位,减去的低64位。其中,可以用64位mantissa浮点数(Intel x87 80-bit double-extended precision)计算,再round成u64

round时若向上取整了,减数会大于被减数。若,可以根据差的符号位判断。

1
2
3
4
5
u64 mul_mod(u64 a, u64 b, u64 P)
{
u64 x = a*b, r = x - P*u64((long double)a*b/P+0.5);
return i64(r) < 0 ? r + P : r;
}

存储的倒数,用(long double)a*b*Q代替(long double)a*b/P能快些。此时会引入额外的误差,Matters Computational说适用于,原因不明。

编译器生成的常量除法(32 bits)

对于固定的模,GCC/llvm可以生成u64%u32的高效代码。llvm的lib/Transforms/Utils/IntegerDivision.cpp

Montgomery modular multiplication+Barret reduction(32 bits)

Faster arithmetic for number-theoretic transforms,用乘法和移位代替除法。

快速计算u64%u32,用乘法和移位代替除法。设为大于等于的整数,为负整数。

的估计值,若大于等于则减去的倍数。

因此

,则,即估计值最多小2,最多两个conditional move指令(if (r >= P) r -= P;)即可修正余数。

,则估计值最多小1。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
extern inline u64 barrett_30_2(u64 a, u64 P, u64 M)
{ // 2^29 < P < 2^30, M = floor(2^61/P)
u64 r = a-((a>>28)*M>>33)*P;
if (r >= P) r -= P;
return r;
}

extern inline u64 barrett_30_1(u64 a, u64 P, u64 M)
{ // 2^29 < P < 2^30, M = floor(2^60/P)
u64 r = a-((a>>29)*M>>31)*P;
if (r >= P) r -= P;
if (r >= P) r -= P;
return r;
}

当模数为常量时,该算法不如编译器生成的常量除法。若模数不固定时可以考虑使用。

Cyclic convolution

两个长为的序列的cyclic convolution的长度也是。第项定义为:

Linear convolution

两个长为的序列的linear convolution的长度是

项总是0。多项式乘法是一种常见的linear convolution应用。

Zero pad原序列到长度后,计算cyclic convolution即可得到linear convolution。

范围和精度

向量卷积需要注意精度问题。

使用complex<double>计算convolution,需要保证结果每一项的实数部分在(Number.MIN_SAFE_VALUENumber.MAX_SAFE_INTEGER)范围内,是double能精确表示的最大整数。采取round half to even规则,均表示为,无法区分。

设每项系数的绝对值小于等于,那么convolution结果每一项绝对值小于等于,若则可放心使用complex<double>

complex<double>还要受到浮点运算误差影响。根据Roundoff Error Analysis of the Fast Fourier Transform,没仔细看,relative error均值为log2(n)*浮点运算精度*变换前系数最大值,对于结果,这个量达到就很可能出错。增长速度可以看作是,不如。因此通常不必担心浮点运算的误差。

对于模的number theoretic transform,,若则可放心使用。

1004535809 (=479*2**21+1), 998244353 (=119*2**23+1), 897581057 (=107*2**23+1),这三个数均小于,两倍不超过INT32_MAX(两个+-P之间的数加减不会超出int32_t表示范围),且可表示为,其中为2的幂,适合用作number theoretic transform的模。3是它们共同的原根。 另外可以选取880803841 (=105*2**23+1),26是一个原根。

系数取自的uniform distribution,则系数均值为,方差为。若把系数平移至,则系数均值,方差为。若其中之一independent and identically distributed,则方差会很小。可以用Chebyshev's inequality等估计系数绝对值,上界可以减小。即使不是independent and identically distributed,也可以用来计算,是independent and identically distributed uniform

下面考察模P多项式乘法,碰到精度问题时的两种应对方案:sqrt decomposition、Chinese remainder theorem。

方案0:sqrt decomposition(FFT, NTT)

适用于FFT与NTT。 取为接近的整数,分解,则:

适当选择可以使的系数小于等于,convolution结果系数最大值为,比原来的小。

求出后,计算等式右边四个convolution,带权相加即得到原convolution。

如上朴素方案需要4次长为的DFT、1次长为的inverse DFT。

一种优化方案是使用Toom-2 (Karatsuba)计算,可以减少为3次DFT、1次inverse DFT。

容易扩展到cube root decomposition等。对于number theoretic transform,分成份需要个DFT和个IDFT,不如用Chinese remainder theorem。

优化0:实部表示低位、虚部表示高位

FFT可以计算复数的DFT,但在朴素的多项式乘法中,FFT只作用于实数向量,虚数部分浪费了。 sqrt decomposition把一个系数拆分成两项,我们可以把两项装载到实数部分和虚数部分。 具体方法如下:

取正整数接近且是一个尽可能小的正整数。

分解。 考察如下的变换:

注意每一项绝对值的值域变为倍了,因此计算对精度有更高的要求。取较小的可以降低精度要求。

右边同余于。 提取虚部与实部,将虚部除以再乘以,再加上实部即得:

正是我们需要的形式。

这个优化需要2次长为的DFT、1次长为的inverse DFT。

优化1:正交计算两个实系数向量DFT(FFT)

这是另一种利用向量的虚数部分的方法。

的共轭的DFT可由的DFT求出:

换言之,计算一个复数向量的DFT,可以通过简单变换,得到实数部分向量的DFT与实数部分向量的DFT。买一送一。

分解后,根据上面的公式,用计算,同法计算。然后用计算出,同法计算出

需要2次长为的DFT、2次长为的inverse DFT。

奇偶项优化(FFT)

该优化可以和其他方式叠加。

看作多项式,同样地,看作多项式

偶次项, 奇次项,同样地,定义的系数为的系数为,令其长为,高位用填充。

用正交计算两个实系数向量DFT的方式,用2次长度为(之前都是 )的DFT计算, 循环右移1位的DFT的第项等于,因此根据的DFT的系数可以得到的DFT的系数。

构造长为的向量

的实部为结果的偶次项系数,虚部为结果的奇次项系数。

需要2次长为的DFT、1次长为的inverse DFT。

方案1:Chinese remainder theorem(NTT)

适用于NTT。 取个可用于number theoretic transform的质数,使,计算个NTT,之后用Chinese remainder theorem合并。

有至少两种算法。

经典算法(Gauss's algorithm)

Gauss之前也有很多人提出。

对于每个用Blankinship's algorithm计算

注意可能超出机器single-precision表示范围,该算法不适合求

Garner's algorithm

定义

满足前个方程,满足所有方程。

稍加变形可用于求

原来的每个是精确值,现在只有的结果,因此计算时之前的都不可复用,需要重新计算。时间复杂度上升为

测试

https://gist.github.com/MaskRay/fac2042058dd5d9e59953f18f3f3978a

NTT int使用小于的素数,编译器用乘法模拟常数除法。 NTT long,编译器无法优化常数除法,性能很低,使用浮点mul_mod会略快于128位除以64位的DIV指令。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
n	microseconds	algorithm

131072 3860 NTT dit2 int
131072 6104 FFT dif2
131072 6712 FFT dit2
131072 6912 NTT dif2 int
131072 6936 Montgomery+Barrett NTT dif2 int
131072 9592 NTT dif2 long non-constant P
131072 10122 NTT dit2 long
131072 13169 NTT dif2 int non-constant P
131072 15419 NTT dif2 long

262144 8993 NTT dif2 int
262144 9036 NTT dit2 int
262144 9670 Montgomery+Barrett NTT dif2 int
262144 15484 FFT dit2
262144 17601 FFT dif2
262144 19731 NTT dit2 long
262144 20527 NTT dif2 long non-constant P
262144 21910 NTT dif2 long
262144 29457 NTT dif2 int non-constant P

524288 18502 NTT dif2 int
524288 20110 Montgomery+Barrett NTT dif2 int
524288 23156 NTT dit2 int
524288 39890 FFT dif2
524288 39904 FFT dit2
524288 44145 NTT dif2 long non-constant P
524288 45038 NTT dit2 long
524288 46334 NTT dif2 long
524288 65265 NTT dif2 int non-constant P

1048576 43648 NTT dit2 int
1048576 45704 NTT dif2 int
1048576 46167 Montgomery+Barrett NTT dif2 int
1048576 104362 NTT dit2 long
1048576 107571 NTT dif2 long non-constant P
1048576 119743 FFT dif2
1048576 122029 NTT dif2 long
1048576 122174 FFT dit2
1048576 144370 NTT dif2 int non-constant P

2097152 122989 Montgomery+Barrett NTT dif2 int
2097152 137276 NTT dif2 int
2097152 143955 NTT dit2 int
2097152 293222 FFT dif2
2097152 338580 FFT dit2
2097152 352833 NTT dif2 int non-constant P
2097152 360372 NTT dif2 long non-constant P
2097152 422108 NTT dit2 long
2097152 423817 NTT dif2 long

4194304 455859 NTT dit2 int
4194304 467340 NTT dif2 int
4194304 490114 Montgomery+Barrett NTT dif2 int
4194304 779945 FFT dif2
4194304 839698 FFT dit2
4194304 904096 NTT dit2 long
4194304 956174 NTT dif2 long
4194304 969572 NTT dif2 long non-constant P
4194304 1074858 NTT dif2 int non-constant P

8388608 1052072 NTT dit2 int
8388608 1138089 NTT dif2 int
8388608 1189775 Montgomery+Barrett NTT dif2 int
8388608 1737166 FFT dif2
8388608 1839095 FFT dit2
8388608 2053195 NTT dif2 long
8388608 2072172 NTT dif2 long non-constant P
8388608 2186451 NTT dit2 long
8388608 2893584 NTT dif2 int non-constant P

花哨的Montgomery+Barrett不如常量除法的NTT int,好处是一份代码可以适用于多个模数,而NTT int得用template或其他方式为各个模数生成不同代码。

不受到Level 3 cache制约时,Montgomery NTT只需要NTT int 60%的时间,此时每次重新计算unit root代替lookup table会快些。

一般来说,decimation in frequency(Sande-Tukey,从较大的n计算到较小的n)优于decimation in time(Cooley-Tukey,从较小的n计算到较大的n),可能是因为decimation in frequency的butterfly数据依赖小些。

FFT有超过10%时间花在计算unit roots上,而NTT只有5%。考虑到FFT往往能正交计算两个序列,而NTT只能计算一个,且double有53位精度而NTT int只能使用以下的素数(当前代码只能处理以下的),FFT通常优于NTT。

References

感谢ftiasch老师教导。

uwi,https://www.hackerrank.com/rest/contests/w23/challenges/sasha-and-swaps-ii/hackers/uwi/download_solution:Modified Montgomery+Barrett变体+Garner's algorithm: