Page List

Search on the blog

2013年6月26日水曜日

C++でk-d tree実装

 C++でk-d treeを実装してみた。k-d treeのアイディアはシンプルで簡単に言うと二分探索木のk次元空間拡張バージョン。もうちょっと詳しく言うと、

  1. 節点の高さごとに軸を選んで、軸と交わる超平面でk次元空間を二つに分割する。
  2. 1.で選んだ超平面より左にあるデータは左部分木に、右にあるデータは右部分木に格納していく。
  3. データ集合の中央値を通るように1.の超平面を選ぶことで、平衡二分探索木にすることができる。

これを使うと、nearest neighborで最近傍を見つけるときの計算を高速化することができる。
詳しく知りたいひとは、ここ


問題
とりあえずソースコードを書く前にk-d treeを用いると高速に解くことができそうな問題を考えてみた。

『二次元平面上にN個の点が与えられる。その後L個の点が与えられる。L個の点それぞれについて、先に与えられたN個の点のうちどれに最も近いのかを求めよ。』

straightforwardな解法
普通に書くとこうです。
#include <iostream>
#include <vector>

using namespace std;

int N, L;

inline double dist(double x1, double y1, double x2, double y2) {
    return (x1-x2) * (x1-x2) + (y1-y2) * (y1-y2);
}

int main() {
    cin >> N;
    vector<double> xs(N), ys(N);
    for (int i = 0; i < N; i++)
        cin >> xs[i] >> ys[i];
    cin >> L;
    for (int i = 0; i < L; i++) {
        double x, y;
        cin >> x >> y;
        double best = 1e100;
        int nearest = -1;
        for (int j = 0; j < N; j++) {
            if (dist(x, y, xs[j], ys[j]) < best) {
                best = dist(x, y, xs[j], ys[j]);
                nearest = j;
            }
        }
        cout << nearest << endl;
    }

    return 0;
}
k-d treeを用いた解法
k-d treeを使うとこんな感じ。
#include <iostream>
#include <vector>
#include <algorithm>

using namespace std;

int N, L;

inline double dist(double x1, double y1, double x2, double y2) {
    return (x1-x2) * (x1-x2) + (y1-y2) * (y1-y2);
}

struct data {
    int index;
    vector<double> v;
    data(int _d) : v(_d) {}
    data() {}
};

struct vertex {
    data val;
    vertex *left;
    vertex *right;
};

class axisSorter {
    int k;
public:
    axisSorter(int _k) : k(_k) {}
    double operator()(const data &a, const data &b) {
        return a.v[k] < b.v[k];
    }
};

vertex *makeKDTree(vector<data> &xs, int l, int r, int depth) {
    if (l >= r)
        return NULL;

    sort(xs.begin() + l, xs.begin() + r, axisSorter(depth % xs[0].v.size()));
    int mid = (l+r)>>1;

    vertex *v = new vertex();
    v->val = xs[mid];
    v->left = makeKDTree(xs, l, mid, depth+1);
    v->right = makeKDTree(xs, mid+1, r, depth+1);

    return v;
}

inline double dist(const data &x, const data &y) {
    double ret = 0;
    for (int i = 0; i < (int)x.v.size(); i++)
        ret += (x.v[i] - y.v[i]) * (x.v[i] - y.v[i]);
    return ret;
}

inline double sq(double x) {
    return x*x;
}

data query(data &a, vertex *v, int depth) {
    int k = depth % a.v.size();
    if ((v->left != NULL && a.v[k] < v->val.v[k]) || 
        (v->left != NULL && v->right == NULL)) {
        data d = query(a, v->left, depth+1);
        if (dist(v->val, a) < dist(d, a))
            d = v->val;
        if (v->right != NULL && sq(a.v[k] - v->val.v[k]) < dist(d, a)) {
            data d2 = query(a, v->right, depth+1);
            if (dist(d2, a) < dist(d, a))
                d = d2;
        }
        return d;
    }
    else if (v->right != NULL) {
        data d = query(a, v->right, depth+1);
        if (dist(v->val, a) < dist(d, a))
            d = v->val;
        if (v->left != NULL && sq(a.v[k] - v->val.v[k]) < dist(d, a)) {
            data d2 = query(a, v->left, depth+1);
            if (dist(d2, a) < dist(d, a))
                d = d2;
        }
        return d;
    }
    else
        return v->val;
}

int main() {
    cin >> N;

    vector<data> xs(N, data(2));
    for (int i = 0; i < N; i++) {
        xs[i].index = i;
        cin >> xs[i].v[0] >> xs[i].v[1];
    }

    vertex *root = makeKDTree(xs, 0, N, 0);

    cin >> L;
    for (int i = 0; i < L; i++) {
        data a(2);
        cin >> a.v[0] >> a.v[1];
        data nearest = query(a, root, 0);
        cout << nearest.index << endl;
    }

    return 0;
}


速度比較
N=1,000,000、L=100,000とすると、straightforward解は240秒弱。対してk-d treeを使った解法は8秒くらい。30倍くらい速い。もうちょっと真面目にコードを書けばさらに10倍くらいは速くできると思う。(多分)

0 件のコメント:

コメントを投稿