データyはNormal(w*x + b, σ2)から生成されるとしてモデルを組んで、パラメータw, b, σの分布を求めてみた。
分かったこと
- pystan.stan()メソッドの引数で以下を指定できる
- iter 各チェーンが行うサンプリング回数(warmup含む)
- warmup 初回の捨てるサンプリング数
- chain チェーンの数
- fit.extract()メソッドでサンプリング値を取れる
- permuted引数=Trueとすると、チェーンをマージしたサンプル列を取得できる
- permuted引数=Falseとすると、(iter, chain, parameters)の次元のnumpy.ndarrayを取得できる
data { int<lower=0> N; vector[N] x; vector[N] y; } parameters { real w; real b; real<lower=0> sigma; } model { y ~ normal(w * x + b, sigma); }
import numpy as np import pystan import matplotlib.pyplot as plt import pandas as pd def gen(): N = 30 w = 10 b = 3 sigma = 0.5 x = np.random.rand(N) err = sigma * np.random.randn(N) y = w * x + b + err return (x,y) def train(x, y): data = { 'N': len(x), 'x': x, 'y': y } fit = pystan.stan(file='linear_model.stan', data=data, iter=2000, warmup=1000, chains=4) return fit def predict(param, x): return param['w'] * x + param['b'] if __name__ == '__main__': x, y = gen() fit = train(x, y) fit.plot() plt.show() xt = np.linspace(0, 1, 100) params = pd.DataFrame({ 'w': fit.extract(permuted=True).get('w'), 'b': fit.extract(permuted=True).get('b') }) median_params = params.median() yt = predict(median_params, xt) yt_lower = [np.percentile(predict(params, x), 2.5) for x in xt] yt_upper = [np.percentile(predict(params, x), 97.5) for x in xt] plt.plot(x, y, 'bo') plt.plot(xt, yt, 'k') plt.fill_between(xt, yt_lower, yt_upper, facecolor='lightgrey', edgecolor='none') plt.show()
結果
まずトレースプロットと、各パラメータの分布。それっぽい分布になっている。
次に学習結果を使って予測値のプロット。実線はパラメータの中央値を使ってyを予測した値。グレーの部分は得られたパラメータのサンプルを使って予測値を計算したときの2.5%パーセンタイルと97.5パーセンタイルの値以内の領域。
0 件のコメント:
コメントを投稿