Page List

Search on the blog

2017年4月1日土曜日

PyStan入門(2)線形モデル

 PyStanで線形モデルのパラメータ推定をしてみた。
データyはNormal(w*x + b, σ2)から生成されるとしてモデルを組んで、パラメータw, b, σの分布を求めてみた。

分かったこと
  • pystan.stan()メソッドの引数で以下を指定できる
    • iter 各チェーンが行うサンプリング回数(warmup含む)
    • warmup 初回の捨てるサンプリング数
    • chain チェーンの数
  • fit.extract()メソッドでサンプリング値を取れる
    • permuted引数=Trueとすると、チェーンをマージしたサンプル列を取得できる
    • permuted引数=Falseとすると、(iter, chain, parameters)の次元のnumpy.ndarrayを取得できる
コード

data {
    int<lower=0> N;
    vector[N] x;
    vector[N] y; 
} parameters {
    real w;
    real b;
    real<lower=0> sigma;
} model {
    y ~ normal(w * x + b, sigma);
}

import numpy as np
import pystan
import matplotlib.pyplot as plt
import pandas as pd


def gen():
  N = 30
  w = 10
  b = 3
  sigma = 0.5

  x = np.random.rand(N)
  err = sigma * np.random.randn(N)
  y = w * x + b + err

  return (x,y)


def train(x, y):
  data = {
    'N': len(x),
    'x': x,
    'y': y
  }
  fit = pystan.stan(file='linear_model.stan', data=data, iter=2000, warmup=1000, chains=4)
  return fit


def predict(param, x):
  return param['w'] * x + param['b']


if __name__ == '__main__':
  x, y = gen()
  fit = train(x, y)
  fit.plot()
  plt.show()

  xt = np.linspace(0, 1, 100)
  params = pd.DataFrame({
    'w': fit.extract(permuted=True).get('w'),
    'b': fit.extract(permuted=True).get('b')
  })

  median_params = params.median()
  yt = predict(median_params, xt)
  yt_lower = [np.percentile(predict(params, x), 2.5) for x in xt]
  yt_upper = [np.percentile(predict(params, x), 97.5) for x in xt]

  plt.plot(x, y, 'bo')
  plt.plot(xt, yt, 'k')
  plt.fill_between(xt, yt_lower, yt_upper, facecolor='lightgrey', edgecolor='none')
  plt.show()

結果
まずトレースプロットと、各パラメータの分布。それっぽい分布になっている。
次に学習結果を使って予測値のプロット。実線はパラメータの中央値を使ってyを予測した値。グレーの部分は得られたパラメータのサンプルを使って予測値を計算したときの2.5%パーセンタイルと97.5パーセンタイルの値以内の領域。

0 件のコメント:

コメントを投稿