Page List

Search on the blog

2017年3月31日金曜日

PyStan入門(1)とりあえず動かす

 緑本を読んで、MCMCで統計モデルのパラメータ推定をすることに興味がわいてきたので、PyStanをさわっていこうと思う。
 まずは、とりあえずインストールして動かしてみました。

PyStanのインストール
pipで入ります。

$ pip install pystan

モデルの記述
stanのモデルは.stanという拡張子のファイルに書くのが一般的なようです。他にもPythonのコードにベタ書きするという方法もあるようです。
ここではstanファイルにモデルを記述することにし、以下の内容を8schools.stanというファイルに記述します。

data {
  int<lower=0> J; // number of schools 
  real y[J]; // estimated treatment effects
  real<lower=0> sigma[J]; // s.e. of effect estimates 
}
parameters {
  real mu; 
  real<lower=0> tau;
  real eta[J];
}
transformed parameters {
  real theta[J];
  for (j in 1:J)
    theta[j] <- mu + tau * eta[j];
}
model {
  eta ~ normal(0, 1);
  y ~ normal(theta, sigma);
}

上のモデルを以下のPythonコードから参照します。以下のコードをsample.pyという名前のファイルに記述します。
import pystan
import matplotlib.pyplot as plt

schools_dat = {
 'J': 8,
 'y': [28,  8, -3,  7, -1,  1, 18, 12],
 'sigma': [15, 10, 16, 11,  9, 11, 10, 18]
}

fit = pystan.stan(file='8schools.stan', data=schools_dat, iter=1000, chains=4)

print(fit)
fit.plot()
plt.show()

実行結果
実行してみます。

$ python sample.py
DIAGNOSTIC(S) FROM PARSER:
Warning (non-fatal): assignment operator <- deprecated in the Stan language; use = instead.

INFO:pystan:COMPILING THE C++ CODE FOR MODEL anon_model_286b3180dfa752c4cfedaf0241add0e4 NOW.
Iteration:   1 / 1000 [  0%]  (Warmup) (Chain 1)
Iteration:   1 / 1000 [  0%]  (Warmup) (Chain 3)
Iteration:   1 / 1000 [  0%]  (Warmup) (Chain 2)
Iteration:   1 / 1000 [  0%]  (Warmup) (Chain 0)
Iteration: 100 / 1000 [ 10%]  (Warmup) (Chain 1)
Iteration: 100 / 1000 [ 10%]  (Warmup) (Chain 2)
Iteration: 100 / 1000 [ 10%]  (Warmup) (Chain 3)
Iteration: 100 / 1000 [ 10%]  (Warmup) (Chain 0)
Iteration: 200 / 1000 [ 20%]  (Warmup) (Chain 1)
Iteration: 200 / 1000 [ 20%]  (Warmup) (Chain 0)
Iteration: 200 / 1000 [ 20%]  (Warmup) (Chain 3)
Iteration: 300 / 1000 [ 30%]  (Warmup) (Chain 0)
Iteration: 300 / 1000 [ 30%]  (Warmup) (Chain 3)
Iteration: 300 / 1000 [ 30%]  (Warmup) (Chain 1)
Iteration: 200 / 1000 [ 20%]  (Warmup) (Chain 2)
Iteration: 400 / 1000 [ 40%]  (Warmup) (Chain 1)
Iteration: 400 / 1000 [ 40%]  (Warmup) (Chain 0)
Iteration: 400 / 1000 [ 40%]  (Warmup) (Chain 3)
Iteration: 500 / 1000 [ 50%]  (Warmup) (Chain 3)
Iteration: 501 / 1000 [ 50%]  (Sampling) (Chain 3)
Iteration: 300 / 1000 [ 30%]  (Warmup) (Chain 2)
Iteration: 500 / 1000 [ 50%]  (Warmup) (Chain 0)
Iteration: 500 / 1000 [ 50%]  (Warmup) (Chain 1)
Iteration: 501 / 1000 [ 50%]  (Sampling) (Chain 0)
Iteration: 501 / 1000 [ 50%]  (Sampling) (Chain 1)
Iteration: 600 / 1000 [ 60%]  (Sampling) (Chain 3)
Iteration: 600 / 1000 [ 60%]  (Sampling) (Chain 0)
Iteration: 400 / 1000 [ 40%]  (Warmup) (Chain 2)
Iteration: 600 / 1000 [ 60%]  (Sampling) (Chain 1)
Iteration: 700 / 1000 [ 70%]  (Sampling) (Chain 3)
Iteration: 700 / 1000 [ 70%]  (Sampling) (Chain 1)
Iteration: 500 / 1000 [ 50%]  (Warmup) (Chain 2)
Iteration: 501 / 1000 [ 50%]  (Sampling) (Chain 2)
Iteration: 700 / 1000 [ 70%]  (Sampling) (Chain 0)
Iteration: 800 / 1000 [ 80%]  (Sampling) (Chain 3)
Iteration: 800 / 1000 [ 80%]  (Sampling) (Chain 1)
Iteration: 900 / 1000 [ 90%]  (Sampling) (Chain 3)
Iteration: 600 / 1000 [ 60%]  (Sampling) (Chain 2)
Iteration: 800 / 1000 [ 80%]  (Sampling) (Chain 0)
Iteration: 1000 / 1000 [100%]  (Sampling) (Chain 3)
# 
#  Elapsed Time: 0.044725 seconds (Warm-up)
#                0.037462 seconds (Sampling)
#                0.082187 seconds (Total)
# 
Iteration: 900 / 1000 [ 90%]  (Sampling) (Chain 1)
Iteration: 900 / 1000 [ 90%]  (Sampling) (Chain 0)
Iteration: 700 / 1000 [ 70%]  (Sampling) (Chain 2)
Iteration: 1000 / 1000 [100%]  (Sampling) (Chain 1)
# 
#  Elapsed Time: 0.05144 seconds (Warm-up)
#                0.040057 seconds (Sampling)
#                0.091497 seconds (Total)
# 
Iteration: 1000 / 1000 [100%]  (Sampling) (Chain 0)
# 
#  Elapsed Time: 0.04517 seconds (Warm-up)
#                0.046124 seconds (Sampling)
#                0.091294 seconds (Total)
# 
Iteration: 800 / 1000 [ 80%]  (Sampling) (Chain 2)
Iteration: 900 / 1000 [ 90%]  (Sampling) (Chain 2)
Iteration: 1000 / 1000 [100%]  (Sampling) (Chain 2)
# 
#  Elapsed Time: 0.053818 seconds (Warm-up)
#                0.040396 seconds (Sampling)
#                0.094214 seconds (Total)
# 
Inference for Stan model: anon_model_286b3180dfa752c4cfedaf0241add0e4.
4 chains, each with iter=1000; warmup=500; thin=1; 
post-warmup draws per chain=500, total post-warmup draws=2000.

           mean se_mean     sd   2.5%    25%    50%    75%  97.5%  n_eff   Rhat
