# 4.5 讀取和存儲
到目前為止,我們介紹了如何處理數據以及如何構建、訓練和測試深度學習模型。然而在實際中,我們有時需要把訓練好的模型部署到很多不同的設備。在這種情況下,我們可以把內存中訓練好的模型參數存儲在硬盤上供后續讀取使用。
## 4.5.1 讀寫`Tensor`
我們可以直接使用`save`函數和`load`函數分別存儲和讀取`Tensor`。`save`使用Python的pickle實用程序將對象進行序列化,然后將序列化的對象保存到disk,使用`save`可以保存各種對象,包括模型、張量和字典等。而`laod`使用pickle unpickle工具將pickle的對象文件反序列化為內存。
下面的例子創建了`Tensor`變量`x`,并將其存在文件名同為`x.pt`的文件里。
``` python
import torch
from torch import nn
x = torch.ones(3)
torch.save(x, 'x.pt')
```
然后我們將數據從存儲的文件讀回內存。
``` python
x2 = torch.load('x.pt')
x2
```
輸出:
```
tensor([1., 1., 1.])
```
我們還可以存儲一個`Tensor`列表并讀回內存。
``` python
y = torch.zeros(4)
torch.save([x, y], 'xy.pt')
xy_list = torch.load('xy.pt')
xy_list
```
輸出:
```
[tensor([1., 1., 1.]), tensor([0., 0., 0., 0.])]
```
存儲并讀取一個從字符串映射到`Tensor`的字典。
``` python
torch.save({'x': x, 'y': y}, 'xy_dict.pt')
xy = torch.load('xy_dict.pt')
xy
```
輸出:
```
{'x': tensor([1., 1., 1.]), 'y': tensor([0., 0., 0., 0.])}
```
## 4.5.2 讀寫模型
### 4.5.2.1 `state_dict`
在PyTorch中,`Module`的可學習參數(即權重和偏差),模塊模型包含在參數中(通過`model.parameters()`訪問)。`state_dict`是一個從參數名稱隱射到參數`Tesnor`的字典對象。
``` python
class MLP(nn.Module):
def __init__(self):
super(MLP, self).__init__()
self.hidden = nn.Linear(3, 2)
self.act = nn.ReLU()
self.output = nn.Linear(2, 1)
def forward(self, x):
a = self.act(self.hidden(x))
return self.output(a)
net = MLP()
net.state_dict()
```
輸出:
```
OrderedDict([('hidden.weight', tensor([[ 0.2448, 0.1856, -0.5678],
[ 0.2030, -0.2073, -0.0104]])),
('hidden.bias', tensor([-0.3117, -0.4232])),
('output.weight', tensor([[-0.4556, 0.4084]])),
('output.bias', tensor([-0.3573]))])
```
注意,只有具有可學習參數的層(卷積層、線性層等)才有`state_dict`中的條目。優化器(`optim`)也有一個`state_dict`,其中包含關于優化器狀態以及所使用的超參數的信息。
``` python
optimizer = torch.optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
optimizer.state_dict()
```
輸出:
```
{'param_groups': [{'dampening': 0,
'lr': 0.001,
'momentum': 0.9,
'nesterov': False,
'params': [4736167728, 4736166648, 4736167368, 4736165352],
'weight_decay': 0}],
'state': {}}
```
### 4.5.2.2 保存和加載模型
PyTorch中保存和加載訓練模型有兩種常見的方法:
1. 僅保存和加載模型參數(`state_dict`);
2. 保存和加載整個模型。
#### 1. 保存和加載`state_dict`(推薦方式)
保存:
``` python
torch.save(model.state_dict(), PATH) # 推薦的文件后綴名是pt或pth
```
加載:
``` python
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
```
#### 2. 保存和加載整個模型
保存:
``` python
torch.save(model, PATH)
```
加載:
``` python
model = torch.load(PATH)
```
我們采用推薦的方法一來實驗一下:
``` python
X = torch.randn(2, 3)
Y = net(X)
PATH = "./net.pt"
torch.save(net.state_dict(), PATH)
net2 = MLP()
net2.load_state_dict(torch.load(PATH))
Y2 = net2(X)
Y2 == Y
```
輸出:
```
tensor([[1],
[1]], dtype=torch.uint8)
```
因為這`net`和`net2`都有同樣的模型參數,那么對同一個輸入`X`的計算結果將會是一樣的。上面的輸出也驗證了這一點。
此外,還有一些其他使用場景,例如GPU與CPU之間的模型保存與讀取、使用多塊GPU的模型的存儲等等,使用的時候可以參考[官方文檔](https://pytorch.org/tutorials/beginner/saving_loading_models.html)。
## 小結
* 通過`save`函數和`load`函數可以很方便地讀寫`Tensor`。
* 通過`save`函數和`load_state_dict`函數可以很方便地讀寫模型的參數。
-----------
> 注:本節與原書此節有一些不同,[原書傳送門](https://zh.d2l.ai/chapter_deep-learning-computation/read-write.html)
- Home
- Introduce
- 1.深度學習簡介
- 深度學習簡介
- 2.預備知識
- 2.1環境配置
- 2.2數據操作
- 2.3自動求梯度
- 3.深度學習基礎
- 3.1 線性回歸
- 3.2 線性回歸的從零開始實現
- 3.3 線性回歸的簡潔實現
- 3.4 softmax回歸
- 3.5 圖像分類數據集(Fashion-MINST)
- 3.6 softmax回歸的從零開始實現
- 3.7 softmax回歸的簡潔實現
- 3.8 多層感知機
- 3.9 多層感知機的從零開始實現
- 3.10 多層感知機的簡潔實現
- 3.11 模型選擇、反向傳播和計算圖
- 3.12 權重衰減
- 3.13 丟棄法
- 3.14 正向傳播、反向傳播和計算圖
- 3.15 數值穩定性和模型初始化
- 3.16 實戰kaggle比賽:房價預測
- 4 深度學習計算
- 4.1 模型構造
- 4.2 模型參數的訪問、初始化和共享
- 4.3 模型參數的延后初始化
- 4.4 自定義層
- 4.5 讀取和存儲
- 4.6 GPU計算
- 5 卷積神經網絡
- 5.1 二維卷積層
- 5.2 填充和步幅
- 5.3 多輸入通道和多輸出通道
- 5.4 池化層
- 5.5 卷積神經網絡(LeNet)
- 5.6 深度卷積神經網絡(AlexNet)
- 5.7 使用重復元素的網絡(VGG)
- 5.8 網絡中的網絡(NiN)
- 5.9 含并行連結的網絡(GoogLeNet)
- 5.10 批量歸一化
- 5.11 殘差網絡(ResNet)
- 5.12 稠密連接網絡(DenseNet)
- 6 循環神經網絡
- 6.1 語言模型
- 6.2 循環神經網絡
- 6.3 語言模型數據集(周杰倫專輯歌詞)
- 6.4 循環神經網絡的從零開始實現
- 6.5 循環神經網絡的簡單實現
- 6.6 通過時間反向傳播
- 6.7 門控循環單元(GRU)
- 6.8 長短期記憶(LSTM)
- 6.9 深度循環神經網絡
- 6.10 雙向循環神經網絡
- 7 優化算法
- 7.1 優化與深度學習
- 7.2 梯度下降和隨機梯度下降
- 7.3 小批量隨機梯度下降
- 7.4 動量法
- 7.5 AdaGrad算法
- 7.6 RMSProp算法
- 7.7 AdaDelta
- 7.8 Adam算法
- 8 計算性能
- 8.1 命令式和符號式混合編程
- 8.2 異步計算
- 8.3 自動并行計算
- 8.4 多GPU計算
- 9 計算機視覺
- 9.1 圖像增廣
- 9.2 微調
- 9.3 目標檢測和邊界框
- 9.4 錨框
- 10 自然語言處理
- 10.1 詞嵌入(word2vec)
- 10.2 近似訓練
- 10.3 word2vec實現
- 10.4 子詞嵌入(fastText)
- 10.5 全局向量的詞嵌入(Glove)
- 10.6 求近義詞和類比詞
- 10.7 文本情感分類:使用循環神經網絡
- 10.8 文本情感分類:使用卷積網絡
- 10.9 編碼器--解碼器(seq2seq)
- 10.10 束搜索
- 10.11 注意力機制
- 10.12 機器翻譯