備忘録的な

プログラミングや機械学習に関する備忘録

RでStan(基本)

RからStanを使うための勉強をしていきます.

rstanとRtoolsのインストール

> install.packages('rstan', repos='https://cloud.r-project.org/', dependencies=TRUE)
> pkgbuild::has_build_tools(debug=TRUE)
> install.packages("rstudioapi")

テキストにはrstudioapiが必要とは書いてありませんでしたが,RStudioを使って並列計算をする場合に必要なようです.
また,
C:\RBuildTools\4.0\usr\bin
C:\RBuildTools\4.0\mingw64\bin
にパスを通す必要がありました.

実装

Stanファイル

// 使用するデータ
data {
  int<lower=0> N;
  vector[N] sales;
}

// 事後分布を得たいパラメータの一覧
parameters {
  real mu;
  real<lower=0> sigma;
}

// 売り上げという観測データは平均mu,標準偏差sigmaの正規分布から得られた
// ということを表すモデル
model {
  for (i in 1:N){
    sales[i] ~ normal(mu, sigma);
  }
}

modelブロックはforループを使わずに
sales ~ normal(mu, sigma);
と書くこともできます.このような書き方をベクトル化と呼びます.ベクトル化すると計算が速くなることがあるそうです.

Rファイル

library(rstudioapi)
library(rstan)

# 計算の高速化
rstan_options(auto_write=TRUE) # 2回目以降のコンパイルを不要にする
options(mc.cores=parallel::detectCores()) # 計算の並列化

beer_df <- read.csv("2-4-1-beer-sales-1.csv")

data_list <- list(sales=beer_df$sales, N=nrow(beer_df))

mcmc_result <- stan(
  file = "2-4-1-calc-mean-variance.stan",
  data = data_list,
  seed = 1,
  chains = 4,
  iter = 2000,
  warmup = 1000,
  thin = 1
)

stanメソッドに渡すデータはデータリスト形式で,Stanファイルに記載した変数が格納されている必要があります.
iterは生成する乱数の個数です.warmupは乱数の初期値依存性を減らすために切り捨てるサンプルの数です.つまり2000個生成した乱数の内最初の1000個を捨てることになります.
乱数は自己相関を持つこともあるため,thinを1より大きい値にするとthin個の乱数の内1つだけを採用するように間引くことで自己相関の影響を削減できます.
収束の評価を行うために乱数生成を複数セット行います.このセット数をchainsで設定します.

結果の確認

> print(
+   mcmc_result,
+   probs = c(0.025, 0.5, 0.975)
+ )
Inference for Stan model: 2-4-1-calc-mean-variance.
4 chains, each with iter=2000; warmup=1000; thin=1; 
post-warmup draws per chain=1000, total post-warmup draws=4000.

         mean se_mean   sd    2.5%     50%   97.5% n_eff Rhat
mu     102.18    0.03 1.82   98.57  102.20  105.73  3277    1
sigma   18.20    0.02 1.29   15.95   18.12   21.00  2858    1
lp__  -336.44    0.02 0.97 -338.98 -336.12 -335.48  1774    1

> traceplot(mcmc_result)

実行時に以下のような警告がでましたが原因がわかりませんでした.エラーではないのでとりあえず無視してよいのでしょうか?

 警告メッセージ: 
 system(paste(CXX, ARGS), ignore.stdout = TRUE, ignore.stderr = TRUE) で: 
  'C:/RBUILD~1/4.0/usr/mingw_/bin/g++' not found

n_effはMCMCにおける有効サンプルサイズだそうです.これが少ないようであればモデルの改善が必要であり,100くらいあることが望ましいそうです.
収束判定指標Rhatは1.1未満であることが求められます.
トレースプロットは4本のチェーンがまじりあっていればOKです.
inc_warmup=TRUEとするとwarmup期間を含めて描画できます.