Search on the blog

2013年12月9日月曜日

集合を扱うときはbitsetが便利

まえがき
プログラムで集合を扱うときはbitmaskを使うと便利なことがあります。
  • ビットが立っていれば要素は集合内に含まれる。
  • ビットが立っていなければ要素は集合内に含まれない。
  • 2つの集合の積集合はandで計算できる。
  • 2つの集合の和集合はorで計算できる。
  • 2つの集合の補集合はnotで計算できる。
など。

 C++の場合はintを使うと要素数が32以下の集合を、long longを使うと要素数が64以下の集合を扱うことができます。

bitsetクラスの使い方
 要素数が64より大きな場合は、bitsetクラスを利用するとよいです。
  • bitを立てるときはsetを使う。
  • bitが立っているかチェックするときはtestを使う。
さらに論理和、論理積、否定、排他的論理和の演算子がオーバーロードされていて、ふつうの整数と同じように扱えます。なんという便利さだ!
#include <bitset>
#include <iostream>

using namespace std;

int main(int argc, char **argv) {
    bitset<16> a, b;

    for (int i = 0; i < 16; i++) {
        if (i % 2)
            a.set(i);
        else
            b.set(i);
    }

    cout << a << endl;            // 1010101010101010
    cout << b << endl;            // 0101010101010101

    cout << a.test(0) << endl;    // 0
    cout << a.test(1) << endl;    // 1
    cout << (~a) << endl;         // 0101010101010101
    cout << (a & b) << endl;      // 0000000000000000
    cout << (a | b) << endl;      // 1111111111111111
    cout << (a ^ b) << endl;      // 1111111111111111

    return 0;
}
 
例題
Facebook Hacker Cup 2014 Round 1「Preventing Alzheimer's」のEgorさんの解答がbitsetをうまく使っています。

 実際の問題は少し複雑なので、要点のみを抜粋した以下の問題を解いてみます。
[2, 100]の整数がn個(2 <= N <= 100)与えられる。これらの数字のいずれとも互いに素となり、かつ、いずれの数字よりも大きい最小の整数x(2 <= x <= 100)を求めて出力せよ。そのような整数が存在しない場合は、-1を出力せよ。

 各整数についてそれと互いに素な整数のみビットを立てるという前処理を行うことで、すっきりとした実装ができます。
#include <bitset>
#include <iostream>
#include <cassert>

using namespace std;

const int MAX_V = 100+1;
bitset<MAX_V> coprime[MAX_V];

int gcd(int a, int b) {
    return b == 0 ? a : gcd(b, a%b);    
}

void init() {
    for (int i = 1; i < MAX_V; i++) {
        for (int j = 1; j < MAX_V; j++) {
            if (gcd(i, j) == 1)
                coprime[i].set(j);
        }
    }
}

int solve() {
    int N;
    cin >> N;

    bitset<MAX_V> bs;
    bs = ~bs;

    int maxv = 0;
    for (int i = 0; i < N; i++) {
        int x;
        cin >> x;
        maxv = max(maxv, x);
        bs &= coprime[x];
    }

    for (int i = maxv+1; i < MAX_V; i++) {
        if (bs.test(i))
            return i;
    }

    return -1;
}

int main(int argc, char **argv) {
    init();

    int T;      // # of test cases
    cin >> T;
    
    for (int i = 0; i < T; i++)
        cout << solve() << endl;

    return 0;
}

0 件のコメント:

コメントを投稿