問題概要
素数pが与えられる。サイズnとサイズmの整数列A, Bを考える。但しA, Bは以下のように決まる。
A[0] = A1
A[1] = A2
A[i] = (A3 * A[i-2] + A4 * A[i-1] * A5) % p, i = 2,3,..,n-1
B[0] = B1
B[1] = B2
B[i] = (B3 * A[i-2] + B4 * A[i-1] * B5) % p, i = 2,3,..,m-1
(A[i] * B[j]) mod p < lとなるような(i, j)の組の個数を求めよ。
2 ≤ P < 250,000
1 ≤ L ≤ P
2 ≤ N, M ≤ 10,000,000
0 ≤ A1, A2, A3, A4, A5, B1, B2, B3, B4, B5 < P
方針
ここのサイトがとても分かりやすい。一応日本語で概要だけまとめておく。
素数pには必ず原始根が存在する。以下では、素数pの原始根をgと記す。
x * y (mod p) = z
となるようなx, yを考えよ。という問題は、x, y, zをgのべき乗で表してやると(つまりg
x' * g
y' = g
z'と表してやると)、
x' + y' (mod p) = z'
となるような(x', y')を考えよという問題に読み変えることができる。
よって、まず原始根を求めて[1, p-1]までの値 -> それをgのべき乗で表した場合べきの数はいくらか?というlookupテーブルを作成する。原始根の計算は愚直にやると、p-2個の候補のp-1個のべき乗を調べることになるので、O(p
2)くらいかかる(ただし、2が原始根になる場合もあるので、Ω(p))。エラトステネスの篩 + 位数がφ(p) = p-1の約数になること + 繰り返し二乗法を利用すれば、(p-1)を素因数で割ったものだけを調べることにより、O(p * log
2(p))くらいで計算できそうだったが、予め最小の原始根を調べてみると、せいぜい40程度だったので、愚直な実装のままでいくことにした。
次に、実際にA, Bを生成しながら、どのべき数がどれだけ数列内に含まれているかという情報を保存していく。Aに含まれるべき数がiのものの個数をw[i], Bに含まれるべき数がjのものの個数をv[j]とする。
このとき、Aの要素とBの要素を掛け合わせた結果のべき数がiとなるようなもの組み合わせの数は、
のように表すことができる。これは畳みこみ和の形をしているので、カラツバ法もしくはFFTを用いて高速に計算することができる。今回はFFTで計算した。あとは、g
iがlより小さいかどうか判定して、小さければu[i]を答えに足していくだけ。また、ゼロは原始根のべき乗で表せないので例外的に処理をした。
原始根って何?原始根は知ってるけど、何ですべての素数に原始根が存在するの?という人には、以下の本がお勧めです。
ソースコード
ジャッジ用のテストデータがダウンロードできなかったため、完全な正誤判定はできていない。(ランダムに比較的小さなサイズのテストケースを生成して愚直な解法と比べて検算した。また、大きなテストケースも生成して許容時間内(6分内)にすべてのテストケースを計算できることを確認した。)誰かローカルにテストデータ持ってたら、ください。。
#include <iostream>
#include <vector>
#include <algorithm>
#include <complex>
using namespace std;
/*
* get a primitive root
*/
int getPrimitiveRoot(int p) {
for (int i = 1; i < p; i++) {
long long x = i;
bool ck = true;
for (int j = 1; j < p - 1; j++) {
if (x == 1) {
ck = false;
break;
}
x = x * i % p;
}
if (ck)
return i;
}
return -1;
}
/*
* Fast Fourier Transform
*/
typedef complex<long double> Complex;
const int TIME_TO_FREQ = -1;
const int FREQ_TO_TIME = 1;
void fft(Complex x[], int N, int sgn) {
Complex theta = Complex(0, sgn * 2 * M_PI / N);
for (int m = N; m >= 2; m >>= 1) {
int mh = m >> 1;
for (int i = 0; i < mh; i++) {
Complex w = exp((long double)i * theta);
for (int j = i; j < N; j += m) {
int k = j + mh;
Complex tmp = x[j] - x[k];
x[j] += x[k];
x[k] = w * tmp;
}
}
theta *= 2;
}
int i = 0;
for (int j = 1; j < N - 1; j++) {
for (int k = N >> 1; k > (i ^= k); k >>= 1) ;
if (j < i) swap(x[i], x[j]);
}
if (sgn == FREQ_TO_TIME) for (int i = 0; i < N; i++) x[i] /= N;
}
/*
* main solver
*/
const int MAX_SEQ_SIZE = 10000000;
const int MAX_PRIME_NUM = 1<<18;
long long a[MAX_SEQ_SIZE];
long long b[MAX_SEQ_SIZE];
int discreteLog[MAX_PRIME_NUM];
int powVal[MAX_PRIME_NUM];
Complex w[2*MAX_PRIME_NUM];
Complex v[2*MAX_PRIME_NUM];
Complex u[2*MAX_PRIME_NUM];
long long solve(int p, int l, int n, int m) {
// calculate a primitive root on modular p
// and make a look-up table (n, discrete log(n))
int g = getPrimitiveRoot(p);
long long x = 1;
for (int i = 0; i < p - 1; i++) {
discreteLog[x] = i;
powVal[i] = x;
x = x * g % p;
}
// calculate convolutions with FFTs
int N = 1;
while (N <= p)
N *= 2;
N *= 2;
for (int i = 0; i < N; i++) {
w[i] = 0;
v[i] = 0;
}
long long ret = 0;
long long z1 = 0;
long long z2 = 0;
for (int i = 0; i < n; i++) {
if (!a[i]) {
ret += m;
++z1;
}
else
w[discreteLog[a[i]]].real()++;
}
for (int i = 0; i < m; i++) {
if (!b[i]) {
ret += n;
++z2;
}
else
v[discreteLog[b[i]]].real()++;
}
ret -= z1 * z2;
fft(w, N, TIME_TO_FREQ);
fft(v, N, TIME_TO_FREQ);
for (int i = 0; i < N; i++)
u[i] = w[i] * v[i];
fft(u, N, FREQ_TO_TIME);
for (int i = 0; i < p-1; i++) {
if (powVal[i] >= l)
continue;
if (i == p-2)
ret += (long long)(u[i].real() + 0.5);
else
ret += (long long)(u[i].real() + u[i+p-1].real() + 0.5);
}
return ret;
}
/*
* main (for input and output only)
*/
#include <float.h>
int main() {
int T;
cin >> T;
for (int t = 0; t < T; t++) {
cerr << "Case #" << t+1 << " solving....." << endl;
int p, l;
cin >> p >> l;
// generate sequence A
int n, a3, a4, a5;
cin >> n >> a[0] >> a[1] >> a3 >> a4 >> a5;
for (int i = 2; i < n; i++)
a[i] = (a[i - 2] * a3 + a[i - 1] * a4 + a5) % p;
// generate sequence B
int m, b3, b4, b5;
cin >> m >> b[0] >> b[1] >> b3 >> b4 >> b5;
for (int i = 2; i < m; i++)
b[i] = (b[i - 2] * b3 + b[i - 1] * b4 + b5) % p;
long long ret = solve(p, l, n, m);
cout << "Case #" << t+1 << ": " << ret << endl;
}
return 0;
}