半教師あり学習でollを使いたいので、ollをライブラリとして使ってみる

バイトにて、岡野原さんが作られているオンライン学習をサポートした機械学習ライブラリ「oll」をめちゃくちゃ使わせてもらっている。

自然言語処理のような大規模かつスパースな例だと、SVMよりめちゃくちゃ早い&同精度ということは昔書いた。

ollにはoll_trainとoll_testという*1プログラムが付属しており、大抵の場合はコンパイルしたあとに、これに投げれば解決する。

半教師あり学習で使いたい

バイト@DBCLSでライフサイエンス分野の専門用語辞書を作ろうとしているわけですが、自分で書いているテキストから特徴ベクトルを作り出すプログラムがようやくまともに動き初めてきた*2。で、その作った素性をollに投げている。しかしながら、テストデータ数 / 学習データ数 = 100くらい差があり、圧倒的にラベル付きデータが足りない…!!まあ、足りないからやるわけですが。機械学習やっている人には半教師あり学習なアプローチが使えるだろうということが容易に想像されると思うので、ちょっとそのアプローチを試してみたい。半教師あり学習については、以下が分かりやすい*3

「分かりやすい」といいながら、ほとんど分かってないのだが、その中でも簡単そうな方法として"Self-Training"という手法があった。スライドより引用すると

  • A classifier is trained with a small amount of labeled data
  • The classifier is then used to classify the unlabeled data
  • Typically the most confident unlabeled points, along with the predicted labels are incorporated into the training set
  • The classifier is re-trained and the procedure is repeated

となっていて、日本語にしてみると

  • 少数のラベル付きデータで学習に学習させる
  • その学習機でラベルなしデータを識別させる
  • 識別されたラベルなしデータの中で(確率値とかでもって)信頼性が高いなと思えるものについて、識別されたラベル(識別の問題だったら1とか-1)を付与、学習データに混ぜる
  • 混ぜて増えた学習データを使って、また学習機を学習させる!

という感じ。難しいことは抜きにして何がうれしいかというと「自分で新しくimplementationする必要がなく、既存のソフトウェアのみで簡単にやれる」ということである。ということで、ollにラベルなしデータを識別させて、(閾値を越えるような|上位、下位何%の)データに対して、ラベルを振って、また学習というのをやればできそうである!俺でもできる!!

何回も繰り返し学習させるということでollをライブラリ的に使えるとうれしいだろうなと思ったので、oll_trainとかoll_testを参考にしながらC++でちょっと書いてみた(半教師ありまではまだ書いてない→書きました)。

プログラム

#include <iostream>
#include <fstream>
#include <string>
#include <vector>
#include <map>
#include <oll/oll.hpp>

int main(int argc, char *argv[]) {
  typedef std::vector<std::pair<int, float> > fv_t; // feature vector
  typedef std::vector<float> fvec;

  std::ifstream training_file("./training.txt");

  size_t lineN = 0;
  std::string line;
  std::vector<std::pair<fv_t, int> > examples;

  oll_tool::oll oll;
	
  while (getline(training_file, line)){
	lineN++;
	
	fv_t fv;
	int y = 0;

	oll.parseLine(line, fv, y);
	examples.push_back(std::make_pair(fv, y));
  }

  oll_tool::PA1_s a;

  std::random_shuffle(examples.begin(), examples.end());
  for (size_t j = 0; j < examples.size(); j++){
	oll.trainExample(a, examples[j].first, examples[j].second);
  }

  std::ifstream test_file("./test.txt");
  while (getline(test_file, line)){
	fv_t fv;
	int y = 0;
	oll.parseLine(line, fv, y);
	std::cout << "value : " << oll.classify(fv) << std::endl;
  }  

  return 0;
};

追記

半教師ありっぽくしてみた。

#include <iostream>
#include <fstream>
#include <string>
#include <vector>
#include <map>
#include <oll/oll.hpp>

class OLLWrapper {
  typedef std::vector<std::pair<int, float> > fv_t; // feature vector
  typedef std::vector<float> fvec;
public:
  OLLWrapper() {};
  void add(const std::string& line) {
	fv_t fv;
	int y = 0;
	oll.parseLine(line, fv, y);
	this->do_add(fv, y);
	examples.push_back(std::make_pair(fv, y));
  };
  void train() {
	std::random_shuffle(examples.begin(), examples.end());
	for (size_t j = 0; j < examples.size(); j++){
	  oll.trainExample(a, examples[j].first, examples[j].second);
	}
  };
  float classify(const std::string& line) {
	fv_t fv;
	int y = 0;
	oll.parseLine(line, fv, y);
	return oll.classify(fv);
  };
  void add_unlabeled_data(const std::string& line, int y) {
	fv_t fv;
	int tmp = 0; // ダミー
	oll.parseLine(line, fv, tmp);
	this->do_add(fv, y);
	examples.push_back(std::make_pair(fv, y));
  };
private:
  std::vector<std::pair<fv_t, int> > examples;
  oll_tool::oll oll;
  oll_tool::PA1_s a;
  void do_add(fv_t& fv, int& y) {
	examples.push_back(std::make_pair(fv, y));
  };
};

int main(int argc, char *argv[]) {
  std::ifstream training_file("./training.txt");

  oll_tool::oll oll;
  OLLWrapper wrapper;

  std::string line;
  while (getline(training_file, line)){
	wrapper.add(line);
  }

  wrapper.train();

  std::ifstream test_file("./unknown.txt");
  std::vector<std::string> v;

  while (getline(test_file, line)){
	v.push_back(line);
  }

  for (int i = 0; i < 10; i++) {
	std::cerr << "iteration : " << i << std::endl;
	for (std::vector<std::string>::iterator it = v.begin(); it != v.end(); it++){
	  std::string line = *it;
	  float result = wrapper.classify(line);
	  if (result > 3.5) {
		wrapper.add_unlabeled_data(line, 1);
	  } else if (result < - 3.5) {
		wrapper.add_unlabeled_data(line, -1);
	  }
	}
  }

  for (std::vector<std::string>::iterator it = v.begin(); it != v.end(); it++){
	std::string line = *it;
	float result = wrapper.classify(line);
	std::cout << result << std::endl;
  }
  return 0;
};

Effective Modern C++: 42 Specific Ways to Improve Your Use of C++11 and C++14

Effective Modern C++: 42 Specific Ways to Improve Your Use of C++11 and C++14

*1:oll_lineというのもあるが

*2:どれだけ時間かかってるんだ、、、

*3:こういうのって、Tutorialとかつけてぐぐると分かりやすいの出てきますよね。余談