- 入力用と出力用バッファのデータ移動
- 出力用バッファは毎回newしている
- ビット逆転で毎回Nビットなめている
- 同じWを何回も計算している
この辺の処理を工夫して、定数項を小さくします。
まず、一つ目と二つ目は、in-placeに変換を行うことで解決できます。つまり、入力のデータを直接いじって入力用バッファに結果を書き込みます。三つ目は、ループを一つ増やしてinnermostなloopの外で計算することで解決します。(余談ですが、exp()とかhypot()は相当重い処理です。)
四つ目は、ビット逆転では、LSBとMSBが逆転しているということを利用すれば、効率的に書くことができます。(spaghetti sourceのを参考にしました。)
- #include <complex>
- typedef complex<double> comp;
- /*
- * sgn -1 : time to freq
- * 1 : freq to time
- */
- void FFT(comp x[], int n, int sgn) {
- comp theta = comp(0, sgn*2*PI/n);
- for (int m = n; m >= 2; m>>=1) {
- int mh = m >> 1;
- for (int i = 0; i < mh; i++) {
- comp W = exp(1.0*i*theta);
- for (int j = i; j < n; j += m) {
- int k = j + mh;
- comp x_t = x[j] - x[k];
- x[j] += x[k];
- x[k] = W * x_t;
- }
- }
- theta *= 2;
- }
- int i = 0;
- for (int j = 1; j < n - 1; j++) {
- for (int k = n >> 1; k > (i ^= k); k >>= 1);
- if (j < i) swap(x[i], x[j]);
- }
- }
前回は、時間間引き(時間成分を偶数、奇数に分ける)のFFTを書きましたが、上のソースは周波数間引き(周波数成分を偶数、奇数に分ける)で書いています。どちらも同じように書けますが、バタフライ演算の向きが逆になります。
ビット逆転のところが興味深いですが、ビット演算を使わない書くと処理の意味が分かります。LSBとMSBを逆に考えて1をインクリメントしています。
- void bit_reverse(int x[], int N) {
- int r = 0;
- for (int i = 1; i < N-1; i++) {
- int k = N/2;
- while (r >= k) {
- r -= k;
- k /= 2;
- }
- r += k;
- if (r > i)
- swap(x[i], x[r]);
- }
- }
さて、最終テーマであるFFTを用いた多倍長整数乗算の高速化です。
普通に多倍長整数の乗算をやると、O(n^2)の時間が必要です(nは桁数)。100万桁同士の掛け算をすると、30分~1時間程度かかるのではないかと思います。FFTを使うと、O(n log n)で乗算ができます。100万桁同士の掛け算でも数秒で終わる計算量です。
数学的な詳しい解説は、ここのサイトを参照。ここでは、大まかな処理だけをまとめます。
「文字列xとyが与えられる。z = x * yを求める。」とします。
- 上記のFFTが使用できるように信号長を2のべき乗とします。N >= max(len(x), len(y))となるような最小の2のべき乗Nを求めます。
- 文字列を信号列に変換します。x_t[]、y_t[]にそれぞれx, yのdigitsをLSBから詰めていきます。x_t[i], i=len(x), len(x)+1, ..., 2*N-1、および、y_t[i], i=len(y), len(y)+1, ...,2*N-1には0を入れます。
- x_t[]、y_t[]を長さ2*Nの信号としてフーリエ変換します。
- z_t[i] = x_t[i] * y_t[i], i = 0,1,..., 2*N-1とします。
- z_t[]を逆フーリエ変換します。
- z_t[]はx*yと一致しています。ただし、繰り上がりの処理を忘れずに。
以下ソース。SPOJのNot So Fast Multiplicationを解きました。
- char num1[1<<16], num2[1<<16];
- comp xt[1<<16], yt[1<<16], zt[1<<16];
- int carried[1<<16];
- char ans[1<<16];
- /*
- * conversion from a string to a signal sequence
- */
- void convert(const char x[], comp *ret, int len) {
- for (int i = 0; i < len; i++)
- ret[i] = x[i]-'0';
- for (int i = 0; i < len; i++)
- ret[len+i] = 0;
- }
- /*
- * conversion from a signal sequence to an integer
- */
- void toInteger(int N) {
- for (int i = 0; i < N; i++)
- carried[i] = (int)(zt[i].real()/(N) + 0.5);
- for (int i = 0; i < N; i++) {
- if (carried[i] >= 10) {
- carried[i+1] += carried[i] / 10;
- carried[i] %= 10;
- }
- }
- int last = N-1;
- for (; last >= 0 && !carried[last]; last--)
- ;
- for (int i = 0; i <= last; i++)
- ans[last-i] = (char)(carried[i]+'0');
- if (last == -1)
- ans[0] = '0';
- }
- /*
- * O(n log n) multiplication
- */
- void multiply(char x[], char y[]) {
- int len1 = strlen(x);
- int len2 = strlen(y);
- int N = max(len1, len2);
- int b = 1;
- while (b < N)
- b *= 2;
- N = b;
- reverse(x, x+len1);
- reverse(y, y+len2);
- for (int i = len1; i < N; i++)
- x[i] = '0';
- for (int i = len2; i < N; i++)
- y[i] = '0';
- convert(x, xt, N);
- convert(y, yt, N);
- FFT(xt, 2*N, -1);
- FFT(yt, 2*N, -1);
- for (int i = 0; i < 2*N; i++)
- zt[i] = xt[i] * yt[i];
- FFT(zt, 2*N, 1);
- toInteger(2*N);
- }
- int main() {
- int n;
- scanf("%d", &n);
- while (n--) {
- scanf("%s %s", num1, num2);
- memset(ans, 0, sizeof(ans));
- multiply(num1, num2);
- printf("%s\n", ans);
- }
- return 0;
- }
0 件のコメント:
コメントを投稿