前向きアルゴリズム、Vitebiアルゴリズム

Viterbi書くの何回目だろ。。。週末にはバウムウェルチ(Baum-Welch)のアルゴリズムこと前向き後ろ向きアルゴリズムを書きたいところ。
www.yasuhisay.info

# -*- coding: utf-8 -*-
# 確率的言語モデル(東京大学出版)第4章(隠れマルコフモデル)

require 'pp'

# 品詞iから品詞jへの遷移確率
# 番号と品詞の対応付け: 0(名詞)、1(冠詞)、2(動詞)、3(形容詞)、4(前置詞)
a = [[0.3, 0.0, 0.4, 0.1, 0.2],
     [0.7, 0.0, 0.0, 0.3, 0.0],
     [0.3, 0.2, 0.1, 0.2, 0.2],
     [0.5, 0.1, 0.0, 0.4, 0.0],
     [0.6, 0.3, 0.0, 0.1, 0.0]]

# 品詞jから単語iを生成する確率
# 番号と単語の対応付け: 0(Time)、1(flies)、2(like)、3(an)、4(arrow)

b = [[0.6, 0.1, 0.0, 0.0, 0.3],
     [0.0, 0.0, 0.0, 1.0, 0.0],
     [0.1, 0.2, 0.7, 0.0, 0.0],
     [0.0, 0.0, 1.0, 0.0, 0.0],
     [0.0, 0.0, 1.0, 0.0, 0.0]]

o = [0, 1, 2, 3, 4] # Time flies like an arrow.

# それぞれの品詞が文頭にくる確率(品詞の初期確率)
pi = [0.6, 0.4, 0.0, 0.0, 0.0] 

# 前向きアルゴリズム
# 文が生起した確率を計算(品詞は考えない)
# 様々なパスを足し合わせるが、途中までの累積確率をうまく使う
def forward(a, b, pi, o)
  large_t = o.size - 1
  large_n = pi.size

  puts "T: #{large_t}, N: #{large_n}"

  init = []
  (0..large_n-1).each{|i|
    init.push pi[i] * b[i][0]
  }

  alpha = [init]
  large_t.times do
    alpha.push []
  end

  (0..large_t - 1).each{|t|
    puts "----- t = #{t+2} -----"
    (0..large_n - 1).each{|j|
      tmp = []
      sum = 0 # sum_{i=1}^N alpha_t(i) a_ij
      (0..large_n - 1).each{|i|
        tmp.push "(#{alpha[t][i]} * #{a[i][j]})"
        sum += alpha[t][i] * a[i][j]
      }
      puts "(#{tmp.join(" + ")})"
      puts "\t* #{b[j][t+1]} = #{sum * b[j][t+1]}"
      alpha[t+1][j] = sum * b[j][t+1]
    }
  }

  p = 0
  (0..large_n-1).each{|i|
    p += alpha[large_n-1][i]
  }

  return p
end

# Vitebi Algorithm
# 前向きアルゴリズムではoが生起した確率を求めるが、その生起したパスの中でどのパスが最も確率の高いパスかを計算する
def viterbi(a, b, pi, o)
  large_t = o.size - 1
  large_n = pi.size

  init = []
  
  # psi = [Hash.new, Hash.new, Hash.new, Hash.new, Hash.new]
  psi = [Hash.new]

  (0..large_n-1).each{|i|
    init.push pi[i] * b[i][0]
    psi.push Hash.new
    psi[0][i] = 0
  }

  delta = [init]
  large_t.times do
    delta.push []
  end

  (0..large_t - 1).each{|t|
    (0..large_n - 1).each{|j|
      max = 0 
      (0..large_n - 1).each{|i|
        if delta[t][i] * a[i][j] > max
          max = delta[t][i] * a[i][j] 
          psi[t+1][j] = i
        end
      }
      delta[t+1][j] = max * b[j][t+1]
    }
  }

  p_hat = 0
  q_large_t = 0

  (0..large_n-1).each{|i|
    if delta[large_n-1][i] > p_hat
      p_hat = delta[large_n-1][i] 
      q_large_t  = i
    end
  }

  q_hat = Hash.new
  q_hat[large_t] = q_large_t

  # backtrack
  prev = 0
  (large_t-1).downto(0){|t|
    q_hat[t] = psi[t+1][q_hat[t+1]]

  }
  pp q_hat

  return p_hat
end

puts "=" * 50
puts forward(a, b, pi, o)
puts "=" * 50
puts viterbi(a, b, pi, o)
puts "=" * 50

日本語入力を支える技術 ?変わり続けるコンピュータと言葉の世界 (WEB+DB PRESS plus)

日本語入力を支える技術 ?変わり続けるコンピュータと言葉の世界 (WEB+DB PRESS plus)