問題
A, B, Kが与えられる。以下の条件を満たす非負整数(x, y)の組み合わせ数を求めよ。
- x < A
- y < B
- (x & y) < K
ただし、
1 <= A <= 10^9,
1 <= B <= 10^9,
1 <= K <= 10^9.
解法
rng..58さんの解法(LSBを固定して再帰)がとても綺麗で分かりやすかった。こんなにシンプルに書けるのかという感じ。正答者の多くはビット単位のDPで解いているようだった。何を状態数としてDPしているのか分からなかったのでじっくりと考えてみた。
簡単のため以下のような問題を考える。
正の正数Aが与えられる。x < Aを満たす非負整数xはいくつあるか?
答えはA個なのですが、この問題をわざわざビット単位で処理して解くことを考えてみます。
#include <iostream> using namespace std; int main() { long long A; cin >> A; int a[32] = {}; for (int i = 0; i < 32; i++) a[32-1-i] = A >> i & 1; for (int i = 0; i < 32; i++) cout << a[i] << " "; cout << endl; int dp[32+1][2] = {}; // dp[bit][LSBs are arbitrary or not?] dp[0][0] = 1; for (int k = 0; k < 32; k++) { // k : bit (notice that we process from MSB to LSB) for (int i = 0; i < 2; i++) { // i : the next bit is arbitrary or not? int jt = i == 1 ? 1 : a[k]; // jt : the maximum next bit for (int j = 0; j <= jt; j++) { // j : next bit dp[k+1][i | a[k] > j] += dp[k][i]; } } } cout << dp[32][1] << endl; // how many non-negative integers less than A are there? // "less than" means "not used up to A" -> LSBs are arbitrary. return 0; }MSBからLSBの方向に処理していきます。
状態数を(現在の桁, 次の桁以降を任意の値に出来るか?)とするのがポイントです。
ビットkからビット(k+1)への遷移を考えます。
もし、ビットkの時点で次の桁以降を任意に決められるのであれば、x[k+1]のビットは{0, 1}のどちらでも選べます。もし、ビットkの時点で次の桁以降を任意に決められないのであれば、x[k+1]のビットは、a[k+1]の値まで選べます。このx[k+1]のビットの最大値が上のソースコードのjtです。
jtが決まると、[0, jt]まででループを回して、x[k+1]のビットの値を決めます。
もし、ビットkの時点で次の桁以降の値を任意に決められるのであれば、k+1以降も同様に次の桁以降の値を任意に決めることが出来ます。
ビットkの時点で次の桁以降の値を任意に決められない場合は、
a[k+1] = 1、かつ、x[k+1] = 0
とした場合に次の桁以降の値を任意に決められることになります。
上のソースコードの dp[k+1][i | a[k] > j] += dp[k][i]; の部分がそれに対応します。
問題の解は、dp[32][1]になります。これはx < Aなので、ぎりぎり目一杯Aに寄っておらずもし次のビットが存在していれば、任意に決められるという意味で2つ目の添字は1にします。ちなみにdp[32][0] = 1です。これはa[k] = 1となるビットすべてでx[k] = 1としたことを意味しています。
状態数の取り方が特殊で難しいですが、おもしろいパターンのDPです。ここまで分かれば、元々の問題はこれの応用で解けます。
using namespace std; #define ALL(x) (x).begin(), (x).end() #define EACH(itr,c) for(__typeof((c).begin()) itr=(c).begin(); itr!=(c).end(); itr++) #define FOR(i,b,e) for (int i=(int)(b); i<(int)(e); i++) #define MP(x,y) make_pair(x,y) #define REP(i,n) for(int i=0; i<(int)(n); i++) void solve() { int A, B, K; cin >> A >> B >> K; int ar[32], br[32], kr[32]; // bit expression: ar[0] = MSB, ar[31] = LSB for (int i = 0; i < 32; i++) { ar[31-i] = A >> i & 1; br[31-i] = B >> i & 1; kr[31-i] = K >> i & 1; } long long dp[32+1][2][2][2] = {}; dp[0][0][0][0] = 1; for (int d = 0; d < 32; d++) { for (int i = 0; i < 2; i++) { // next ar bit is arbitrary? for (int j = 0; j < 2; j++) { // next br bit is arbitrary? for (int k = 0; k < 2; k++) { // next kr bit is arbitrary? if (dp[d][i][j][k] == 0) continue; int at = i == 1 ? 1 : ar[d]; int bt = j == 1 ? 1 : br[d]; for (int a = 0; a <= at; a++) { // next ar bit for (int b = 0; b <= bt; b++) { // next br bit int kk = a & b; // next kk bit if (k == 0 && kk > kr[d]) continue; dp[d+1][i | ar[d] > a][j | br[d] > b][k | kr[d] > kk] += dp[d][i][j][k]; } } } } } } cout << dp[32][1][1][1] << endl; } int main() { ios_base::sync_with_stdio(0); int T; cin >> T; REP (i, T) { cerr << "Case #" << i+1 << ": " << endl; cout << "Case #" << i+1 << ": "; solve(); } return 0; }
0 件のコメント:
コメントを投稿