(1) 最急降下法
予備知識として、最適化アルゴリズムの最急降下法を知る必要があります。
最急降下法は以下のようになります。
(2) Perceptron
パーセプトロンは教師データを用いて学習させた後、2つのデータに線形分離します。
[アルゴリズム]
for 全教師データ
if( wxの符号と教師フラグの符号){
一致 -> continue (何もしない)
一致しない
w<- br="" tx="" w=""> index <- 1="" br="">end
Javaではリンクのようになります。 Rでは下記になります。
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
x <- matrix(c(1, 0, 0, 1, 1, 0, 1, 0, 1, 1, 1, 1), nrow=4, ncol=3, byrow = TRUE) #OR(1,0,0)最初はw0の係数なので1 | |
t <- c(-1, 1, 1, 1) #教師ラベル | |
w <- c(1, 1, 1) #weight(w0, w1, w2) | |
eta <- 0.2 #学習係数 | |
#weightの更新メソッド | |
update <- function(x, t, w) { | |
if (sign(x %*% w) == sign(t)) { | |
return(w) | |
} | |
return(w + eta * x * t) | |
} | |
#全データでweightが更新されなくなるまで繰り返す | |
index <- 1 | |
while (index <= nrow(x)) { | |
tmp <- update(x[index,], t[index], w) | |
if (all(tmp == w)) { # w更新があったか | |
index <- index + 1 # w更新なし | |
} else { # w更新あり | |
w <- tmp | |
index <- 1 # ループの最初に戻す | |
} | |
print(w) | |
} | |
getPredict <- function(x,w){ | |
return (sign(x %*% w)) | |
} | |
plot(x[,2],x[,3], pch = t + 1) | |
abline( c(-w[1] / w[3], -w[2] / w[3])) # y = b + ax abline(b,a) | |
cat("y = ", -w[2] / w[3], "x +", -w[1] / w[3],"\n") | |
cat(c(0,0),"のとき",getPredict(c(1,0,0),w),"\n") | |
cat(c(1,0),"のとき",getPredict(c(1,1,0),w),"\n") | |
cat(c(0,1),"のとき",getPredict(c(1,0,1),w),"\n") | |
cat(c(1,1),"のとき",getPredict(c(1,1,1),w),"\n") |
次に50個の座標を線形分離してみます。今回はy = -x +1 の上下で分離します。
y > -x + 1の時 flag = 1
y < -x + 1の時 flag = -1
利用したファイル
このファイルをRStudioでアップロードしておきます。
[Perceptron.2R]
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
data <- read.csv("random.csv", fileEncoding = "UTF-8",header = F) | |
x <- data.matrix(data[,1:3]) | |
t <- data[,4] | |
w <- c(1, 1, 1) #weight(w0, w1, w2) | |
eta <- 0.2 #学習係数 | |
#weightの更新メソッド | |
update <- function(x, t, w) { | |
if (sign(x %*% w) == sign(t)) { | |
return(w) | |
} | |
return(w + eta * x * t) | |
} | |
#全データでweightが更新されなくなるまで繰り返す | |
index <- 1 | |
while (index <= nrow(x)) { | |
tmp <- update(x[index,], t[index], w) | |
if (all(tmp == w)) { # w更新があったか | |
index <- index + 1 # w更新なし | |
} else { # w更新あり | |
w <- tmp | |
index <- 1 # ループの最初に戻す | |
} | |
print(w) | |
} | |
getPredict <- function(x,w){ | |
return (sign(x %*% w)) | |
} | |
plot(x[,2],x[,3], pch = t + 1) | |
abline( c(-w[1] / w[3], -w[2] / w[3])) # y = b + ax abline(b,a) | |
cat("y = ", -w[2] / w[3], "x +", -w[1] / w[3],"\n") |
参考
http://smrmkt.hatenablog.jp/entry/2013/11/06/223221
0 件のコメント:
コメントを投稿