Page List

Search on the blog

2011年8月10日水曜日

サルでも分かるFFT(4)

前回「サルでも分かるFFT(3)」で、FFTのC++による実装を書きました。この実装は、O(n log n)で動くのですが、まだまだ遅いです。
  • 入力用と出力用バッファのデータ移動
  • 出力用バッファは毎回newしている
  • ビット逆転で毎回Nビットなめている
  • 同じWを何回も計算している
この辺の処理を工夫して、定数項を小さくします。

 まず、一つ目と二つ目は、in-placeに変換を行うことで解決できます。つまり、入力のデータを直接いじって入力用バッファに結果を書き込みます。三つ目は、ループを一つ増やしてinnermostなloopの外で計算することで解決します。(余談ですが、exp()とかhypot()は相当重い処理です。)
四つ目は、ビット逆転では、LSBとMSBが逆転しているということを利用すれば、効率的に書くことができます。(spaghetti sourceのを参考にしました。)

  1. #include <complex>  
  2.   
  3. typedef complex<double> comp;  
  4.   
  5. /* 
  6.  * sgn -1 : time to freq 
  7.  *      1 : freq to time 
  8.  */  
  9. void FFT(comp x[], int n, int sgn) {  
  10.     comp theta = comp(0, sgn*2*PI/n);  
  11.   
  12.     for (int m = n; m >= 2; m>>=1) {  
  13.         int mh = m >> 1;  
  14.   
  15.         for (int i = 0; i < mh; i++) {  
  16.             comp W = exp(1.0*i*theta);  
  17.   
  18.             for (int j = i; j < n; j += m) {  
  19.                 int k = j + mh;  
  20.                 comp x_t = x[j] - x[k];  
  21.                 x[j] += x[k];  
  22.                 x[k] = W * x_t;  
  23.             }  
  24.         }  
  25.         theta *= 2;  
  26.     }  
  27.   
  28.     int i = 0;  
  29.     for (int j = 1; j < n - 1; j++) {  
  30.         for (int k = n >> 1; k > (i ^= k); k >>= 1);  
  31.         if (j < i) swap(x[i], x[j]);  
  32.     }  
  33. }  


前回は、時間間引き(時間成分を偶数、奇数に分ける)のFFTを書きましたが、上のソースは周波数間引き(周波数成分を偶数、奇数に分ける)で書いています。どちらも同じように書けますが、バタフライ演算の向きが逆になります。

ビット逆転のところが興味深いですが、ビット演算を使わない書くと処理の意味が分かります。LSBとMSBを逆に考えて1をインクリメントしています。

  1. void bit_reverse(int x[], int N) {  
  2.     int r = 0;  
  3.   
  4.     for (int i = 1; i < N-1; i++) {  
  5.         int k = N/2;  
  6.   
  7.         while (r >= k) {  
  8.             r -= k;  
  9.             k /= 2;  
  10.         }  
  11.         r += k;  
  12.   
  13.         if (r > i)  
  14.             swap(x[i], x[r]);  
  15.     }  
  16. }  

さて、最終テーマであるFFTを用いた多倍長整数乗算の高速化です。
普通に多倍長整数の乗算をやると、O(n^2)の時間が必要です(nは桁数)。100万桁同士の掛け算をすると、30分~1時間程度かかるのではないかと思います。FFTを使うと、O(n log n)で乗算ができます。100万桁同士の掛け算でも数秒で終わる計算量です。

数学的な詳しい解説は、ここのサイトを参照。ここでは、大まかな処理だけをまとめます。
「文字列xとyが与えられる。z = x * yを求める。」とします。
  1. 上記のFFTが使用できるように信号長を2のべき乗とします。N >= max(len(x), len(y))となるような最小の2のべき乗Nを求めます。
  2. 文字列を信号列に変換します。x_t[]、y_t[]にそれぞれx, yのdigitsをLSBから詰めていきます。x_t[i], i=len(x), len(x)+1, ..., 2*N-1、および、y_t[i], i=len(y), len(y)+1, ...,2*N-1には0を入れます。
  3. x_t[]、y_t[]を長さ2*Nの信号としてフーリエ変換します。
  4. z_t[i] = x_t[i] * y_t[i], i = 0,1,..., 2*N-1とします。
  5. z_t[]を逆フーリエ変換します。
  6. z_t[]はx*yと一致しています。ただし、繰り上がりの処理を忘れずに。
以下ソース。SPOJのNot So Fast Multiplicationを解きました。

  1. char num1[1<<16], num2[1<<16];  
  2. comp xt[1<<16], yt[1<<16], zt[1<<16];  
  3. int carried[1<<16];  
  4. char ans[1<<16];  
  5.   
  6. /* 
  7.  * conversion from a string to a signal sequence 
  8.  */  
  9. void convert(const char x[], comp *ret, int len) {  
  10.     for (int i = 0; i < len; i++)  
  11.         ret[i] = x[i]-'0';  
  12.   
  13.     for (int i = 0; i < len; i++)  
  14.         ret[len+i] = 0;  
  15. }  
  16.   
  17. /* 
  18.  * conversion from a signal sequence to an integer 
  19.  */  
  20. void toInteger(int N) {  
  21.     for (int i = 0; i < N; i++)  
  22.         carried[i] = (int)(zt[i].real()/(N) + 0.5);  
  23.   
  24.     for (int i = 0; i < N; i++) {  
  25.         if (carried[i] >= 10) {  
  26.             carried[i+1] += carried[i] / 10;  
  27.             carried[i] %= 10;  
  28.         }  
  29.     }  
  30.   
  31.     int last = N-1;  
  32.     for (; last >= 0 && !carried[last]; last--)  
  33.         ;  
  34.   
  35.     for (int i = 0; i <= last; i++)  
  36.         ans[last-i] = (char)(carried[i]+'0');  
  37.     if (last == -1)  
  38.         ans[0] = '0';  
  39. }  
  40.   
  41. /* 
  42.  * O(n log n) multiplication 
  43.  */  
  44. void multiply(char x[], char y[]) {  
  45.     int len1 = strlen(x);  
  46.     int len2 = strlen(y);  
  47.   
  48.     int N = max(len1, len2);  
  49.   
  50.     int b = 1;  
  51.     while (b < N)  
  52.         b *= 2;  
  53.     N = b;  
  54.   
  55.     reverse(x, x+len1);  
  56.     reverse(y, y+len2);  
  57.   
  58.     for (int i = len1; i < N; i++)  
  59.         x[i] = '0';  
  60.     for (int i = len2; i < N; i++)  
  61.         y[i] = '0';  
  62.   
  63.     convert(x, xt, N);  
  64.     convert(y, yt, N);  
  65.   
  66.     FFT(xt, 2*N, -1);  
  67.     FFT(yt, 2*N, -1);  
  68.   
  69.     for (int i = 0; i < 2*N; i++)  
  70.         zt[i] = xt[i] * yt[i];  
  71.   
  72.     FFT(zt, 2*N, 1);  
  73.     toInteger(2*N);  
  74. }  
  75.   
  76. int main() {  
  77.     int n;  
  78.     scanf("%d", &n);  
  79.   
  80.     while (n--) {  
  81.         scanf("%s %s", num1, num2);  
  82.   
  83.         memset(ans, 0, sizeof(ans));  
  84.         multiply(num1, num2);  
  85.         printf("%s\n", ans);  
  86.     }  
  87.     return 0;  
  88.   
  89. }  

0 件のコメント:

コメントを投稿