背景
PyStanでMCMCするときには標準出力に有益なログが出力されます。
たとえば,サンプリングの途中経過が出力されます。
Iteration: 800 / 2000 [ 40%] (Warmup) Iteration: 800 / 2000 [ 40%] (Warmup) Iteration: 600 / 2000 [ 30%] (Warmup) Iteration: 800 / 2000 [ 40%] (Warmup) Iteration: 1000 / 2000 [ 50%] (Warmup) Iteration: 1000 / 2000 [ 50%] (Warmup) Iteration: 1001 / 2000 [ 50%] (Sampling) Iteration: 1001 / 2000 [ 50%] (Sampling) Iteration: 1000 / 2000 [ 50%] (Warmup) Iteration: 1001 / 2000 [ 50%] (Sampling) Iteration: 800 / 2000 [ 40%] (Warmup) Iteration: 1200 / 2000 [ 60%] (Sampling) Iteration: 1200 / 2000 [ 60%] (Sampling)
最近になって,サンプリング終了までの時間の予測も出力されるようになりました。
Gradient evaluation took 3e-05 seconds 1000 transitions using 10 leapfrog steps per transition would take 0.3 seconds. Adjust your expectations accordingly!
これらのログはMCMCによるサンプリングが正しく行えているかどうかチェックするのに非常に有用ですが,ログを出力させたくない場合もあります。私の場合だと,Thompson samplingのための事後分布の推定にStanを使っているのですが,アームを引くたびに事後分布の推定を行う必要があり,何度もStanを呼び出すことになるのでStanのログ出力だけでコンソール出力が埋まってしまうのに悩まされていました。
ダメな解決法1
PyStanのドキュメント を読んでみると,StanModel.sampling
のオプションとして refresh
というパラメータが見つかります。これはサンプリングの途中経過を表す Iteration: 1001 / 2000 [ 50%] (Sampling)
という出力の頻度を調整するパラメータで,0に設定すると途中経過が出力されなくなります。このパラメータの設定によって出力が抑制されるものの,サンプリング終了時に出力される以下のようなログは表示されたままで,かなりの行数が使われてしまい邪魔だと感じてしまいます。
Elapsed Time: 0.01544 seconds (Warm-up) 0.013689 seconds (Sampling) 0.029129 seconds (Total)
ダメな解決法2
Pythonの contextlib
モジュールには redirect_stdout
というコンテクストマネージャがあり,with文で指定した範囲内で stdout
への出力を別のストリームに変更することができます。以下のような関数を作って stdout
への出力を /dev/null
に変更してみました。
@contextlib.contextmanager def suppress_stdout(): with open(os.devnull, "w") as f: with contextlib.redirect_stderr(f): yield
以下のように使います。
data = ... stan_model = ... with suppress_stdout(): fit = stan_model.sampling(data=data, seed=0)
この方法ではPyStanの出力は抑制できません。PyStanで呼び出されたサンプラーはPythonコードの子プロセスとして起動するので,Pythonコード中の stdout
が引き継がれるわけではないからです。
うまくいった解決法
以下のGitHub Issuesに「ファイルディスクリプタを書き換えれば良いよ」という書き込みがあったので試してみました。
今回は stdout
への出力を抑制したいだけで stderr
の出力はそのまま表示されてほしいので,処理を若干書き換えました。
@contextlib.contextmanager def suppress_stdout(): null_fd = os.open(os.devnull, os.O_RDWR) save_fd = os.dup(1) os.dup2(null_fd, 1) yield os.dup2(save_fd, 1) os.close(null_fd) os.close(save_fd)
以下のように使います。
data = ... stan_model = ... with suppress_stdout(): fit = stan_model.sampling(data=data, seed=0)
この方法でPyStanの出力を抑制することができました。