読者です 読者をやめる 読者になる 読者になる

A*アルゴリズムについて整理

アルゴリズム C++

辞書を参考にしながら。NAISTのI期辺りでやったはずなんだが、かなりすっ飛んでいる。

デジタル人工知能学事典 [CD-ROM付]

デジタル人工知能学事典 [CD-ROM付]

A*アルゴリズムとは?

グラフ探索アルゴリズムの一つ。「開始ノードから現在位置に至るまでのコスト」と「現在位置からゴールまでの推定コスト」の2つのコストを用いてadmissibleな条件(後述)の元でコストが最小であるような経路を効率的に見つけることができるアルゴリズムである。1960年代に開発されたアルゴリズムであるが、50年経った今でもばしばし使われている。

現在いるノードをp、開始ノードからpまでの最小コストをg(p)、pからゴールまでの最小コストをh(p)と書くとすればpを経由して開始ノードからゴールに向かう最小コストは
f(p) = g(p) + h(p)
と書ける。が、h(p)はまだたどってないから分からん、ということでその推定コスト\hat{h}(p)を用いることにすると最小コストの推定値は
\hat{f}(p) = g(p) + \hat{h}(p)
となる。A*アルゴリズムはこの\hat{f}(p)が最小になるようなノードpを展開していくようなアルゴリズムである。ただし、\hat{h}(p)はなんでもいいってわけじゃなくって\hat{h}(p) \leq h(p)でないと最適なものが見つかるかは保証できない。この条件をadmissiblityという。

Admissiblityについて

Admissiblityのような条件なしには「必ず」最小コストのものを見つけるというのは厳しい、というのは直感的に想像がつく。が、なぜ\hat{h}(p) \leq h(p)の条件なのか?については触れておく必要があるだろう。デジタル人工知能学事典 [CD-ROM付]の25pに簡単な証明が載っているのでそれを参考にする。

p_kをゴールとし、開始ノードsからp_kまでの経路を(s, p_1, \cdots, p_k)とするA*がこの経路のうち(s, p_1, \cdots, p_i)まで展開していたとする。このとき、次のノードp_{i+1}も必ず展開されることを確認すればよい。ノードqを最適経路に含まれていないノードだとすると
\hat{f}(p_{i+1}) = g(p_{i+1}) + \hat{h}(p_{i+1}) \leq g(p_{i+1}) + h(p_{i+1}) < g(q)
となり、従ってp_{i+1}よりも先にqが選択されて、探索が終了することはない。ゆえに、すべてのノードpで\hat{h}(p) \leq h(p)ならば、常に最短路経由のゴールを見い出して終了する。

具体的なアルゴリズム

  1. 開始ノードsをOPENに入れる。その最適コストの推定値は\hat{f}(s) = \hat{h}(s)とする。
  2. OPENが空なら目標ノードは見つからず、探索は失敗。終了。
  3. OPENから\hat{f}(p)が最小となるノードpを取り出してCLOSEに移す。
  4. pが目標ノードなら探索成功。ゴールからポインタを開始ノードまでたどることで解の経路を得て終了。
  5. pを展開し、すべての継続ノードを生成し、各々の\hat{f}値を計算する。
  6. 継続ノードのうちOPEN、CLOSEいずれにも含まれていないものをOPENに入れる。そのノードからpへポインタを設定する。
  7. 継続ノードでOPEN、CLOSEのいずれかすでに含まれているものは、\hat{f}の旧値と5で計算した\hat{f}の新値とを比較し、新値が小さければ\hat{f}を新値に更新して、その継続ノードのポインタを新たな親ノードpに付け替える。継続ノードがCLOSEに含まれていて\hat{f}の新値が旧値より小になったならば、その継続ノードをOPENに移す。
  8. 2へ。

C++によるA*アルゴリズムの実装

人工知能論辺りで実装したC++のコードが手元にあったので参考のために載せておく。

#include <iostream>
#include <vector>
#include <boost/bind.hpp>
#include <tr1/memory>

class Node {
public:
  int x, y, cost; // costはこのノードにくるまでの最短パス上での累積和(f'(n))
  std::tr1::shared_ptr<Node> parent;
  Node(int x_, int y_, int cost_) {
    x = x_; y = y_; cost = cost_;
  };
};

class A_Star {
  typedef std::tr1::shared_ptr<Node> NodePointer;
  typedef std::vector<NodePointer> List;
 public:
  int start_pos_x, start_pos_y, goal_pos_x, goal_pos_y;
  int **cost, **heuristic;
  List L1, L2;

