- 節点の高さごとに軸を選んで、軸と交わる超平面でk次元空間を二つに分割する。
- 1.で選んだ超平面より左にあるデータは左部分木に、右にあるデータは右部分木に格納していく。
- データ集合の中央値を通るように1.の超平面を選ぶことで、平衡二分探索木にすることができる。
これを使うと、nearest neighborで最近傍を見つけるときの計算を高速化することができる。
詳しく知りたいひとは、ここ。
問題
とりあえずソースコードを書く前にk-d treeを用いると高速に解くことができそうな問題を考えてみた。『二次元平面上にN個の点が与えられる。その後L個の点が与えられる。L個の点それぞれについて、先に与えられたN個の点のうちどれに最も近いのかを求めよ。』
straightforwardな解法
普通に書くとこうです。
#include <iostream> #include <vector> using namespace std; int N, L; inline double dist(double x1, double y1, double x2, double y2) { return (x1-x2) * (x1-x2) + (y1-y2) * (y1-y2); } int main() { cin >> N; vector<double> xs(N), ys(N); for (int i = 0; i < N; i++) cin >> xs[i] >> ys[i]; cin >> L; for (int i = 0; i < L; i++) { double x, y; cin >> x >> y; double best = 1e100; int nearest = -1; for (int j = 0; j < N; j++) { if (dist(x, y, xs[j], ys[j]) < best) { best = dist(x, y, xs[j], ys[j]); nearest = j; } } cout << nearest << endl; } return 0; }
k-d treeを用いた解法
k-d treeを使うとこんな感じ。
#include <iostream> #include <vector> #include <algorithm> using namespace std; int N, L; inline double dist(double x1, double y1, double x2, double y2) { return (x1-x2) * (x1-x2) + (y1-y2) * (y1-y2); } struct data { int index; vector<double> v; data(int _d) : v(_d) {} data() {} }; struct vertex { data val; vertex *left; vertex *right; }; class axisSorter { int k; public: axisSorter(int _k) : k(_k) {} double operator()(const data &a, const data &b) { return a.v[k] < b.v[k]; } }; vertex *makeKDTree(vector<data> &xs, int l, int r, int depth) { if (l >= r) return NULL; sort(xs.begin() + l, xs.begin() + r, axisSorter(depth % xs[0].v.size())); int mid = (l+r)>>1; vertex *v = new vertex(); v->val = xs[mid]; v->left = makeKDTree(xs, l, mid, depth+1); v->right = makeKDTree(xs, mid+1, r, depth+1); return v; } inline double dist(const data &x, const data &y) { double ret = 0; for (int i = 0; i < (int)x.v.size(); i++) ret += (x.v[i] - y.v[i]) * (x.v[i] - y.v[i]); return ret; } inline double sq(double x) { return x*x; } data query(data &a, vertex *v, int depth) { int k = depth % a.v.size(); if ((v->left != NULL && a.v[k] < v->val.v[k]) || (v->left != NULL && v->right == NULL)) { data d = query(a, v->left, depth+1); if (dist(v->val, a) < dist(d, a)) d = v->val; if (v->right != NULL && sq(a.v[k] - v->val.v[k]) < dist(d, a)) { data d2 = query(a, v->right, depth+1); if (dist(d2, a) < dist(d, a)) d = d2; } return d; } else if (v->right != NULL) { data d = query(a, v->right, depth+1); if (dist(v->val, a) < dist(d, a)) d = v->val; if (v->left != NULL && sq(a.v[k] - v->val.v[k]) < dist(d, a)) { data d2 = query(a, v->left, depth+1); if (dist(d2, a) < dist(d, a)) d = d2; } return d; } else return v->val; } int main() { cin >> N; vector<data> xs(N, data(2)); for (int i = 0; i < N; i++) { xs[i].index = i; cin >> xs[i].v[0] >> xs[i].v[1]; } vertex *root = makeKDTree(xs, 0, N, 0); cin >> L; for (int i = 0; i < L; i++) { data a(2); cin >> a.v[0] >> a.v[1]; data nearest = query(a, root, 0); cout << nearest.index << endl; } return 0; }
速度比較
N=1,000,000、L=100,000とすると、straightforward解は240秒弱。対してk-d treeを使った解法は8秒くらい。30倍くらい速い。もうちょっと真面目にコードを書けばさらに10倍くらいは速くできると思う。(多分)
0 件のコメント:
コメントを投稿