logsumexpを使って乱数生成 + X^2検定

空港で暇にしていたので書いてみる(出発時間が遅れた)。

takanori-i君がベイジアンHMMを作っているらしく、相談に乗る。確率の積が入ってきて、数値計算で死んでしまうときがあるとのことだったので、logsumexpについて教える。logsumexpについては高村本が分かりやすい。

後で他の人にも教えることがありそうな気がしてきたので、コードがあると教えやすい。ということでメモがてら残しておく。0を1/15の確率で、1を2/15の確率で...4を5/15の確率でという具合の多項分布から生成してくるという単純なもの。Samplerクラスというのがlogsumexpしたやつをdecodeするようなクラスと考えてもらえばよい。

そのままではあまり面白くない(?)ので、生成してきたサンプルが思った通りの離散分布から生成されているかX^2検定で確かめる、というのをやっておく。実はこういう例ってあんまりWebとかでは見ない。研究のコードだとそもそもテスト書かれている例のほうが少な(ry。X^2検定等を使った乱数などのテストについてはBeatiful Testingに少し載っているので参考にするとよい。

#include <iostream>
#include <vector>
#include <string>
#include <map>
#include <math.h>

template <class T>
class Sampler {
public:
  T sample(const std::vector<std::pair<T, double> >& posts, const double psum) {
	const double r = (double) rand() / RAND_MAX;
	T next_z = posts[0].first;
	double prev = exp(posts[0].second - psum);
	if (r > prev) {
	  for (unsigned int i = 1; i < posts.size(); i++) {
		prev = exp(posts[i].second - psum) + prev;
		if (r <= prev) {
		  next_z = posts[i].first;
		  break;
		}
	  }
	}
	return next_z;
  };
};

double logsumexp (double x, double y, bool flg) {
  if (flg) return y; // init mode
  if (x == y) return x + 0.69314718055;  // log(2)
  double vmin = std::min(x, y);

  double vmax = std::max(x, y);
  if (vmax > vmin + 50) {
    return vmax;
  } else {
    return vmax + log (exp (vmin - vmax) + 1.0);
  }
};

int main(int argc, char** argv) {
  std::vector<std::pair<int, double> > v;
  v.push_back(std::make_pair(0, 1.0));
  v.push_back(std::make_pair(1, 2.0));
  v.push_back(std::make_pair(2, 3.0));
  v.push_back(std::make_pair(3, 4.0));
  v.push_back(std::make_pair(4, 5.0));
  std::vector<std::pair<int, double> > posts(v.size());
  double psum = 0.0;
  for(unsigned int i = 0; i < v.size(); i++) {
	posts[i].first = v[i].first;
	posts[i].second = log(v[i].second);
	psum = logsumexp(psum, posts[i].second, (i == 0));
  }

  Sampler<int> sampler;
  std::vector<int> table(v.size(), 0);
  const int N = 100000;
  for (int i = 0; i < N; i++) {
	table[sampler.sample(posts, psum)]++;
  }

  double chi_square = 0.0;
  for (unsigned int k = 0; k < v.size(); k++) {
	const int Ok = table[k];
	const double Ek = (double) N * v[k].second / 15.0; // (1 + 2 + 3 + 4 + 5 = 15)
	chi_square += (Ok - Ek) * (Ok - Ek) / Ek;
  }
  // Chi_square_test with K &#8211; 1 degrees of freedom
  // http://kusuri-jouhou.com/statistics/bunpuhyou.html
  if (chi_square < 9.488) {
	std::cout << "成功!!" << std::endl;
  } else {
	std::cout << "失敗!!" << std::endl;
  }

  return 0;
}

言語処理のための機械学習入門 (自然言語処理シリーズ)

言語処理のための機械学習入門 (自然言語処理シリーズ)

Beautiful Testing

Beautiful Testing