読者です 読者をやめる 読者になる 読者になる

言語モデルの準備

自然言語処理 C++

自然言語処理特論で使うやつの準備の準備くらいの。

準備として青空文庫のテキストを食わせる。で、何か入力の文を与えるとUnigram、Bigram、Trigramの言語モデルでのその文が出てくる確率を計算する。確率は非常に小さくなるので、出力するところでは対数を取っている。

ここでは、「首をひねって考えた」というそれっぽい文を入力として与えている。三連鎖であるところのTrigramが一番確率として大きくなっている。正しい文に高い確率を付与している、という意味ではTrigram > Bigram > Unigramな感じであるように思うが、「私は首をひねって考えた」のような分にしてしまうと、Trigramではマイナス無限大に飛んでしまうことがある。いわゆる0頻度問題。性能が高い*1モデルは頑健性がない、ということがなんとなく分かる実験ができた。

/Users/syou6162/cpp% g++ -Wall ngram.cpp `mecab-config --cflags` `mecab-config --libs` -lboost_system-mt -lboost_filesystem-mt
/Users/syou6162/cpp% ./a.out
今日 : 1 / 4
今日は : 1 / 1
今日は晴 : 1 / 1
==================================================
首 (268/1126629)
を (28470/1126629)
ひねっ (6/1126629)
て (31981/1126629)
考え (350/1126629)
た (30489/1126629)
-39.4131
首, を (163/268)
を, ひねっ (4/28470)
ひねっ, て (4/6)
て, 考え (16/31981)
考え, た (65/350)
-19.0569
首, を, ひねっ (4/163)
を, ひねっ, て (3/4)
ひねっ, て, 考え (1/4)
て, 考え, た (1/16)
-8.15402

とりえあず、ゼロ頻度を解決すべく前回の授業でやったスムージングを使ってもうちょっとましにしようと思います。

以下コード。

#include <iostream>
#include <string>
#include <math.h>
#include <vector>
#include <map>
#include <mecab.h>
#include <boost/foreach.hpp>
#include <boost/tr1/tuple.hpp>
#include <boost/filesystem/path.hpp>
#include <boost/algorithm/string.hpp>
#include <boost/filesystem/fstream.hpp>
#include <boost/filesystem/operations.hpp>

using namespace std;
using namespace std::tr1;
using namespace boost;
using namespace boost::filesystem;

vector<string> str2vec(const string str) {
  // stringを形態素解析して分解する
  vector<string> result;
  MeCab::Tagger *tagger = MeCab::createTagger("-O wakati");
  const MeCab::Node *node = tagger->parseToNode( str.c_str() );
  for( node=node->next; node->next; node=node->next ){ 
	// 形態素解析されて出てきた文字列(品詞の情報ではなく)
	char *s = new char [node->length + 1];
	strncpy(s, node->surface, node->length);
	s[node->length] = '\0';
	result.push_back(s);
  } 	
  delete tagger;
  return result;
}

template <class T>
class Ngram {
public:
  void add(const T t);
  int denominator(const T t); // 分母
  int numerator(const T t); // 分子
private:
};

class Unigram : public Ngram<string> {
public:
  Unigram() {
	this->N = 0;
  };
  void add(const string text) {
	BOOST_FOREACH(string s, str2vec(text)) {
	  C_w[s]++;
	  N++;
	}
  };
  int denominator() {
	return N;
  };
  int numerator(string s) {
	return C_w[s];
  }
private:
  map<string, int> C_w; // C(w)
  int N; 
};

class Bigram : public Ngram<pair<string, string> > {
public:
  void add(const string text) {
	vector<string> v = str2vec(text);
	pair<string, string> p;
	BOOST_FOREACH(p , convert(v)) {
	  C_b[p.first]++;
	  C_b_w[p]++;
	}
  };
  vector<pair<string, string> > convert(const vector<string> v) { // bigramなvectorに変換
	vector<pair<string, string> > result;
	int length = v.size();
	if (length <= 1) return result;
	for (int i = 0; i < length - 1; i++) {
	  result.push_back(make_pair(v.at(i), v.at(i + 1)));
	}
	return result;
  };
  int denominator(string s) {
	return C_b[s];
  };
  int numerator(pair<string, string> p) {
	return C_b_w[p];
  }
private:
  map<pair<string, string>, int> C_b_w; // C(b, w)
  map<string, int> C_b; // C(b)
};

