最急降下法によるパラメータの推定

最小2乗法は知ってるので、ぶっとばしてパーセプトロン。

線形の和に対して、閾値関数*1やロジステック関数をかましてやったものが、ここで言うyのことなのか。で、なんか説明があって、次へ行く。

ADALINE

このモデルでの学習は、教師の答えとネッ トワークの出力との平均2乗誤差を最小とするような結合重み(a_0,\cdots,a_M)を最急降下法によって求めるものです。

よし、日本語でおk。irisのデータを使うらしいよ。

と思ったら

最小2乗法は解析的にパラメータ計算するのしか知らないを思い出したので、「最急降下法によるパラメータの推定」のところから。

データの読み込み。

students <- textConnection("
生徒番号	ボール投げ	握力	身長	体重
1	22	28	146	34
2	36	46	169	57
3	24	39	160	48
4	22	25	156	38
5	27	34	161	47
6	29	29	168	50
7	26	38	154	54
8	23	23	153	40
9	31	42	160	62
10	24	27	152	39
11	23	35	155	46
12	27	39	154	54
13	31	38	157	57
14	25	32	162	53
15	23	25	142	32
")
close.connection(textConnection(students))
students <- read.table(students,header=T)

で、適当に計算させる。a0に関してはupdateさせなかった。てか、入れると変になる(ほとんど更新されないとか)。updateさせると変になる理由がいまいち分かってない。あとで考える。

y <- function(y,x1,x2,x3,a1,a2,a3){
  return(mean(y) + a1*(x1-mean(x1)) + a2*(x2-mean(x2)) + a3*(x3-mean(x3)))
}

#この辺はevalとかで動的に作ったほうがいいけど、まあ今はいいや
deda1 <- function(y,x1,x2,x3,a1,a2,a3){
  -2/15 * sum((y-y(y,x1,x2,x3,a1,a2,a3)) * x1) 
}

deda2 <- function(y,x1,x2,x3,a1,a2,a3){
  -2/15 * sum((y-y(y,x1,x2,x3,a1,a2,a3)) * x2) 
}

deda3 <- function(y,x1,x2,x3,a1,a2,a3){
  -2/15 * sum((y-y(y,x1,x2,x3,a1,a2,a3)) * x3) 
}

students[c(-1,-2)]<- students[c(-1,-2)]/100
students

#てきとーな初期値とパラメータ
a1 <- 1
a2 <- 1
a3 <- 1
alpha <- 10

for(i in seq(1000)){
  a1 <- a1 - alpha * deda1(students$ボール投げ,students$握力,students$身長,students$体重,a1,a2,a3)
  a2 <- a2 - alpha * deda2(students$ボール投げ,students$握力,students$身長,students$体重,a1,a2,a3)
  a3 <- a3 - alpha * deda3(students$ボール投げ,students$握力,students$身長,students$体重,a1,a2,a3)
}

てか、あれか。行列使って計算しろよ、って話ですね。

それぞれのパラメータ。こんな感じになる。a0だけはa1、a2、a3の関係式から求めてやった。

> (a0 <- mean(students$ボール投げ) - a1*mean(students$握力) - a2*mean(students$身長) - a3*mean(students$体重))
[1] -13.21730
> a1
[1] 20.13769
> a2
[1] 17.10246
> a3
[1] 12.49428

解析的に求めたパラメータの値。さっきのがまあまあよさげなことであることが分かる。

> lm(ボール投げ~握力+身長+体重,data=students)

Call:
lm(formula = ボール投げ ~ 握力 + 身長 + 体重, data = students)

Coefficients:
(Intercept)         握力         身長         体重  
     -13.22        20.14        17.10        12.49  

*1:indicator functionの0じゃなくて-1バージョンみたいなやつ