ちょっと前に実装してたんだけど、メモを書くがてら公開してみる。やりたいこととしてはnested Chinse Restaurant Processまで行きたいんだけど、ノンパラベイズ初心者なので一番取りかかりやすいであろうDirichlet Process Mixture(DPM)を文書モデルでやってみたという感じです。HDPではなくDPMとしてコーパスをモデル化するので、いくつ文書があろうがそれらは全部一つの文書として取り扱います(というか、そういう形でしか取り扱えません。扱いたかったらHDPの世界へ行こう)。
目的
やってみようと思った理由はいくつかあって
- 実際に自分で把握できるミニマムな*1ノンパラベイズのプログラムをgetする
- 実際に書いてみることでノンパラベイズのプログラムではまりやすいところを知る
- コード書く段階までやってみないと分かったつもりになっていることが多い
などなどです。実際に動かしてみるとサンプリングとかよりクラスタ削除した際のラベルの付け換えに一番時間くったりしているっぽいけど、とりあえず動いているのでまあいいことにしておこう(ぇ。
DPMの簡単な復習
DPMのグラフィカルモデルは以下のように書き表わすことができる。
Hが基底分布で、今回は文書を取り扱うのでHは具体的ディリクレ分布となる。は集中度パラメータで、こいつが大きいほど新しいクラスタができやすくなる。今回コードを書いて実験したのはこのパラメータの影響度が実際にどんなもんかを身を持って確かめたかったというのもある。(連続分布である)Hから生成した分布Gは離散分布となる。Gが離散分布となることからサンプル毎にパラメータを用意しておいても実際には同じ値が使われやすいような効果が起きる。
モデルとしてはこれで何も不足していないのだが、実際にGibbs Samplingをするときには補助変数を導入したり、基底分布を具体的に書き換えたほうがやりやすいことが多いので書き換える。
文書の区切りがないようなLDAでトピック数は無限大といった感じのイメージと言えば伝わりやすいだろうか。ベイズなので、パラメータとについて積分消去したものを考える。LDA Gibbsのときと同様に考えて、サンプリングの式は以下の2つのfactorに分解して考えることができる。
2つのfactorを、文書の場合について書きくだしてみたのがこれ。新しいクラスタができるところも実際に書きくだしてみるとそんなに怖くない、ということが分かる。
対数尤度
Gibbs Samplingの最中は対数尤度がある程度単調に増加していくことを確かめたいので、その式の導出。まず、事後分布は
というように2つのfactorに分解することができる。最初のfactorについてはLDAの要領で、以下のポリア分布が出てくる形で書ける。
ここで、はクラスタkに属している単語wの数。次に第二項だが、ここはLDAのように行かない。CRPの定義を生かすために隠れ変数のN個の同時分布をchain ruleを使ってばらすと以下のように書き表わすことができる。
ここで、はクラスタkに属している単語の数。プログラムを書くときにはこの対数のものが分かっていると便利なのでメモ。階乗とかが出てくるので3秒くらい焦るが普通にやればできる。
テストセットパープレキシティ
新しいデータが新しいクラスタに所属するかそうでないかのどっちにすればいいかについて結構悩むのだが、Bleiの"Variational methods for the Dirichlet Process"とかを見ていると"note that the next component can take on K + 1 possible values"と書いてあるので新しいデータがK+1番目のクラスタを作ることを考慮してテストセットパープレキシティを導出してみる。
テストセットパープレキシティの定義式は
であるが、これを隠れ変数を周辺化したものだと考えると
と書ける(p(z)のところがちょっと自身ない...)。
実験
コードはgithubに上げてあります(DPMなのにdirichlet_processというプロジェクトの名前のままだった...)。
ややこしいですが、このエントリでのハイパーパラメータとコードでのハイパーパラメータがちょっと違ってて
- エントリ => コード
- =>
- =>
という風になっているので注意。プロの方々はハイパーパラメータも最適化するらしいですが、今回はそれは置いとく。
実験の結果では、iteration毎の
- クラスタの数
- 対数尤度
- テストセットパープレキシティ
を表示しています。クラスタ数は大体のところで頭打ちになっていて、対数尤度も途中から安定してるので多分当ってて、テストセットパープレキシティもほどほどな値になっている。ということできっと動いているんでしょう、たぶん。
ただ、バグっているのかそもそもこういうものなのかよく分かっていないところが一点。p(w|z)で各クラスタで特有な単語を列挙してみた際にそれぞれのクラスタ間であまち違いが出ていないという結果になってしまう。なんでだろー。
*1:DPを使ったコードでいくつか公開されているものがあるけど、DPM以上に複雑なものが多いので、把握しずらいことが多かった。あとmatlabで公開されているものも多かったが、C++での例が欲しかった