class Trigram : public Ngram<tuple<string, string, string> > {
public:
  void add(string text) {
	vector<string> v = str2vec(text);
	tuple<string, string, string> t;
	BOOST_FOREACH(t , convert(v)) {
	  pair<string, string> tmp = make_pair(get<0>(t), get<1>(t));
	  C_a_b[tmp]++;
	  C_a_b_w[t]++;
	}
  };
  vector<tuple<string, string, string> > convert(const vector<string> v) {
	vector<tuple<string, string, string> > result;
	int length = v.size();
	if (length <= 2) return result;
	for (int i = 0; i < length - 2; i++) {
	  result.push_back(make_tuple(v.at(i), v.at(i + 1), v.at(i + 2)));
	}
	return result;
  }
  int denominator(pair<string, string> p) {
	return C_a_b[p];
  };
  int numerator(tuple<string, string, string> t) {
	return C_a_b_w[t];
  }
private:
  map<tuple<string, string, string>, int> C_a_b_w; // C(a, b, w)
  map<pair<string, string>, int> C_a_b; // C(a, b)
};

int main(int argc, char *argv[]) {
  setlocale(LC_CTYPE, ""); // wcharの付近の変換で必要
  Unigram unigram;
  unigram.add("今日は晴だ");
  cout << "今日 : " << unigram.numerator("今日") 
	   << " / " << unigram.denominator() << endl;

  Bigram bigram;
  bigram.add("今日は晴だ");
  cout << "今日は : " << bigram.numerator(make_pair("今日", "は")) 
	   << " / " << bigram.denominator("今日") << endl;

  Trigram trigram;
  trigram.add("今日は晴だ");
  cout << "今日は晴 : " << trigram.numerator(make_tuple("今日", "は", "晴")) 
	   << " / " << trigram.denominator(make_pair("今日", "は")) << endl;
  
  cout << "==================================================" << endl;
  
  std::string dir = "/Users/syou6162/dbcls/aozora/";
  path fullPath = complete(path(dir, native));
  directory_iterator end;
  int i = 0;
  for (directory_iterator it(fullPath); it !=end; ++it) {
	//	cout << i << " : " << (dir + it->leaf()) << endl;
	std::ifstream fis((dir + it->leaf()).c_str());
	string line;
	while(getline(fis, line)) {
	  unigram.add(line);
	  bigram.add(line);
	  trigram.add(line);
	}
	i++;
	if (i > 100) {
	  break;
	}
  }

  string text = "首をひねって考えた";

  double prob = 0.0;
  BOOST_FOREACH(string s, str2vec(text)) {
	prob += log((double) unigram.numerator(s) / (double) unigram.denominator());
	cout << s
		 << " (" << unigram.numerator(s) << "/" <<  unigram.denominator() << ")"
		 << endl;
  }
  cout << prob << endl;

  prob = 0.0;
  pair<string, string> p;
  BOOST_FOREACH(p, bigram.convert(str2vec(text))) {
	prob += log((double) bigram.numerator(p) / (double) bigram.denominator(p.first));
	cout << p.first << ", " << p.second 
		 << " (" << bigram.numerator(p) << "/" <<  bigram.denominator(p.first) << ")"
		 << endl;
  }
  cout << prob << endl;

  prob = 0.0;
  tuple<string, string, string> t;
  BOOST_FOREACH(t, trigram.convert(str2vec(text))) {
	prob += log((double) trigram.numerator(t) / 
				(double) trigram.denominator(make_pair(get<0>(t), get<1>(t))));
	cout << get<0>(t) << ", " << get<1>(t) << ", " << get<2>(t)
		 << " (" << trigram.numerator(t) << "/" 
		 <<  trigram.denominator(make_pair(get<0>(t), get<1>(t))) << ")"
		 << endl;
  }
  cout << prob << endl;
  return 0;
}

*1:ここでは、パープレキシティも使っていないので、性能もへったくれもないのだが