[Matlab] 重刻類神經網路

人工神經網路 (Artificial Neural Network, ANN) 透過集合眾多數值之神經鍵與神經元(節點)建構出虛擬之類神經網路,為仿生物智慧原型技術之一,藉由神經鍵儲存抽象認知之信息與激發傳遞消息,可進行資料擬合、族群分類、認知辨識、趨勢預測甚至自動控制等領域當中,具有人造之學習、歸納與判斷能力。

類神經網路目前已有眾多種類與架構(詳見wiki),除此之外,神經鍵之如何調整與修正將決定類神經網路之學習效率 。大型程式語言 Matlab 已內建多種類神經網路分析技術,除支援多核心與GPU運算之外,亦針對各種不同反饋神經鍵值之技術進行研究。局部最佳化方法 Levenberg-Marquardt 因具備較快速與穩定之收斂性,已成為 Matlab 訓練網路之預設方式。

為能在更多平台上進行類神經網路之應用並期望後續與 Matlab 分離,先採 Matlab 以簡易、不使用太多內建函數之方式撰寫倒傳遞類神經網路,並採 Levenberg-Marquardt 進行神經鍵值之修正,原始碼與測試資料如下:

[下載]

本類神經網路系統共有5個m檔,其內容為:

1.建構類神經網路
[net]=ANN_Net(nlist,epochs,trys,evals)

輸入參數
nlist:神經元數量。為一維向量,給予各層神經元個數。若欲建立3層架構(輸入層、1隱藏層、輸出層),而輸入層有3種資料,輸出層有2種資料,隱藏層給10個神經元,則 nlist = [3 10 2]。

epochs:總訓練次數。一次即使用全部資料針對全部神經鍵調整一次。

trys:嘗試次數。每次訓練中,依照梯度搜尋更佳之神經鍵值時,允許最大失敗次數。

evals:總核心分析次數。搜尋過程中會執行類神經網路計算目前誤差量,允許最大執行次數。

輸出參數
net:儲存所建構出之類神經網路資訊。



2.由給定資料訓練網路
[net,f]=ANN_Train(net,data)

輸入參數
net:儲存所建構出之類神經網路資訊。

data:欲訓練之資料。可採文字檔以直排方式儲存,若目前資料為3種輸入2種輸出,則文字檔中給予 3+2=5直排資料,再透過 Matlab 內建 load 函數載入進來即可。載入後矩陣 data 維度需為 (n,5),其中 n 為資料點數。

輸出參數
net:儲存所建構出之類神經網路資訊,將添加訓練後資訊如神經鍵值等。

f:訓練成果。若訓練資料提供了2種輸出,則訓練後亦回傳2種成果。



3.由已訓練之網路推估回想
[f]=ANN_Sim(net,data)

輸入參數
net:儲存所建構出之類神經網路資訊。
data:欲回想 (Recall) 之資料。注意若訓練資料為3種輸入(3直排),則回想資料亦同樣需提供3種(也是3直排),但長度可不一樣。

輸出參數
f:回想成果。若訓練資料提供了2種輸出(2直排),故回想只能提供2種輸出(2直排)。



4.局部最佳化方法Levenberg-Marquardt進行神經鍵值修正
[x,val]=ANN_LM(Fk,x0,epochs,trys,evals)

輸入參數
Fk:待分析之函數。

x0:初始值向量。此處表示神經鍵值。

epochs:總訓練次數。

trys:嘗試次數。每次訓練中,依照梯度搜尋更佳之神經鍵值時,允許最大失敗次數。

evals:總核心分析次數。搜尋過程中會執行類神經網路計算目前誤差量,允許最大執行次數。

輸出參數
x:較佳值向量。此處表示修正後之神經鍵值。

val:總誤差量。此處表示執行 Fk 後之向量取長度值,故 Fk 回傳值需為類神經網路之輸出結果與真值之誤差向量。



5.正規化資料還原
[ff]=ANN_Rescale(net,f)

輸入參數
net:儲存所建構出之類神經網路資訊。

f:正規化後之輸出值。

輸出參數
ff:還原後之輸出值。透過原先正規化所紀錄之各資料之最大最小值,即可進行還原工作。



測試範例
共有3個,以下只列出 test3,其餘測試範例詳見原始檔案。

%總共4層,1輸入層、2隱藏層、1輸出
%神經元數量向量為 [輸入2種資料 第1隱藏層10顆神經元 第2隱藏層10顆神經元 輸出1種資料]
%輸入2種資料表示各點 x,y 座標資料,輸出1種資料表示各點對應之 z 座標 
[net]=ANN_Net([2 10 10 1],1000,1000,16000);

%載入欲訓練之資料
datatr=load('ex3_train.txt');

%進行類神經網路訓練
[net,ftr]=ANN_Train(net,datatr);

%載入欲回想之資料
datarecall=load('ex3_recall.txt');

%進行類神經網路回想
[frecall]=ANN_Sim(net,datarecall);


訓練用資料之三維展示圖:


訓練後成果之三維展示圖:


回想後成果之三維展示圖:


留言