  A_Star(std::pair<int, int> start, std::pair<int, int> goal, int **cost_,  int **heuristic_) {
    start_pos_x = start.first; start_pos_y = start.second;
    goal_pos_x = goal.first; goal_pos_y = goal.second;
    cost = cost_;
    heuristic= heuristic_;
  };

  bool Compare(const NodePointer& a, const NodePointer& b) {
    return a->cost + heuristic[a->x][a->y] < b->cost + heuristic[b->x][b->y];
  };

  bool is_goal_node(NodePointer node) {
    return node->x == goal_pos_x && node->y == goal_pos_y;
  };

  bool is_included_in(List list, NodePointer node) {
    for(List::iterator it = list.begin();it != list.end(); ++it) {
      if (node->x == (*it)->x && node->y == (*it)->y) return true;
    }
    return false;
  };

  // これまでにノードが生成されていたらそれを返し、まだ生成されていなかったら新たに生成
  NodePointer get_node(NodePointer source_node, int target_x, int target_y) {
    NodePointer node = NodePointer(new Node(target_x, target_y, 
					    source_node->cost + cost[target_x][target_y]));
	
    if (is_included_in(L1, node)) { 
      for(List::iterator it = L1.begin(); it != L1.end(); ++it) {
	if (target_x == (*it)->x && target_y == (*it)->y) return *it;
      }
    } else if (is_included_in(L2, node)) {
      for(List::iterator it = L2.begin(); it != L2.end(); ++it) {
	if (target_x == (*it)->x && target_y == (*it)->y) return *it;
      }
    } else { 
      // not generated yet...
    }
    return node;
  };

  List expand(NodePointer node) {
    List result;
    
    int x = node->x; int y = node->y; 
    // 上から時計回り(xのほうが縦、yのほうが横)
    int x1 = -1; int y1 = 0; // 上
    int x2 = 0; int y2 = 1; // 右
    int x3 = 1; int y3 = 0; // 下
    int x4 = 0; int y4 = -1; // 左

    if (0 <= node->x + x1) { // 上
      result.push_back(get_node(node, x + x1, y + y1));
    }
    if (y + y2 <= 4) { // 右
      result.push_back(get_node(node, x + x2, y + y2));
    }
    if (x + x3 <= 4) { // 下
      result.push_back(get_node(node, x + x3, y + y3));
    }
    if (0 <= y + y4) { // 左
      result.push_back(get_node(node, x + x4, y + y4));
    }
    return result;
  };

  void search() {
    // 探索の初期点をL1に入れる
    L1.push_back(NodePointer(new Node(start_pos_x, start_pos_y, 
				      cost[start_pos_x][start_pos_y])));
    while(true) {
      if (L1.empty()) {
	std::cout << "Search failed..." << std::endl;
	break;
      } else {
	// リストL1の先頭の節点nと取り除き、リストL2に入れる
	NodePointer n = L1.front();
	L2.push_back(n);
	List::iterator it = L1.begin();
	L1.erase(it);

	if (is_goal_node(n)) break;

	List expanded_nodes = expand(n);

	std::cout << "Target Node x: " << n->x << ", y: " << n->y 
		  << ", cost: " << n->cost 
		  << "(" << n->cost + heuristic[n->x][n->y] << ")" << std::endl;

	for(it = expanded_nodes.begin(); it != expanded_nodes.end(); ++it) {
	  NodePointer n_i = (*it);

	  std::cout << "\tExpanded Node x: " << n_i->x << ", y: " << n_i->y 
		    << ", cost: " << n_i->cost 
		    << "(" << n_i->cost + heuristic[n_i->x][n_i->y] << ")" << std::endl;
		
	  int f_prime = n->cost + cost[n_i->x][n_i->y];
	  if (! is_included_in(L1, n_i) && ! is_included_in(L2, n_i)) { // n_iがL1にもL2にも含まれない
	    std::cout << "\tCase 1" << std::endl;
	    n_i->cost = f_prime;
	    n_i->parent = n;
	    L1.push_back(n_i);
	  } else if(is_included_in(L1, n_i)) { // n_iがL1に含まれている
	    if (f_prime < n_i->cost) {
	      std::cout << "\tCase 2(" << f_prime << ", " << n_i->cost << ")" << std::endl;
	      n_i->cost = f_prime;
	      n_i->parent = n;
	    }
	  } else if(is_included_in(L2, n_i)) { // n_iがL2に含まれている
	    if (f_prime < n_i->cost) {
	      std::cout << "\tCase 3(" << f_prime << ", " << n_i->cost << ")" << std::endl;
	      n_i->cost = f_prime;
	      // n_iをL2から取り除く
	      int index = 0;
	      for(int i = 0; i < (int) L2.size(); ++i) {
		NodePointer tmp = L2.at(i);
		if (n_i->x == tmp->x && n_i->y == tmp->y) {
		  index = i;
		  L1.push_back(tmp);
		}
	      }
	      List::iterator tmp = L2.begin();
	      L2.erase(tmp + index);
	    }
	  } else {
	    // do nothing
	  }
	  sort(L1.begin(), L1.end(), boost::bind(&A_Star::Compare, this, _1, _2));
	}
      }
    }
  };
  void back_track() {
    List path;
    NodePointer node = L2.back();
    NodePointer parent = node->parent;
  
    path.push_back(node);
    // Goalからバックトラック
    while(parent != NULL) {
      node = parent;
      parent = node->parent;
      path.push_back(node);
    }

    reverse(path.begin(), path.end()); // Startから出力させたいので、逆向きに
    for(List::iterator it = path.begin(); it != path.end(); ++it) {
      std::cout << "(" << (*it)->x << ", " << (*it)->y << ") " << (*it)->cost << std::endl;
    }
  };
};


