PythonでEMアルゴリズムを可視化してみた(混合正規分布)

可視化してみました。

f:id:kujira16:20151206210030g:plain

解説

以下の動画が一番分かりやすいです。

www.youtube.com

Q関数を最大化する\lambda, \mu, \sigmaの求め方は,実際に解いてみると,本当にラグランジュの未定乗数法を使うだけという感じでした。ただし \sum_h \lambda_h=1 という制約を組み込むのを忘れずに…

\sigma_hについて求めるところはちょっと迷ったので,導出を載せておきます。

{
\displaystyle
\begin{align}
&\frac{\partial}{\partial\sigma_h}\log P(x_i|h,\theta) \\
=&\frac{\partial}{\partial\sigma_h}\log \left(\frac{1}{\sqrt{2\pi\sigma_h^2}}\exp\left(-\frac{(x_i-\mu_h)^2}{2\sigma_h^2}\right)\right) \\
=&\frac{\partial}{\partial\sigma_h}\left(-\log\sqrt{2\pi\sigma_h^2}-\frac{(x_i-\mu_h)^2}{2\sigma_h^2}\right) \\
=&-\frac{1}{\sigma_h}+\frac{(x_i-\mu_h)^2}{\sigma_h^3} \\
=&\frac{-\sigma_h^2+(x_i-\mu_h)^2}{\sigma_h^3}
\end{align}
}
{
\displaystyle
\begin{align}
&\frac{\partial}{\partial\sigma_h}L(\theta,\alpha)\\
=&\sum_{i=1}^N\bar{P}(h|x_i;\theta')\frac{\partial}{\partial\sigma_h}\log P(x_i|h,\theta)\\
=&\sum_{i=1}^N\bar{P}(h|x_i;\theta')\frac{-\sigma_h^2+(x_i-\mu_h)^2}{\sigma_h^3}=0
\end{align}
}
{\displaystyle
\sigma_h^2\sum_{i=1}^N\bar{P}(h|x_i;\theta')=\sum_{i=1}^N\bar{P}(h|x_i;\theta')(x_i-\mu_h)^2
}
{\displaystyle
\sigma_h^2=\frac{\sum_{i=1}^N\bar{P}(h|x_i;\theta')(x_i-\mu_h)^2}{\sum_{i=1}^N\bar{P}(h|x_i;\theta')}
}

ソースコード

Pythonで書きました。movieディレクトリにpngを連番で出力します。アニメーションGIFを作るには以下のサイトを参照してください。

アニメーションGIFの作り方 ImageMagick編

hに関してのforループが残っていてnumpy力の低さが感じられますが,自信がある人は消してみてください。

# coding: UTF-8
import os
import shutil
import scipy
import scipy.stats
import matplotlib.pyplot as plt


def make_gmm(k, size):
    # ランダム生成だと良い感じの分布ができなかった
    # r = scipy.stats.dirichlet.rvs([1]*k).flatten()
    r = [0.35, 0.65]
    # mu = scipy.stats.norm.rvs(size=k)
    mu = [-0.3, 0.5]
    # sigma = scipy.sqrt(scipy.stats.invgamma.rvs(1, size=k))
    sigma = [0.15, 0.3]
    X = []
    for h in range(k):
        add = scipy.stats.norm.rvs(
            loc=mu[h],
            scale=sigma[h],
            size=int(size * r[h]))
        X.append(add)
    X = scipy.concatenate(X)
    return X


def estimate_gmm(X, k):
    def e_step(lamb, mu, sigma):
        '''
        Returns
        -------
        p_bar : shape[k, n_samples]
        '''
        p_bar = scipy.empty((k, len(X)))
        for h in range(k):
            norm = scipy.stats.norm(loc=mu[h], scale=sigma[h])
            p_bar[h, :] = lamb[h] * norm.pdf(X)
        p_bar /= p_bar.sum(axis=0)
        return p_bar

    def m_step(p_bar):
        '''
        Returns
        -------
        lamb  : shape[k]
        mu    : shape[k]
        sigma : shape[k]
        '''
        def next_lamb():
            '''
            Returns
            -------
            lamb : shape[k]
            '''
            lamb = p_bar.sum(axis=1)
            lamb /= lamb.sum()
            return lamb

        def next_mu():
            '''
            Returns
            -------
            mu : shape[k]
            '''
            mu = scipy.empty((k, ))
            for h in range(k):
                numer = scipy.dot(p_bar[h, :], X)
                denom = p_bar[h, :].sum()
                mu[h] = numer / denom
            return mu

        def next_sigma(mu):
            '''
            Returns
            -------
            sigma : shape[k]
            '''
            sigma = scipy.empty((k, ))
            for h in range(k):
                numer = scipy.dot(p_bar[h, :], (X - mu[h]) ** 2)
                denom = p_bar[h, :].sum()
                sigma[h] = scipy.sqrt(numer / denom)
            return sigma

        lamb = next_lamb()
        mu = next_mu()
        sigma = next_sigma(mu)
        return lamb, mu, sigma

    def log_likelihood(lamb, mu, sigma):
        rvs = []
        for h in range(k):
            rvs.append(scipy.stats.norm(loc=mu[h], scale=sigma[h]))
        l = scipy.empty((k, len(X)))
        for h in range(k):
            l[h, :] = lamb[h] * rvs[h].pdf(X)
        return scipy.log(l.sum(axis=0)).sum()

    # 初期値
    lamb = scipy.ones((k, )) / k
    mu = scipy.stats.uniform.rvs(loc=-1, scale=2, size=k)  # [-1, 1]
    sigma = scipy.sqrt(scipy.stats.invgamma.rvs(1, size=k))
    l_prev = log_likelihood(lamb, mu, sigma)
    while True:
        p_bar = e_step(lamb, mu, sigma)
        lamb, mu, sigma = m_step(p_bar)
        l = log_likelihood(lamb, mu, sigma)
        yield lamb, mu, sigma, l
        if scipy.absolute(l_prev - l) < 0.001:
            break
        l_prev = l


def main():
    # クラスタ数
    k = 2
    # データ数
    n_samples = 1000
    X = make_gmm(k, n_samples)

    outdir = 'movie'
    if os.path.isdir(outdir):
        shutil.rmtree(outdir)
    os.mkdir(outdir)

    i = 0
    for lamb, mu, sigma, l in estimate_gmm(X, k):
        plt.clf()
        plt.xlim(X.min(), X.max())
        plt.hist(X, normed=True, facecolor='none', bins=20)
        for h in range(k):
            xs = scipy.linspace(X.min(), X.max())
            ys = lamb[h] * scipy.stats.norm.pdf(xs, loc=mu[h], scale=sigma[h])
            plt.plot(xs, ys)
        plt.suptitle('log Likelihood = {:.2f}'.format(l))
        filename = os.path.join(outdir, '{:0>3}.png'.format(i))
        i += 1
        plt.savefig(filename)

if __name__ == '__main__':
    main()

追記

GMM以外での典型的な適用例を知りたい場合は以下のサイトがおすすめです。

yamaguchiyuto.gitbooks.io

参考文献

言語処理のための機械学習入門 (自然言語処理シリーズ)

言語処理のための機械学習入門 (自然言語処理シリーズ)