# [機器學習公開課筆記(5):神經網絡(Neural Network)——學習](https://www.cnblogs.com/python27/p/MachineLearningWeek05.html)
這一章可能是Andrew Ng講得最不清楚的一章,為什么這么說呢?這一章主要講后向傳播(Backpropagration, BP)算法,Ng花了一大半的時間在講如何計算誤差項δ
,如何計算Δ
的矩陣,以及如何用Matlab去實現后向傳播,然而最關鍵的問題——為什么要這么計算?前面計算的這些量到底代表著什么,Ng基本沒有講解,也沒有給出數學的推導的例子。所以這次內容我不打算照著公開課的內容去寫,在查閱了許多資料后,我想先從一個簡單的神經網絡的梯度推導入手,理解后向傳播算法的基本工作原理以及每個符號代表的實際意義,然后再按照課程的給出BP計算的具體步驟,這樣更有助于理解。
## 簡單神經網絡的后向傳播(Backpropagration, BP)算法
### 1\. 回顧之前的前向傳播(ForwardPropagration, FP)算法
FP算法還是很簡單的,說白了就是根據前一層神經元的值,先加權然后取sigmoid函數得到后一層神經元的值,寫成數學的形式就是:
a(1)\=X
z(2)\=Θ(1)a(1)
a(2)\=g(z(2))
z(3)\=Θ(2)a(2)
a(3)\=g(z(3))
z(4)\=Θ(3)a(3)
a(4)\=g(z(4))
### 2\. 回顧神經網絡的代價函數(不含regularization項)
J(Θ)\=?1m\[∑i\=1m∑k\=1Ky(i)klog(hθ(x(i)))k+(1?y(i)k)log(1?(hθ(x(i)))k)\]
### 3\. 一個簡單神經網絡的BP推導過程
BP算法解決了什么問題?我們已經有了代價函數J(Θ)
,接下來我們需要利用梯度下降算法(或者其他高級優化算法)對J(Θ)進行優化從而得到訓練參數Θ,然而關鍵問題是,優化算法需要傳遞兩個重要的參數,一個代價函數J(Θ),另一個是代價函數的梯度?J(Θ)?Θ
,**BP算法其實就是解決如何計算梯度的問題**。
下面我們從一個簡單的例子入手考慮如何從數學上計算代價函數的梯度,考慮如下簡單的神經網絡(為方便起見,途中已經給出了前向傳播(FP)的計算過程),該神經網絡有三層神經元,對應的有兩個權重矩陣Θ(1)
和Θ(2),為計算梯度我們只需要計算兩個偏導數即可:?J(Θ)?Θ(1)和?J(Θ)?Θ(2)
。

首先我們先計算第2個權重矩陣的偏導數,即??Θ(2)J(Θ)
。首先我們需要在J(Θ)和Θ(2)之間建立聯系,很容易可以看到J(Θ)的值取決于hθ(x),而hθ(x)\=a(3), a3又是由z(3)取sigmoid得到,最后z(3)\=Θ(2)×a(2)
,所以他們之間的聯系可以如下表示:

按照求導的鏈式法則,我們可以先求J(Θ)
對z(3)的導數,然后乘以z(3)對Θ(2)
的導數,即
??Θ(2)J(Θ)\=??z(3)J(Θ)×?z(3)?Θ(2)
由z(3)\=Θ(2)a(2)
不難計算?z(3)?Θ(2)\=(a(2))T,令??z(3)J(Θ)\=δ(3)
,上式可以重寫為
??Θ(2)J(Θ)\=δ(3)(a(2))T
接下來僅需要計算δ(3)
即可,由上一章的內容我們已經知道g′(z)\=g(z)(1?g(z)), hθ(x)\=a(3)\=g(z(3)),忽略前面的1/m∑i\=1m
(這里我們只對一個example推導,最后累加即可)
δ(3)\=?J(Θ)z(3)\=(?y)1g(z(3))g′(z(3))?(1?y)11?g(z(3))\[1?g(z(3))\]′\=?y(1?g(z(3)))+(1?y)g(z(3))\=?y+g(z(3))\=?y+a(3)
至此我們已經得到J(Θ)
對Θ(2)
的偏導數,即
?J(Θ)?Θ(2)\=(a(2))Tδ(3)
δ(3)\=a(3)?y
接下來我們需要求J(Θ)
對Θ(1)的偏導數,J(Θ)對Θ(1)
的依賴關系如下:

根據鏈式求導法則有
?J(Θ)?Θ(1)\=?J(Θ)?z(3)?z(3)?a(2)?a(2)?Θ(1)
我們分別計算等式右邊的三項可得:
?J(Θ)?z(3)\=δ(3)
?z(3)?a(2)\=(Θ(2))T
?a(2)?Θ(1)\=?a(2)?z(2)?z(2)?Θ(1)\=g′(z(2))a(1)
帶入后得
?J(Θ)?Θ(1)\=(a(1))Tδ(3)(Θ(2))Tg′(z(2))
令δ(2)\=δ(3)(Θ(2))Tg′(z(2))
, 上式可以重寫為
?J(Θ)?Θ(1)\=(a(1))Tδ(2)
δ(2)\=δ(3)(Θ(2))Tg′(z(2))
把上面的結果放在一起,我們得到J(Θ)
對兩個權重矩陣的偏導數為:
δ(3)\=a(3)?y
?J(Θ)?Θ(2)\=(a(2))Tδ(3)
δ(2)\=δ(3)(Θ(2))Tg′(z(2))
?J(Θ)?Θ(1)\=(a(1))Tδ(2)
觀察上面的四個等式,我們發現
* 偏導數可以由當前層神經元向量a(l)
與下一層的誤差向量δ(l+1)* 相乘得到
* 當前層的誤差向量δ(l)
可以由下一層的誤差向量δ(l+1)與權重矩陣Δl
* 的乘積得到
所以可以從后往前逐層計算誤差向量(這就是***后向傳播***的來源),然后通過簡單的乘法運算得到代價函數對每一層權重矩陣的偏導數。到這里算是終于明白為什么要計算誤差向量,以及為什么誤差向量之間有遞歸關系了。盡管這里的神經網絡十分簡單,推導過程也不是十分嚴謹,但是通過這個簡單的例子,基本能夠理解后向傳播算法的工作原理了。
## 嚴謹的后向傳播算法(計算梯度)
假設我們有m
個訓練example,L
層神經網絡,并且此處考慮正則項,即
J(Θ)\=?1m\[∑i\=1m∑k\=1Ky(i)klog(hθ(x(i)))k+(1?y(i)k)log(1?(hθ(x(i)))k)\]+λ2m∑l\=1L?1∑i\=1sl∑j\=1sl+1(Θ(l)ji)2
初始化:設置Δ(l)ij\=0
(理解為對第l
層的權重矩陣的偏導累加值)
For i = 1 : m
* 設置 a(1)\=X
* 通過前向傳播算法(FP)計算對各層的預測值a(l)
,其中l\=2,3,4,…,L
* 計算最后一層的誤差向量 δ(L)\=a(L)?y
,利用后向傳播算法(BP)從后至前逐層計算誤差向量 δ(L?1),δ(L?1),…,δ(2), 計算公式為δ(l)\=(Θ(l))Tδ(l+1).?g′(z(l))
* 更新Δ(l)\=Δ(l)+δ(l+1)(a(l))T
end // for
計算梯度:
D(l)ij\=1mΔ(l)ij,j\=0
D(l)ij\=1mΔ(l)ij+λmΘ(l)ij,j≠0
?J(Θ)?Θ(l)\=D(l)
## BP實際運用中的技巧
### 1\. 將參數展開成向量
對于四層三個權重矩陣參數Θ(1),Θ(2),Θ(3)
將其展開成一個參數向量,Matlab code如下:?
1thetaVec = [Theta1(:); Theta2(:); Theta3(:)];
### 2\. 梯度檢查
為了保證梯度計算的正確性,可以用數值解進行檢查,根據導數的定義
dJ(θ)dθ≈J(θ+?)?J(θ??)2?
Matlab Code 如下
1234567for i = 1 : n thetaPlus = theta; thetaPlus(i) = thetaPlus(i) + EPS; thetaMinus = theta; thetaMinus(i) = thetaMinus(i) - EPS; gradApprox(i) = (J(thetaPlus) - J(thetaMinus)) / (2 * EPS);end
最后檢查 gradApprox 是否約等于之前計算的梯度值即可。需要注意的是:因為近似的梯度計算代價很大,在梯度檢查后記得**關閉梯度檢查**的代碼。
### 3\. 隨機初始化
初始權重矩陣的初始化應該打破對稱性 (symmetry breaking),避免使用全零矩陣進行初始化。可以采用隨機數進行初始化,即 Θ(l)ij∈\[??,+?\]
## 如何訓練一個神經網絡
1. 隨機初始化權重矩陣
2. 利用前向傳播算法(FP)計算模型預測值hθ(x)
* 計算代價函數J(Θ)
* 利用后向傳播算法(BP)計算代價函數的梯度 ?J(Θ)?Θ(l)
* 利用數值算法進行梯度檢查(gradient checking),**確保正確后關閉梯度檢查**
* 利用梯度下降(或者其他優化算法)求得最優參數Θ
## 附:一個簡短的后向傳播
- BP神經網絡到c++實現等--機器學習“掐死教程”
- 訓練bp(神經)網絡學會“乘法”--用”蚊子“訓練高射炮
- Ann計算異或&前饋神經網絡20200302
- 神經網絡ANN的表示20200312
- 簡單神經網絡的后向傳播(Backpropagration, BP)算法
- 牛頓迭代法求局部最優(解)20200310
- ubuntu安裝numpy和pip3等
- 從零實現一個神經網絡-numpy篇01
- _美國普林斯頓大學VictorZhou神經網絡神文的改進和翻譯20200311
- c語言-普林斯頓victorZhou神經網絡實現210301
- bp網絡實現xor異或的C語言實現202102
- bp網絡實現xor異或-自動錄入輸入(寫死20210202
- Mnist在python3.6上跑tensorFlow2.0一步一坑20210210
- numpy手寫數字識別-直接用bp網絡識別210201