int main(int argc, char **argv) {
  {
	int *maze_cost_array[5];
	int maze_cost_array0[5] = {6, 5, 8, 9, 0};
	int maze_cost_array1[5] = {3, 1, 3, 9, 2};
	int maze_cost_array2[5] = {0, 2, 2, 2, 3};
	int maze_cost_array3[5] = {1, 3, 1, 2, 4};
	int maze_cost_array4[5] = {0, 0, 4, 2, 1};

	maze_cost_array[0] = maze_cost_array0;
	maze_cost_array[1] = maze_cost_array1;
	maze_cost_array[2] = maze_cost_array2;
	maze_cost_array[3] = maze_cost_array3;
	maze_cost_array[4] = maze_cost_array4;

	int *manhattanma_distance[5];
	int manhattanma_distance0[5] = {4, 3, 2, 1, 0};
	int manhattanma_distance1[5] = {5, 4, 3, 2, 1};
	int manhattanma_distance2[5] = {6, 5, 4, 3, 2};
	int manhattanma_distance3[5] = {7, 6, 5, 4, 3};
	int manhattanma_distance4[5] = {8, 7, 6, 5, 4};

	manhattanma_distance[0] = manhattanma_distance0;
	manhattanma_distance[1] = manhattanma_distance1;
	manhattanma_distance[2] = manhattanma_distance2;
	manhattanma_distance[3] = manhattanma_distance3;
	manhattanma_distance[4] = manhattanma_distance4;
  
	A_Star a_star(std::pair<int, int>(4, 0), 
				  std::pair<int, int>(0, 4), 
				  maze_cost_array, manhattanma_distance);
	a_star.search();
	a_star.back_track();
  };

  {
	int *maze_cost_array[5];
	int maze_cost_array0[5] = {0, 0, 2, 0, 0};
	int maze_cost_array1[5] = {3, 10, 3, 900, 200};
	int maze_cost_array2[5] = {10, 1, 2, 2, 3};
	int maze_cost_array3[5] = {1, 2, 1, 2, 4};
	int maze_cost_array4[5] = {0, 3, 6, 2, 1};

	maze_cost_array[0] = maze_cost_array0;
	maze_cost_array[1] = maze_cost_array1;
	maze_cost_array[2] = maze_cost_array2;
	maze_cost_array[3] = maze_cost_array3;
	maze_cost_array[4] = maze_cost_array4;

	int *manhattanma_distance[5];
	int manhattanma_distance0[5] = {4, 3, 2, 1, 0};
	int manhattanma_distance1[5] = {5, 4, 3, 2, 100};
	int manhattanma_distance2[5] = {11, 6, 50, 4, 3};
	int manhattanma_distance3[5] = {11, 6, 50, 4, 3};
	int manhattanma_distance4[5] = {8, 5, 3, 50, 4};

	manhattanma_distance[0] = manhattanma_distance0;
	manhattanma_distance[1] = manhattanma_distance1;
	manhattanma_distance[2] = manhattanma_distance2;
	manhattanma_distance[3] = manhattanma_distance3;
	manhattanma_distance[4] = manhattanma_distance4;
  
	A_Star a_star(std::pair<int, int>(4, 0), 
				  std::pair<int, int>(0, 4), 
				  maze_cost_array, manhattanma_distance);
	a_star.search();
	a_star.back_track();
  };
  
  return 0;
};