# 6.6 通過時間反向傳播
在前面兩節中,如果不裁剪梯度,模型將無法正常訓練。為了深刻理解這一現象,本節將介紹循環神經網絡中梯度的計算和存儲方法,即通過時間反向傳播(back-propagation through time)。
我們在3.14節(正向傳播、反向傳播和計算圖)中介紹了神經網絡中梯度計算與存儲的一般思路,并強調正向傳播和反向傳播相互依賴。正向傳播在循環神經網絡中比較直觀,而通過時間反向傳播其實是反向傳播在循環神經網絡中的具體應用。我們需要將循環神經網絡按時間步展開,從而得到模型變量和參數之間的依賴關系,并依據鏈式法則應用反向傳播計算并存儲梯度。
## 6.6.1 定義模型
簡單起見,我們考慮一個無偏差項的循環神經網絡,且激活函數為恒等映射(`$ \phi(x)=x $`)。設時間步 `$ t $` 的輸入為單樣本 `$ \boldsymbol{x}_t \in \mathbb{R}^d $`,標簽為 `$ y_t $`,那么隱藏狀態 `$ \boldsymbol{h}_t \in \mathbb{R}^h $`的計算表達式為
```[tex]
\boldsymbol{h}_t = \boldsymbol{W}_{hx} \boldsymbol{x}_t + \boldsymbol{W}_{hh} \boldsymbol{h}_{t-1},
```
其中`$ \boldsymbol{W}_{hx} \in \mathbb{R}^{h \times d} $`和`$ \boldsymbol{W}_{hh} \in \mathbb{R}^{h \times h} $`是隱藏層權重參數。設輸出層權重參數`$ \boldsymbol{W}_{qh} \in \mathbb{R}^{q \times h} $`,時間步`$ t $`的輸出層變量`$ \boldsymbol{o}_t \in \mathbb{R}^q $`計算為
```[tex]
\boldsymbol{o}_t = \boldsymbol{W}_{qh} \boldsymbol{h}_{t}.
```
設時間步`$ t $`的損失為`$ \ell(\boldsymbol{o}_t, y_t) $`。時間步數為$T$的損失函數`$ L $`定義為
```[tex]
L = \frac{1}{T} \sum_{t=1}^T \ell (\boldsymbol{o}_t, y_t).
```
我們將`$ L $`稱為有關給定時間步的數據樣本的目標函數,并在本節后續討論中簡稱為目標函數。
## 6.6.2 模型計算圖
為了可視化循環神經網絡中模型變量和參數在計算中的依賴關系,我們可以繪制模型計算圖,如圖6.3所示。例如,時間步3的隱藏狀態`$ \boldsymbol{h}_3 $`的計算依賴模型參數`$ \boldsymbol{W}_{hx} $`、`$ \boldsymbol{W}_{hh} $`、上一時間步隱藏狀態`$ \boldsymbol{h}_2 $`以及當前時間步輸入`$ \boldsymbol{x}_3 $`。
:-: 
<div align=center>圖6.3 時間步數為3的循環神經網絡模型計算中的依賴關系。方框代表變量(無陰影)或參數(有陰影),圓圈代表運算符</div>
## 6.6.3 方法
剛剛提到,圖6.3中的模型的參數是 `$ \boldsymbol{W}_{hx} $`, `$ \boldsymbol{W}_{hh} $` 和 `$ \boldsymbol{W}_{qh} $`。與3.14節(正向傳播、反向傳播和計算圖)中的類似,訓練模型通常需要模型參數的梯度`$ \partial L/\partial \boldsymbol{W}_{hx} $`、`$ \partial L/\partial \boldsymbol{W}_{hh} $`和`$ \partial L/\partial \boldsymbol{W}_{qh} $`。
根據圖6.3中的依賴關系,我們可以按照其中箭頭所指的反方向依次計算并存儲梯度。為了表述方便,我們依然采用3.14節中表達鏈式法則的運算符prod。
首先,目標函數有關各時間步輸出層變量的梯度`$ \partial L/\partial \boldsymbol{o}_t \in \mathbb{R}^q $`很容易計算:
```[tex]
\frac{\partial L}{\partial \boldsymbol{o}_t} = \frac{\partial \ell (\boldsymbol{o}_t, y_t)}{T \cdot \partial \boldsymbol{o}_t}.
```
下面,我們可以計算目標函數有關模型參數`$ \boldsymbol{W}_{qh} $`的梯度`$ \partial L/\partial \boldsymbol{W}_{qh} \in \mathbb{R}^{q \times h} $`。根據圖6.3,`$ L $`通過`$ \boldsymbol{o}_1, \ldots, \boldsymbol{o}_T $`依賴`$ \boldsymbol{W}_{qh} $`。依據鏈式法則,
```[tex]
\frac{\partial L}{\partial \boldsymbol{W}_{qh}}
= \sum_{t=1}^T \text{prod}\left(\frac{\partial L}{\partial \boldsymbol{o}_t}, \frac{\partial \boldsymbol{o}_t}{\partial \boldsymbol{W}_{qh}}\right)
= \sum_{t=1}^T \frac{\partial L}{\partial \boldsymbol{o}_t} \boldsymbol{h}_t^\top.
```
其次,我們注意到隱藏狀態之間也存在依賴關系。
在圖6.3中,`$ L $`只通過`$ \boldsymbol{o}_T $`依賴最終時間步$T$的隱藏狀態`$ \boldsymbol{h}_T $`。因此,我們先計算目標函數有關最終時間步隱藏狀態的梯度`$ \partial L/\partial \boldsymbol{h}_T \in \mathbb{R}^h $`。依據鏈式法則,我們得到
```[tex]
\frac{\partial L}{\partial \boldsymbol{h}_T} = \text{prod}\left(\frac{\partial L}{\partial \boldsymbol{o}_T}, \frac{\partial \boldsymbol{o}_T}{\partial \boldsymbol{h}_T} \right) = \boldsymbol{W}_{qh}^\top \frac{\partial L}{\partial \boldsymbol{o}_T}.
```
接下來對于時間步`$ t < T $`, 在圖6.3中,`$ L $`通過`$ \boldsymbol{h}_{t+1} $` 和`$ \boldsymbol{o}_t $`依賴`$ \boldsymbol{h}_t $`。依據鏈式法則,
目標函數有關時間步`$ t < T $`的隱藏狀態的梯度`$ \partial L/\partial \boldsymbol{h}_t \in \mathbb{R}^h $`需要按照時間步從大到小依次計算:
```[tex]
\frac{\partial L}{\partial \boldsymbol{h}_t}
= \text{prod} (\frac{\partial L}{\partial \boldsymbol{h}_{t+1}}, \frac{\partial \boldsymbol{h}_{t+1}}{\partial \boldsymbol{h}_t}) + \text{prod} (\frac{\partial L}{\partial \boldsymbol{o}_t}, \frac{\partial \boldsymbol{o}_t}{\partial \boldsymbol{h}_t} ) = \boldsymbol{W}_{hh}^\top \frac{\partial L}{\partial \boldsymbol{h}_{t+1}} + \boldsymbol{W}_{qh}^\top \frac{\partial L}{\partial \boldsymbol{o}_t}
```
將上面的遞歸公式展開,對任意時間步`$ 1 \leq t \leq T $` ,我們可以得到目標函數有關隱藏狀態梯度的通項公式
```[tex]
\frac{\partial L}{\partial \boldsymbol{h}_t}
= \sum_{i=t}^T {\left(\boldsymbol{W}_{hh}^\top\right)}^{T-i} \boldsymbol{W}_{qh}^\top \frac{\partial L}{\partial \boldsymbol{o}_{T+t-i}}.
```
由上式中的指數項可見,當時間步數 `$ T $` 較大或者時間步 `$ t $` 較小時,目標函數有關隱藏狀態的梯度較容易出現衰減和爆炸。這也會影響其他包含`$ \partial L / \partial \boldsymbol{h}_t $`項的梯度,例如隱藏層中模型參數的梯度`$ \partial L / \partial \boldsymbol{W}_{hx} \in \mathbb{R}^{h \times d} $`和`$ \partial L / \partial \boldsymbol{W}_{hh} \in \mathbb{R}^{h \times h} $`。
在圖6.3中,`$ L $`通過`$ \boldsymbol{h}_1, \ldots, \boldsymbol{h}_T $`依賴這些模型參數。
依據鏈式法則,我們有
```[tex]
\begin{aligned}
\frac{\partial L}{\partial \boldsymbol{W}_{hx}}
&= \sum_{t=1}^T \text{prod}\left(\frac{\partial L}{\partial \boldsymbol{h}_t}, \frac{\partial \boldsymbol{h}_t}{\partial \boldsymbol{W}_{hx}}\right)
= \sum_{t=1}^T \frac{\partial L}{\partial \boldsymbol{h}_t} \boldsymbol{x}_t^\top,\\
\frac{\partial L}{\partial \boldsymbol{W}_{hh}}
&= \sum_{t=1}^T \text{prod}\left(\frac{\partial L}{\partial \boldsymbol{h}_t}, \frac{\partial \boldsymbol{h}_t}{\partial \boldsymbol{W}_{hh}}\right)
= \sum_{t=1}^T \frac{\partial L}{\partial \boldsymbol{h}_t} \boldsymbol{h}_{t-1}^\top.
\end{aligned}
```
我們已在3.14節里解釋過,每次迭代中,我們在依次計算完以上各個梯度后,會將它們存儲起來,從而避免重復計算。例如,由于隱藏狀態梯度`$ \partial L/\partial \boldsymbol{h}_t $`被計算和存儲,之后的模型參數梯度`$ \partial L/\partial \boldsymbol{W}_{hx} $`和`$ \partial L/\partial \boldsymbol{W}_{hh} $`的計算可以直接讀取`$ \partial L/\partial \boldsymbol{h}_t $`的值,而無須重復計算它們。此外,反向傳播中的梯度計算可能會依賴變量的當前值。它們正是通過正向傳播計算出來的。
舉例來說,參數梯度`$ \partial L/\partial \boldsymbol{W}_{hh} $`的計算需要依賴隱藏狀態在時間步`$ t = 0, \ldots, T-1 $`的當前值`$ \boldsymbol{h}_t $`(`$ \boldsymbol{h}_0 $`是初始化得到的)。這些值是通過從輸入層到輸出層的正向傳播計算并存儲得到的。
## 小結
* 通過時間反向傳播是反向傳播在循環神經網絡中的具體應用。
* 當總的時間步數較大或者當前時間步較小時,循環神經網絡的梯度較容易出現衰減或爆炸。
------------
> 注:本節與原書基本相同,[原書傳送門](https://zh.d2l.ai/chapter_recurrent-neural-networks/bptt.html)
- Home
- Introduce
- 1.深度學習簡介
- 深度學習簡介
- 2.預備知識
- 2.1環境配置
- 2.2數據操作
- 2.3自動求梯度
- 3.深度學習基礎
- 3.1 線性回歸
- 3.2 線性回歸的從零開始實現
- 3.3 線性回歸的簡潔實現
- 3.4 softmax回歸
- 3.5 圖像分類數據集(Fashion-MINST)
- 3.6 softmax回歸的從零開始實現
- 3.7 softmax回歸的簡潔實現
- 3.8 多層感知機
- 3.9 多層感知機的從零開始實現
- 3.10 多層感知機的簡潔實現
- 3.11 模型選擇、反向傳播和計算圖
- 3.12 權重衰減
- 3.13 丟棄法
- 3.14 正向傳播、反向傳播和計算圖
- 3.15 數值穩定性和模型初始化
- 3.16 實戰kaggle比賽:房價預測
- 4 深度學習計算
- 4.1 模型構造
- 4.2 模型參數的訪問、初始化和共享
- 4.3 模型參數的延后初始化
- 4.4 自定義層
- 4.5 讀取和存儲
- 4.6 GPU計算
- 5 卷積神經網絡
- 5.1 二維卷積層
- 5.2 填充和步幅
- 5.3 多輸入通道和多輸出通道
- 5.4 池化層
- 5.5 卷積神經網絡(LeNet)
- 5.6 深度卷積神經網絡(AlexNet)
- 5.7 使用重復元素的網絡(VGG)
- 5.8 網絡中的網絡(NiN)
- 5.9 含并行連結的網絡(GoogLeNet)
- 5.10 批量歸一化
- 5.11 殘差網絡(ResNet)
- 5.12 稠密連接網絡(DenseNet)
- 6 循環神經網絡
- 6.1 語言模型
- 6.2 循環神經網絡
- 6.3 語言模型數據集(周杰倫專輯歌詞)
- 6.4 循環神經網絡的從零開始實現
- 6.5 循環神經網絡的簡單實現
- 6.6 通過時間反向傳播
- 6.7 門控循環單元(GRU)
- 6.8 長短期記憶(LSTM)
- 6.9 深度循環神經網絡
- 6.10 雙向循環神經網絡
- 7 優化算法
- 7.1 優化與深度學習
- 7.2 梯度下降和隨機梯度下降
- 7.3 小批量隨機梯度下降
- 7.4 動量法
- 7.5 AdaGrad算法
- 7.6 RMSProp算法
- 7.7 AdaDelta
- 7.8 Adam算法
- 8 計算性能
- 8.1 命令式和符號式混合編程
- 8.2 異步計算
- 8.3 自動并行計算
- 8.4 多GPU計算
- 9 計算機視覺
- 9.1 圖像增廣
- 9.2 微調
- 9.3 目標檢測和邊界框
- 9.4 錨框
- 10 自然語言處理
- 10.1 詞嵌入(word2vec)
- 10.2 近似訓練
- 10.3 word2vec實現
- 10.4 子詞嵌入(fastText)
- 10.5 全局向量的詞嵌入(Glove)
- 10.6 求近義詞和類比詞
- 10.7 文本情感分類:使用循環神經網絡
- 10.8 文本情感分類:使用卷積網絡
- 10.9 編碼器--解碼器(seq2seq)
- 10.10 束搜索
- 10.11 注意力機制
- 10.12 機器翻譯