# 6.7 門控循環單元(GRU)
上一節介紹了循環神經網絡中的梯度計算方法。我們發現,當時間步數較大或者時間步較小時,循環神經網絡的梯度較容易出現衰減或爆炸。雖然裁剪梯度可以應對梯度爆炸,但無法解決梯度衰減的問題。通常由于這個原因,循環神經網絡在實際中較難捕捉時間序列中時間步距離較大的依賴關系。
門控循環神經網絡(gated recurrent neural network)的提出,正是為了更好地捕捉時間序列中時間步距離較大的依賴關系。它通過可以學習的門來控制信息的流動。其中,門控循環單元(gated recurrent unit,GRU)是一種常用的門控循環神經網絡 [1, 2]。另一種常用的門控循環神經網絡則將在下一節中介紹。
## 6.7.1 門控循環單元
下面將介紹門控循環單元的設計。它引入了重置門(reset gate)和更新門(update gate)的概念,從而修改了循環神經網絡中隱藏狀態的計算方式。
### 6.7.1.1 重置門和更新門
如圖6.4所示,門控循環單元中的重置門和更新門的輸入均為當前時間步輸入`$ \boldsymbol{X}_t $`與上一時間步隱藏狀態`$ \boldsymbol{H}_{t-1} $`,輸出由激活函數為sigmoid函數的全連接層計算得到。
:-: 
<div align=center>圖6.4 門控循環單元中重置門和更新門的計算</div>
具體來說,假設隱藏單元個數為`$ h $`,給定時間步`$ t $`的小批量輸入`$ \boldsymbol{X}_t \in \mathbb{R}^{n \times d} $`(樣本數為`$ n $`,輸入個數為`$ d $`)和上一時間步隱藏狀態`$ \boldsymbol{H}_{t-1} \in \mathbb{R}^{n \times h} $`。重置門`$ \boldsymbol{R}_t \in \mathbb{R}^{n \times h} $`和更新門`$ \boldsymbol{Z}_t \in \mathbb{R}^{n \times h} $`的計算如下:
```[tex]
\begin{aligned}
\boldsymbol{R}_t = \sigma(\boldsymbol{X}_t \boldsymbol{W}_{xr} + \boldsymbol{H}_{t-1} \boldsymbol{W}_{hr} + \boldsymbol{b}_r),\\
\boldsymbol{Z}_t = \sigma(\boldsymbol{X}_t \boldsymbol{W}_{xz} + \boldsymbol{H}_{t-1} \boldsymbol{W}_{hz} + \boldsymbol{b}_z),
\end{aligned}
```
其中`$ \boldsymbol{W}_{xr}, \boldsymbol{W}_{xz} \in \mathbb{R}^{d \times h} $`和`$ \boldsymbol{W}_{hr}, \boldsymbol{W}_{hz} \in \mathbb{R}^{h \times h} $`是權重參數,`$ \boldsymbol{b}_r, \boldsymbol{b}_z \in \mathbb{R}^{1 \times h} $`是偏差參數。3.8節(多層感知機)節中介紹過,sigmoid函數可以將元素的值變換到0和1之間。因此,重置門`$ \boldsymbol{R}_t $`和更新門`$ \boldsymbol{Z}_t $`中每個元素的值域都是`$ [0, 1] $`。
### 6.7.1.2 候選隱藏狀態
接下來,門控循環單元將計算候選隱藏狀態來輔助稍后的隱藏狀態計算。如圖6.5所示,我們將當前時間步重置門的輸出與上一時間步隱藏狀態做按元素乘法(符號為`$ \odot $`)。如果重置門中元素值接近0,那么意味著重置對應隱藏狀態元素為0,即丟棄上一時間步的隱藏狀態。如果元素值接近1,那么表示保留上一時間步的隱藏狀態。然后,將按元素乘法的結果與當前時間步的輸入連結,再通過含激活函數tanh的全連接層計算出候選隱藏狀態,其所有元素的值域為`$ [-1, 1] $`。
:-: 
<div align=center>圖6.5 門控循環單元中候選隱藏狀態的計算</div>
具體來說,時間步`$ t $`的候選隱藏狀態`$ \tilde{\boldsymbol{H}}_t \in \mathbb{R}^{n \times h} $`的計算為
```[tex]
\tilde{\boldsymbol{H}}_t = \text{tanh}(\boldsymbol{X}_t \boldsymbol{W}_{xh} + \left(\boldsymbol{R}_t \odot \boldsymbol{H}_{t-1}\right) \boldsymbol{W}_{hh} + \boldsymbol{b}_h),
```
其中`$ \boldsymbol{W}_{xh} \in \mathbb{R}^{d \times h} $`和`$ \boldsymbol{W}_{hh} \in \mathbb{R}^{h \times h} $`是權重參數,`$ \boldsymbol{b}_h \in \mathbb{R}^{1 \times h} $`是偏差參數。從上面這個公式可以看出,重置門控制了上一時間步的隱藏狀態如何流入當前時間步的候選隱藏狀態。而上一時間步的隱藏狀態可能包含了時間序列截至上一時間步的全部歷史信息。因此,重置門可以用來丟棄與預測無關的歷史信息。
### 6.7.1.3 隱藏狀態
最后,時間步`$ t $`的隱藏狀態`$ \boldsymbol{H}_t \in \mathbb{R}^{n \times h} $`的計算使用當前時間步的更新門`$ \boldsymbol{Z}_t $`來對上一時間步的隱藏狀態`$ \boldsymbol{H}_{t-1} $`和當前時間步的候選隱藏狀態`$ \tilde{\boldsymbol{H}}_t $`做組合:
```[tex]
\boldsymbol{H}_t = \boldsymbol{Z}_t \odot \boldsymbol{H}_{t-1} + (1 - \boldsymbol{Z}_t) \odot \tilde{\boldsymbol{H}}_t.
```
:-: 
<div align=center>圖6.6 門控循環單元中隱藏狀態的計算</div>
值得注意的是,更新門可以控制隱藏狀態應該如何被包含當前時間步信息的候選隱藏狀態所更新,如圖6.6所示。假設更新門在時間步`$ t ' $`到`$ t $`(`$ t' < t $`)之間一直近似1。那么,在時間步`$ t' $`到`$ t $`之間的輸入信息幾乎沒有流入時間步$t$的隱藏狀態`$ \boldsymbol{H}_t $`。實際上,這可以看作是較早時刻的隱藏狀態`$ \boldsymbol{H}_{t'-1} $`一直通過時間保存并傳遞至當前時間步`$ t $`。這個設計可以應對循環神經網絡中的梯度衰減問題,并更好地捕捉時間序列中時間步距離較大的依賴關系。
我們對門控循環單元的設計稍作總結:
* 重置門有助于捕捉時間序列里短期的依賴關系;
* 更新門有助于捕捉時間序列里長期的依賴關系。
## 6.7.2 讀取數據集
為了實現并展示門控循環單元,下面依然使用周杰倫歌詞數據集來訓練模型作詞。這里除門控循環單元以外的實現已在6.2節(循環神經網絡)中介紹過。以下為讀取數據集部分。
``` python
import numpy as np
import torch
from torch import nn, optim
import torch.nn.functional as F
import sys
sys.path.append("..")
import d2lzh_pytorch as d2l
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
(corpus_indices, char_to_idx, idx_to_char, vocab_size) = d2l.load_data_jay_lyrics()
```
## 6.7.3 從零開始實現
我們先介紹如何從零開始實現門控循環單元。
### 6.7.3.1 初始化模型參數
下面的代碼對模型參數進行初始化。超參數`num_hiddens`定義了隱藏單元的個數。
``` python
num_inputs, num_hiddens, num_outputs = vocab_size, 256, vocab_size
print('will use', device)
def get_params():
def _one(shape):
ts = torch.tensor(np.random.normal(0, 0.01, size=shape), device=device, dtype=torch.float32)
return torch.nn.Parameter(ts, requires_grad=True)
def _three():
return (_one((num_inputs, num_hiddens)),
_one((num_hiddens, num_hiddens)),
torch.nn.Parameter(torch.zeros(num_hiddens, device=device, dtype=torch.float32), requires_grad=True))
W_xz, W_hz, b_z = _three() # 更新門參數
W_xr, W_hr, b_r = _three() # 重置門參數
W_xh, W_hh, b_h = _three() # 候選隱藏狀態參數
# 輸出層參數
W_hq = _one((num_hiddens, num_outputs))
b_q = torch.nn.Parameter(torch.zeros(num_outputs, device=device, dtype=torch.float32), requires_grad=True)
return nn.ParameterList([W_xz, W_hz, b_z, W_xr, W_hr, b_r, W_xh, W_hh, b_h, W_hq, b_q])
```
### 6.7.3.2 定義模型
下面的代碼定義隱藏狀態初始化函數`init_gru_state`。同6.4節(循環神經網絡的從零開始實現)中定義的`init_rnn_state`函數一樣,它返回由一個形狀為(批量大小, 隱藏單元個數)的值為0的`Tensor`組成的元組。
``` python
def init_gru_state(batch_size, num_hiddens, device):
return (torch.zeros((batch_size, num_hiddens), device=device), )
```
下面根據門控循環單元的計算表達式定義模型。
``` python
def gru(inputs, state, params):
W_xz, W_hz, b_z, W_xr, W_hr, b_r, W_xh, W_hh, b_h, W_hq, b_q = params
H, = state
outputs = []
for X in inputs:
Z = torch.sigmoid(torch.matmul(X, W_xz) + torch.matmul(H, W_hz) + b_z)
R = torch.sigmoid(torch.matmul(X, W_xr) + torch.matmul(H, W_hr) + b_r)
H_tilda = torch.tanh(torch.matmul(X, W_xh) + R * torch.matmul(H, W_hh) + b_h)
H = Z * H + (1 - Z) * H_tilda
Y = torch.matmul(H, W_hq) + b_q
outputs.append(Y)
return outputs, (H,)
```
### 6.7.3.3 訓練模型并創作歌詞
我們在訓練模型時只使用相鄰采樣。設置好超參數后,我們將訓練模型并根據前綴“分開”和“不分開”分別創作長度為50個字符的一段歌詞。
``` python
num_epochs, num_steps, batch_size, lr, clipping_theta = 160, 35, 32, 1e2, 1e-2
pred_period, pred_len, prefixes = 40, 50, ['分開', '不分開']
```
我們每過40個迭代周期便根據當前訓練的模型創作一段歌詞。
```python
d2l.train_and_predict_rnn(gru, get_params, init_gru_state, num_hiddens,
vocab_size, device, corpus_indices, idx_to_char,
char_to_idx, False, num_epochs, num_steps, lr,
clipping_theta, batch_size, pred_period, pred_len,
prefixes)
```
輸出:
```
epoch 40, perplexity 149.477598, time 1.08 sec
- 分開 我不不你 我想你你的愛我 你不你的讓我 你不你的讓我 你不你的讓我 你不你的讓我 你不你的讓我 你
- 不分開 我想你你的讓我 你不你的讓我 你不你的讓我 你不你的讓我 你不你的讓我 你不你的讓我 你不你的讓我
epoch 80, perplexity 31.689210, time 1.10 sec
- 分開 我想要你 我不要再想 我不要再想 我不要再想 我不要再想 我不要再想 我不要再想 我不要再想 我不
- 不分開 我想要你 我不要再想 我不要再想 我不要再想 我不要再想 我不要再想 我不要再想 我不要再想 我不
epoch 120, perplexity 4.866115, time 1.08 sec
- 分開 我想要這樣牽著你的手不放開 愛過 讓我來的肩膀 一起好酒 你來了這節秋 后知后覺 我該好好生活 我
- 不分開 你已經不了我不要 我不要再想你 我不要再想你 我不要再想你 不知不覺 我跟了這節奏 后知后覺 又過
epoch 160, perplexity 1.442282, time 1.51 sec
- 分開 我一定好生憂 唱著歌 一直走 我想就這樣牽著你的手不放開 愛可不可以簡簡單單沒有傷害 你 靠著我的
- 不分開 你已經離開我 不知不覺 我跟了這節奏 后知后覺 又過了一個秋 后知后覺 我該好好生活 我該好好生活
```
## 6.7.4 簡潔實現
在PyTorch中我們直接調用`nn`模塊中的`GRU`類即可。
``` python
lr = 1e-2 # 注意調整學習率
gru_layer = nn.GRU(input_size=vocab_size, hidden_size=num_hiddens)
model = d2l.RNNModel(gru_layer, vocab_size).to(device)
d2l.train_and_predict_rnn_pytorch(model, num_hiddens, vocab_size, device,
corpus_indices, idx_to_char, char_to_idx,
num_epochs, num_steps, lr, clipping_theta,
batch_size, pred_period, pred_len, prefixes)
```
輸出:
```
epoch 40, perplexity 1.022157, time 1.02 sec
- 分開手牽手 一步兩步三步四步望著天 看星星 一顆兩顆三顆四顆 連成線背著背默默許下心愿 看遠方的星是否聽
- 不分開暴風圈來不及逃 我不能再想 我不能再想 我不 我不 我不能 愛情走的太快就像龍卷風 不能承受我已無處
epoch 80, perplexity 1.014535, time 1.04 sec
- 分開始想像 爸和媽當年的模樣 說著一口吳儂軟語的姑娘緩緩走過外灘 消失的 舊時光 一九四三 在回憶 的路
- 不分開始愛像 不知不覺 你已經離開我 不知不覺 我跟了這節奏 后知后覺 又過了一個秋 后知后覺 我該好好
epoch 120, perplexity 1.147843, time 1.04 sec
- 分開都靠我 你拿著球不投 又不會掩護我 選你這種隊友 瞎透了我 說你說 分數怎么停留 所有回憶對著我進攻
- 不分開球我有多煩惱多 牧草有沒有危險 一場夢 我面對我 甩開球我滿腔的怒火 我想揍你已經很久 別想躲 說你
epoch 160, perplexity 1.018370, time 1.05 sec
- 分開愛上你 那場悲劇 是你完美演出的一場戲 寧愿心碎哭泣 再狠狠忘記 你愛過我的證據 讓晶瑩的淚滴 閃爍
- 不分開始 擔心今天的你過得好不好 整個畫面是你 想你想的睡不著 嘴嘟嘟那可愛的模樣 還有在你身上香香的味道
```
## 小結
* 門控循環神經網絡可以更好地捕捉時間序列中時間步距離較大的依賴關系。
* 門控循環單元引入了門的概念,從而修改了循環神經網絡中隱藏狀態的計算方式。它包括重置門、更新門、候選隱藏狀態和隱藏狀態。
* 重置門有助于捕捉時間序列里短期的依賴關系。
* 更新門有助于捕捉時間序列里長期的依賴關系。
## 參考文獻
[1] Cho, K., Van Merri?nboer, B., Bahdanau, D., & Bengio, Y. (2014). On the properties of neural machine translation: Encoder-decoder approaches. arXiv preprint arXiv:1409.1259.
[2] Chung, J., Gulcehre, C., Cho, K., & Bengio, Y. (2014). Empirical evaluation of gated recurrent neural networks on sequence modeling. arXiv preprint arXiv:1412.3555.
-----------
> 注:除代碼外本節與原書此節基本相同,[原書傳送門](https://zh.d2l.ai/chapter_recurrent-neural-networks/gru.html)
- Home
- Introduce
- 1.深度學習簡介
- 深度學習簡介
- 2.預備知識
- 2.1環境配置
- 2.2數據操作
- 2.3自動求梯度
- 3.深度學習基礎
- 3.1 線性回歸
- 3.2 線性回歸的從零開始實現
- 3.3 線性回歸的簡潔實現
- 3.4 softmax回歸
- 3.5 圖像分類數據集(Fashion-MINST)
- 3.6 softmax回歸的從零開始實現
- 3.7 softmax回歸的簡潔實現
- 3.8 多層感知機
- 3.9 多層感知機的從零開始實現
- 3.10 多層感知機的簡潔實現
- 3.11 模型選擇、反向傳播和計算圖
- 3.12 權重衰減
- 3.13 丟棄法
- 3.14 正向傳播、反向傳播和計算圖
- 3.15 數值穩定性和模型初始化
- 3.16 實戰kaggle比賽:房價預測
- 4 深度學習計算
- 4.1 模型構造
- 4.2 模型參數的訪問、初始化和共享
- 4.3 模型參數的延后初始化
- 4.4 自定義層
- 4.5 讀取和存儲
- 4.6 GPU計算
- 5 卷積神經網絡
- 5.1 二維卷積層
- 5.2 填充和步幅
- 5.3 多輸入通道和多輸出通道
- 5.4 池化層
- 5.5 卷積神經網絡(LeNet)
- 5.6 深度卷積神經網絡(AlexNet)
- 5.7 使用重復元素的網絡(VGG)
- 5.8 網絡中的網絡(NiN)
- 5.9 含并行連結的網絡(GoogLeNet)
- 5.10 批量歸一化
- 5.11 殘差網絡(ResNet)
- 5.12 稠密連接網絡(DenseNet)
- 6 循環神經網絡
- 6.1 語言模型
- 6.2 循環神經網絡
- 6.3 語言模型數據集(周杰倫專輯歌詞)
- 6.4 循環神經網絡的從零開始實現
- 6.5 循環神經網絡的簡單實現
- 6.6 通過時間反向傳播
- 6.7 門控循環單元(GRU)
- 6.8 長短期記憶(LSTM)
- 6.9 深度循環神經網絡
- 6.10 雙向循環神經網絡
- 7 優化算法
- 7.1 優化與深度學習
- 7.2 梯度下降和隨機梯度下降
- 7.3 小批量隨機梯度下降
- 7.4 動量法
- 7.5 AdaGrad算法
- 7.6 RMSProp算法
- 7.7 AdaDelta
- 7.8 Adam算法
- 8 計算性能
- 8.1 命令式和符號式混合編程
- 8.2 異步計算
- 8.3 自動并行計算
- 8.4 多GPU計算
- 9 計算機視覺
- 9.1 圖像增廣
- 9.2 微調
- 9.3 目標檢測和邊界框
- 9.4 錨框
- 10 自然語言處理
- 10.1 詞嵌入(word2vec)
- 10.2 近似訓練
- 10.3 word2vec實現
- 10.4 子詞嵌入(fastText)
- 10.5 全局向量的詞嵌入(Glove)
- 10.6 求近義詞和類比詞
- 10.7 文本情感分類:使用循環神經網絡
- 10.8 文本情感分類:使用卷積網絡
- 10.9 編碼器--解碼器(seq2seq)
- 10.10 束搜索
- 10.11 注意力機制
- 10.12 機器翻譯