読者です 読者をやめる 読者になる 読者になる

初めてのEMアルゴリズム with R

R 機械学習

この前はEMアルゴリズムがどんな感じのメカニズムで、どんな性質を持っているか簡単に書いた。

というわけで、ちょちょいとRで書いてみることにした。お題はありがちな混合正規分布。

混合正規分布について

確率変数x_1\mathcal{N}(x_1|-5, 1)x_2\mathcal{N}(x_2|5, 16)で、それぞれ0.3、0.7で生成されているというような分布が真の分布だとしよう。図で書くとこんな感じの密度関数である。
f:id:syou6162:20161010132345p:plain
図を書くためのRのスクリプト。

mixture_gaussian <- function(x) {
  pi_0 <- 0.3
  ifelse(runif(1) < pi_0, rnorm(1, -5, 1), rnorm(1, 5, 4))
}

N <- 1000
x <- sapply(1:N, mixture_gaussian)
plot(density(x)) # ここでdensityを使っちゃ本当はだめだけど。。。

混合正規分布のEMアルゴリズムによるパラメータ推定

この分布を混合正規分布によって推定しよう。求めるパラメータとしては、二つのガウス分布の平均、分散パラメータと混合係数である。EMアルゴリズムを使ったupdateの式はQ関数を微分して0と置いたものから導出することができる。混合ガウスでのupdateの式はパターン認識と機械学習 下 - ベイズ理論による統計的予測のp154に載っているので、今回はそれを使う。

  • Eステップでデータが与えられたもとでのk番目のガウス分布が選択される確率(=事後確率)、つまり負担率を計算する
  • Mステップでは、その負担率が最大になるように、updateしていく

そして、対数尤度の変化を見て、収束しているようであれば打ち切るようにしている。まあ、なんというかそのままRに落としてあげるだけで簡単に計算できる。

log_likelihood <- function(x, mu, sigma, pi) {
  sum(log(pi[1] * dnorm(x, mu[1], sqrt(sigma[1])) +
          pi[2] * dnorm(x, mu[2], sqrt(sigma[2]))))
}

mu <- c(-1, -1)
sigma <- c(1, 2)
pi <- c(0.5, 0.5)
gamma_0 <- c()
gamma_1 <- c()
n_k <- c()

log_likelihood_history <- c()

for(step in 1:1000) {
  old_log_likelihood <- log_likelihood(x, mu, sigma, pi)
  log_likelihood_history <- c(log_likelihood_history, old_log_likelihood)
  
  # E-step
  # gamma_0はクラス0の混合係数、gamma_1はクラス1の混合係数
  gamma_1 <- pi[1] * dnorm(x, mu[1], sqrt(sigma[1])) /
    (pi[1] * dnorm(x, mu[1], sqrt(sigma[1])) + pi[2] * dnorm(x, mu[2], sqrt(sigma[2])))
  gamma_2 <- 1 - gamma_1
  
  # M-step
  n_k[1] <- sum(gamma_1)
  n_k[2] <- sum(gamma_2)
  mu[1] <-  sum(gamma_1 * x) / n_k[1]
  mu[2] <-  sum(gamma_2 * x) / n_k[2]
  sigma[1] <- sum(gamma_1 * (x - mu[1])^2) / n_k[1]
  sigma[2] <- sum(gamma_2 * (x - mu[2])^2) / n_k[2]
  pi[1] <- n_k[1] / N
  pi[2] <- 1 - pi[1]
  if(abs(log_likelihood(x, mu, sigma, pi) - old_log_likelihood) < 0.001){
    break    
  }
}

こうしてEMアルゴリズムで3種類のパラメータが推定されたわけだが、そのパラメータの値は次のようになった。まあ、大体あってそうですね。

> mu
[1] -5.071632  5.132440
> sigma
[1]  0.9019544 16.1150782
> pi
[1] 0.2999747 0.7000253

EMアルゴリズムの単調増加性について

さて、EMアルゴリズムの性質として、ステップごとに対数尤度関数が増加、または変化しない(=単調非減少)という性質がある。これにより、EMアルゴリズムは局所的最適解に収束するという性質を持っている(凸関数なら大域的最適解に収束!!)。ということでステップ数と対数尤度関数をplotして、減少したりしていないか確認してみる(たぶんEMアルゴリズムのプログラムが妥当な感じかはこういう感じで確認できるはず)。
f:id:syou6162:20161010132438p:plain

max_log_likelihood <- log_likelihood(x, c(-5, 5), c(1, 16), c(0.3, 0.7))
plot(log_likelihood_history, pch=1, type="b", ylab="対数尤度")
abline(max_log_likelihood, 0, col="red", lwd=3)

結構早い感じで、最大のところに近づいているみたいですね(一般には遅いですが)。共役勾配法やニュートン法だと(目的関数が2次関数でなければ)、探索方向がよい方向を向いているとは限らないので、ちゃんと山を登っていけるEMアルゴリズムはそういう意味で安心して使うことができる。

というわけで、混合ガウスのパラメータをEMアルゴリズムで推定する話はおしまい。

パターン認識と機械学習 下 (ベイズ理論による統計的予測)

パターン認識と機械学習 下 (ベイズ理論による統計的予測)

  • 作者: C.M.ビショップ,元田浩,栗田多喜夫,樋口知之,松本裕治,村田昇
  • 出版社/メーカー: 丸善出版
  • 発売日: 2012/02/29
  • メディア: 単行本
  • 購入: 6人 クリック: 14回
  • この商品を含むブログを見る