mu         7.81    0.17   5.09  -2.41   4.61   7.77  11.13  17.63    907    1.0
tau        6.63    0.24    5.6   0.29   2.45   5.33   9.17  21.19    531    1.0
eta[0]     0.42    0.02   0.94  -1.53  -0.17   0.44   1.05   2.17   1701    1.0
eta[1]     0.01    0.02   0.87  -1.67  -0.55 3.2e-3    0.6   1.74   1751    1.0
eta[2]    -0.23    0.02   0.89  -1.95  -0.81  -0.22   0.38   1.51   1866    1.0
eta[3]    -0.01    0.02    0.9  -1.82  -0.61-8.1e-3   0.58   1.79   1669    1.0
eta[4]    -0.33    0.02   0.87   -2.1  -0.87  -0.34   0.21   1.44   1603    1.0
eta[5]    -0.21    0.02    0.9   -1.9  -0.82  -0.23   0.39   1.55   1599    1.0
eta[6]     0.34    0.02    0.9  -1.42  -0.24   0.36   0.95   2.07   1537    1.0
eta[7]     0.06    0.02   0.94  -1.84  -0.56   0.07   0.66   1.87   1604    1.0
theta[0]  11.61    0.24   8.19  -1.43   6.15  10.55  15.65  31.15   1211    1.0
theta[1]   7.84    0.14   6.35  -5.56   4.05   7.76   11.7  20.83   1988    1.0
theta[2]   5.97     0.2    7.7  -11.7   1.85   6.38  10.91  20.08   1505    1.0
theta[3]   7.61    0.15   6.45  -5.51   3.46   7.67  11.86  20.35   1897    1.0
theta[4]   5.13    0.14   6.45  -8.65   0.99   5.58   9.56   17.0   2000    1.0
theta[5]   6.21    0.17   6.51  -7.34   2.29   6.46  10.55  18.69   1522    1.0
theta[6]  10.64    0.17   6.88  -1.89   6.09  10.06  14.68  26.06   1623    1.0
theta[7]   8.33     0.2   7.94  -7.88   3.77   8.07  12.54  25.84   1522    1.0
lp__      -4.82    0.11   2.66 -10.43  -6.54  -4.59  -2.84  -0.42    598    1.0

Samples were drawn using NUTS at Sat Apr  1 00:55:24 2017.
For each parameter, n_eff is a crude measure of effective sample size,
and Rhat is the potential scale reduction factor on split chains (at 
convergence, Rhat=1).

C++のコンパイラが走って、MCMCのサンプリングが走って、パラメータの分布が出たようです。続いて各パラメータの分布図とiterationごとの変化がプロットされます。
細かいことは置いといて、何やら楽しそうなことが出来るということは分かりました。
次回は自分で簡単な乱数生成モデルを作って、それをモデリングしてパラメータ推定ができるかどうかを試してみたいと思います。

参考
Getting started — PyStan 2.14.0.0 documentation

0 件のコメント:

コメントを投稿