Centroid Decompositionを使う問題をいくつか解いてみた。
Centroidが何なのか知らない人は
こちら。
Codeforces Round #190 Ciel the Commander
問題
木の各ノードにアルファベット(A-Z)を1つ書きたい。
ただし、同じアルファベットが書かれた2つのノードv, w間のパス上には、2つのノードに書かれたアルファベットより小さい文字が書かれたノードが存在しなければならない。
このようなアルファベットの書き方を求めよ。
解法
Centroid Decompositionして分解されたときの再帰の深さの順に小さい文字を割り振っていけばOK。
ソースコード
#define REP(i,n) for(int i=0; i<(int)(n); i++)
#define FOR(i,b,e) for (int i=(int)(b); i<(int)(e); i++)
#define ALL(x) (x).begin(), (x).end()
int n;
vector<int> edges[100000];
int rk[100000];
int sz[100000];
void szdfs(int v, int par = -1) {
sz[v] = 1;
for (auto &w: edges[v]) {
if (rk[w] || w == par) continue;
szdfs(w, v);
sz[v] += sz[w];
}
}
int centroid(int v, int par, int total) {
for (auto &w: edges[v]) {
if (rk[w] || w == par) continue;
if (2 * sz[w] > total)
return centroid(w, v, total);
}
return v;
}
void solve(int v, int r) {
szdfs(v);
v = centroid(v, -1, sz[v]);
rk[v] = r;
for (auto &w: edges[v]) {
if (rk[w]) continue;
solve(w, r+1);
}
}
int main() {
ios_base::sync_with_stdio(0);
cin.tie(0);
cin >> n;
REP (i, n-1) {
int a, b;
cin >> a >> b;
--a, --b;
edges[a].push_back(b);
edges[b].push_back(a);
}
solve(0, 1);
REP (i, n) {
char c = 'A' + rk[i] - 1;
cout << c << " ";
}
cout << endl;
return 0;
}
Codeforces Round #199 Xenia and Tree
問題
木のノードに色を塗る。はじめノード1は赤色に、それ以外のノードは青色に塗られている。以下のクエリを高速に処理せよ。
1. ある青いノードを赤色に塗る
2. あるノードから赤色のノードまでの最短距離を求める
解法
Centroid Decompositionを使って、バランスした木に構築する。
1. のクエリに対しては指定されたノードを赤く塗り、そのノードからルート方向へ登りながら、 通過したノードにそのノードからそのノード以下の赤ノードまでの最短距離を更新する。
2. のクエリに対してはそのノードからルート方向へ登りながら1.のときに更新した値を用いて最短距離を計算する。
木がバランスしているので、訪れるノード数がlog(n)個程度になるのがポイント。
Centroid Decompositionで作った木と元の木のノード集合は同じだが、枝集合は異なることに注意。ノード間の距離を求める場合は元の木における距離を使わないといけない。
ソースコード
#define REP(i,n) for(int i=0; i<(int)(n); i++)
#define FOR(i,b,e) for (int i=(int)(b); i<(int)(e); i++)
#define ALL(x) (x).begin(), (x).end()
class LCA {
int V, logV;
vector<int> depth;
vector<vector<int> > parent;
void build() {
for (int k = 0; k + 1 < logV; k++) {
for (int v = 0; v < V; v++) {
if (parent[k][v] < 0) parent[k+1][v] = -1;
else parent[k+1][v] = parent[k][parent[k][v]];
}
}
}
public:
LCA(int V) {
this->V = V;
logV = 0;
while (V > (1LL<<logV)) logV++;
this->depth = vector<int>(V);
this->parent = vector<vector<int> >(logV, vector<int>(V));
}
void init(int N, int p[], int d[]) {
for (int i = 0; i < N; i++) {
parent[0][i] = p[i];
depth[i] = d[i];
}
this->build();
}
int query(int u, int v) {
if (depth[u] > depth[v]) swap(u, v);
for (int k = 0; k < logV; k++) {
if ((depth[v] - depth[u]) >> k & 1)
v = parent[k][v];
}
if (u == v) return u;
for (int k = logV-1; k >= 0; k--) {
if (parent[k][u] != parent[k][v]) {
u = parent[k][u];
v = parent[k][v];
}
}
return parent[0][u];
}
};
const int INF = 1<<28;
int n, m;
vector<int> edges[100000];
bool vis[100000];
int p[100000];
int sz[100000];
bool red[100000];
int dist[100000];
int lcap[100000];
int lcad[100000];
LCA lca(100000);
void szdfs(int v, int par = -1) {
sz[v] = 1;
for (auto &w: edges[v]) {
if (vis[w] || w == par) continue;
szdfs(w, v);
sz[v] += sz[w];
}
}
int centroid(int v, int par, int total) {
for (auto &w: edges[v]) {
if (vis[w] || w == par) continue;
if (2 * sz[w] > total)
return centroid(w, v, total);
}
return v;
}
void balanceTree(int v, int par = -1) {
szdfs(v);
v = centroid(v, -1, sz[v]);
p[v] = par;
vis[v] = true;
for (auto &w: edges[v]) {
if (vis[w]) continue;
balanceTree(w, v);
}
}
void lcadfs(int v, int par, int d) {
lcap[v] = par;
lcad[v] = d;
for (auto &w: edges[v]) {
if (w == par) continue;
lcadfs(w, v, d+1);
}
}
void paint(int v) {
red[v] = true;
int w = v;
while (w != -1) {
int u = lca.query(v, w);
int cost = lcad[v] + lcad[w] - 2 * lcad[u];
dist[w] = min(dist[w], cost);
w = p[w];
}
}
int query(int v) {
int ret = INF;
int w = v;
while (w != -1) {
int u = lca.query(v, w);
int cost = lcad[v] + lcad[w] - 2 * lcad[u];
ret = min(ret, dist[w] + cost);
w = p[w];
}
return ret;
}
int main() {
ios_base::sync_with_stdio(0);
cin.tie(0);
cin >> n >> m;
REP (i, n-1) {
int a, b;
cin >> a >> b;
--a, --b;
edges[a].push_back(b);
edges[b].push_back(a);
}
balanceTree(0);
fill(dist, dist+n, INF);
lcadfs(0, -1, 0);
lca.init(n, lcap, lcad);
paint(0);
REP (i, m) {
int t, x;
cin >> t >> x;
--x;
if (t == 1)
paint(x);
else
cout << query(x) << endl;
}
return 0;
}
Codeforces Round #372 Digit Tree
問題
木のノードに1-9までの数字が書かれている。
v, w間のパスに含まれるノードの数字をつないで10進数表記の数を作る。その数がMで割り切れるようなv, wのペアの数を求めよ。
解法
uがv, wのLCAとなるような場合のv, wの組み合わせを考える。
v -> u -> wという順序に通るので、v -> uとu -> wの組み合わせを列挙して、2つをつなげるとMで割るようなものを数えればよい。
あとはuをすべてのノードでループしないといけないが、Centroid Decompositionしておくと計算量を抑えることができる。
ソースコード
using namespace std;
#define REP(i,n) for(int i=0; i<(int)(n); i++)
#define FOR(i,b,e) for (int i=(int)(b); i<(int)(e); i++)
#define ALL(x) (x).begin(), (x).end()
int n;
long long M;
vector<pair<int, int> > edges[100000];
bool vis[100000];
int sz[100000];
map<int, int> upcnt;
long long r1, r2;
long long pw[100000], ipw[100000];
void szdfs(int v, int par = -1) {
sz[v] = 1;
for (auto &e: edges[v]) {
int w = e.first;
if (w == par || vis[w]) continue;
szdfs(w, v);
sz[v] += sz[w];
}
}
int centroid(int v, int par, int total) {
REP (i, edges[v].size()) {
int w = edges[v][i].first;
if (w == par || vis[w]) continue;
if (sz[w] * 2 > total)
return centroid(w, v, total);
}
return v;
}
void downdfs(int v, int par, int acc, int d) {
if (acc == 0) ++r2;
r1 += upcnt[(M-acc)*ipw[d]%M];
for (auto &e: edges[v]) {
int u, w;
tie(u, w) = e;
if (u == par || vis[u]) continue;
downdfs(u, v, (10LL*acc+w)%M, d+1);
}
}
void updfs(int v, int par, int acc, int d) {
if (acc == 0) ++r2;
++upcnt[acc];
for (auto &e: edges[v]) {
int u, w;
tie(u, w) = e;
if (u == par || vis[u]) continue;
updfs(u, v, (acc+pw[d]*w)%M, d+1);
}
}
void solve(int v) {
szdfs(v);
v = centroid(v, -1, sz[v]);
upcnt.clear();
REP (i, edges[v].size()) {
int u = edges[v][i].first;
int w = edges[v][i].second;
if (vis[u]) continue;
downdfs(u, v, w%M, 1);
updfs(u, v, w%M, 1);
}
upcnt.clear();
REP (i, edges[v].size()) {
int u = edges[v][edges[v].size()-1-i].first;
int w = edges[v][edges[v].size()-1-i].second;
if (vis[u]) continue;
downdfs(u, v, w%M, 1);
updfs(u, v, w%M, 1);
}
vis[v] = true;
for (auto &e: edges[v]) {
int w = e.first;
if (vis[w]) continue;
solve(w);
}
}
long long modpow(long long x, long long p, long long mod) {
long long ret = 1;
while (p) {
if (p & 1)
ret = ret * x % mod;
x = x * x % mod;
p >>= 1;
}
return ret;
}
long long totient(long long n) {
long long ret = n;
for (long long i = 2; i * i <= n; i++) {
if (n % i == 0) {
ret = ret / i * (i - 1);
while (n % i == 0)
n /= i;
}
}
if (n != 1)
ret = ret / n * (n - 1);
return ret;
}
void init() {
pw[0] = ipw[0] = 1;
long long inv = modpow(10, totient(M)-1, M);
FOR (i, 1, 100000) {
pw[i] = pw[i-1] * 10 % M;
ipw[i] = ipw[i-1] * inv % M;
}
}
int main() {
ios_base::sync_with_stdio(0);
cin.tie(0);
cin >> n >> M;
REP (i, n-1) {
int u,v,w;
cin >> u >> v >> w;
edges[u].emplace_back(v, w);
edges[v].emplace_back(u, w);
}
r1 = r2 = 0;
init();
solve(0);
cout << r1 + r2/2 << endl;
return 0;
}