PyStanのログ出力を消す

背景

PyStanでMCMCするときには標準出力に有益なログが出力されます。

たとえば,サンプリングの途中経過が出力されます。

Iteration:  800 / 2000 [ 40%]  (Warmup)
Iteration:  800 / 2000 [ 40%]  (Warmup)
Iteration:  600 / 2000 [ 30%]  (Warmup)
Iteration:  800 / 2000 [ 40%]  (Warmup)
Iteration: 1000 / 2000 [ 50%]  (Warmup)
Iteration: 1000 / 2000 [ 50%]  (Warmup)
Iteration: 1001 / 2000 [ 50%]  (Sampling)
Iteration: 1001 / 2000 [ 50%]  (Sampling)
Iteration: 1000 / 2000 [ 50%]  (Warmup)
Iteration: 1001 / 2000 [ 50%]  (Sampling)
Iteration:  800 / 2000 [ 40%]  (Warmup)
Iteration: 1200 / 2000 [ 60%]  (Sampling)
Iteration: 1200 / 2000 [ 60%]  (Sampling)

最近になって,サンプリング終了までの時間の予測も出力されるようになりました。

Gradient evaluation took 3e-05 seconds
1000 transitions using 10 leapfrog steps per transition would take 0.3 seconds.
Adjust your expectations accordingly!

これらのログはMCMCによるサンプリングが正しく行えているかどうかチェックするのに非常に有用ですが,ログを出力させたくない場合もあります。私の場合だと,Thompson samplingのための事後分布の推定にStanを使っているのですが,アームを引くたびに事後分布の推定を行う必要があり,何度もStanを呼び出すことになるのでStanのログ出力だけでコンソール出力が埋まってしまうのに悩まされていました。

ダメな解決法1

PyStanのドキュメント を読んでみると,StanModel.sampling のオプションとして refresh というパラメータが見つかります。これはサンプリングの途中経過を表す Iteration: 1001 / 2000 [ 50%] (Sampling) という出力の頻度を調整するパラメータで,0に設定すると途中経過が出力されなくなります。このパラメータの設定によって出力が抑制されるものの,サンプリング終了時に出力される以下のようなログは表示されたままで,かなりの行数が使われてしまい邪魔だと感じてしまいます。

 Elapsed Time: 0.01544 seconds (Warm-up)
               0.013689 seconds (Sampling)
               0.029129 seconds (Total)

ダメな解決法2

Pythoncontextlib モジュールには redirect_stdout というコンテクストマネージャがあり,with文で指定した範囲内で stdout への出力を別のストリームに変更することができます。以下のような関数を作って stdout への出力を /dev/null に変更してみました。

@contextlib.contextmanager
def suppress_stdout():
    with open(os.devnull, "w") as f:
        with contextlib.redirect_stderr(f):
            yield

以下のように使います。

data = ...
stan_model = ...
with suppress_stdout():
    fit = stan_model.sampling(data=data, seed=0)

この方法ではPyStanの出力は抑制できません。PyStanで呼び出されたサンプラーPythonコードの子プロセスとして起動するので,Pythonコード中の stdout が引き継がれるわけではないからです。

うまくいった解決法

以下のGitHub Issuesに「ファイルディスクリプタを書き換えれば良いよ」という書き込みがあったので試してみました。

github.com

今回は stdout への出力を抑制したいだけで stderr の出力はそのまま表示されてほしいので,処理を若干書き換えました。

@contextlib.contextmanager
def suppress_stdout():
    null_fd = os.open(os.devnull, os.O_RDWR)
    save_fd = os.dup(1)
    os.dup2(null_fd, 1)
    yield
    os.dup2(save_fd, 1)
    os.close(null_fd)
    os.close(save_fd)

以下のように使います。

data = ...
stan_model = ...
with suppress_stdout():
    fit = stan_model.sampling(data=data, seed=0)

この方法でPyStanの出力を抑制することができました。