# 3.5 圖像分類數據集(Fashion-MNIST)
在介紹softmax回歸的實現前我們先引入一個多類圖像分類數據集。它將在后面的章節中被多次使用,以方便我們觀察比較算法之間在模型精度和計算效率上的區別。圖像分類數據集中最常用的是手寫數字識別數據集MNIST[1]。但大部分模型在MNIST上的分類精度都超過了95%。為了更直觀地觀察算法之間的差異,我們將使用一個圖像內容更加復雜的數據集Fashion-MNIST[2](這個數據集也比較小,只有幾十M,沒有GPU的電腦也能吃得消)。
本節我們將使用torchvision包,它是服務于PyTorch深度學習框架的,主要用來構建計算機視覺模型。torchvision主要由以下幾部分構成:
1. `torchvision.datasets`: 一些加載數據的函數及常用的數據集接口;
2. `torchvision.models`: 包含常用的模型結構(含預訓練模型),例如AlexNet、VGG、ResNet等;
3. `torchvision.transforms`: 常用的圖片變換,例如裁剪、旋轉等;
4. `torchvision.utils`: 其他的一些有用的方法。
## 3.5.1 獲取數據集
首先導入本節需要的包或模塊。
``` python
import torch
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import time
import sys
sys.path.append("..") # 為了導入上層目錄的d2lzh_pytorch
import d2lzh_pytorch as d2l
```
下面,我們通過torchvision的`torchvision.datasets`來下載這個數據集。第一次調用時會自動從網上獲取數據。我們通過參數`train`來指定獲取訓練數據集或測試數據集(testing data set)。測試數據集也叫測試集(testing set),只用來評價模型的表現,并不用來訓練模型。
另外我們還指定了參數`transform = transforms.ToTensor()`使所有數據轉換為`Tensor`,如果不進行轉換則返回的是PIL圖片。`transforms.ToTensor()`將尺寸為 (H x W x C) 且數據位于[0, 255]的PIL圖片或者數據類型為`np.uint8`的NumPy數組轉換為尺寸為(C x H x W)且數據類型為`torch.float32`且位于[0.0, 1.0]的`Tensor`。
> 注意: 由于像素值為0到255的整數,所以剛好是uint8所能表示的范圍,包括`transforms.ToTensor()`在內的一些關于圖片的函數就默認輸入的是uint8型,若不是,可能不會報錯但的可能得不到想要的結果。所以,**如果用像素值(0-255整數)表示圖片數據,那么一律將其類型設置成uint8,避免不必要的bug。** 本人就被這點坑過,詳見[我的這個博客2.2.4節](https://tangshusen.me/2018/12/05/kaggle-doodle-reco/)。
``` python
mnist_train = torchvision.datasets.FashionMNIST(root='~/Datasets/FashionMNIST', train=True, download=True, transform=transforms.ToTensor())
mnist_test = torchvision.datasets.FashionMNIST(root='~/Datasets/FashionMNIST', train=False, download=True, transform=transforms.ToTensor())
```
上面的`mnist_train`和`mnist_test`都是[`torch.utils.data.Dataset`](https://pytorch.org/docs/stable/data.html)的子類,所以我們可以用`len()`來獲取該數據集的大小,還可以用下標來獲取具體的一個樣本。訓練集中和測試集中的每個類別的圖像數分別為6,000和1,000。因為有10個類別,所以訓練集和測試集的樣本數分別為60,000和10,000。
``` python
print(type(mnist_train))
print(len(mnist_train), len(mnist_test))
```
輸出:
```
<class 'torchvision.datasets.mnist.FashionMNIST'>
60000 10000
```
我們可以通過下標來訪問任意一個樣本:
``` python
feature, label = mnist_train[0]
print(feature.shape, label) # Channel x Height X Width
```
輸出:
```
torch.Size([1, 28, 28]) tensor(9)
```
變量`feature`對應高和寬均為28像素的圖像。由于我們使用了`transforms.ToTensor()`,所以每個像素的數值為[0.0, 1.0]的32位浮點數。需要注意的是,`feature`的尺寸是 (C x H x W) 的,而不是 (H x W x C)。第一維是通道數,因為數據集中是灰度圖像,所以通道數為1。后面兩維分別是圖像的高和寬。
Fashion-MNIST中一共包括了10個類別,分別為t-shirt(T恤)、trouser(褲子)、pullover(套衫)、dress(連衣裙)、coat(外套)、sandal(涼鞋)、shirt(襯衫)、sneaker(運動鞋)、bag(包)和ankle boot(短靴)。以下函數可以將數值標簽轉成相應的文本標簽。
``` python
# 本函數已保存在d2lzh包中方便以后使用
def get_fashion_mnist_labels(labels):
text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat',
'sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']
return [text_labels[int(i)] for i in labels]
```
下面定義一個可以在一行里畫出多張圖像和對應標簽的函數。
``` python
# 本函數已保存在d2lzh包中方便以后使用
def show_fashion_mnist(images, labels):
d2l.use_svg_display()
# 這里的_表示我們忽略(不使用)的變量
_, figs = plt.subplots(1, len(images), figsize=(12, 12))
for f, img, lbl in zip(figs, images, labels):
f.imshow(img.view((28, 28)).numpy())
f.set_title(lbl)
f.axes.get_xaxis().set_visible(False)
f.axes.get_yaxis().set_visible(False)
plt.show()
```
現在,我們看一下訓練數據集中前9個樣本的圖像內容和文本標簽。
``` python
X, y = [], []
for i in range(10):
X.append(mnist_train[i][0])
y.append(mnist_train[i][1])
show_fashion_mnist(X, get_fashion_mnist_labels(y))
```

## 3.5.2 讀取小批量
我們將在訓練數據集上訓練模型,并將訓練好的模型在測試數據集上評價模型的表現。前面說過,`mnist_train`是`torch.utils.data.Dataset`的子類,所以我們可以將其傳入`torch.utils.data.DataLoader`來創建一個讀取小批量數據樣本的DataLoader實例。
在實踐中,數據讀取經常是訓練的性能瓶頸,特別當模型較簡單或者計算硬件性能較高時。PyTorch的`DataLoader`中一個很方便的功能是允許使用多進程來加速數據讀取。這里我們通過參數`num_workers`來設置4個進程讀取數據。
``` python
batch_size = 256
if sys.platform.startswith('win'):
num_workers = 0 # 0表示不用額外的進程來加速讀取數據
else:
num_workers = 4
train_iter = torch.utils.data.DataLoader(mnist_train, batch_size=batch_size, shuffle=True, num_workers=num_workers)
test_iter = torch.utils.data.DataLoader(mnist_test, batch_size=batch_size, shuffle=False, num_workers=num_workers)
```
我們將獲取并讀取Fashion-MNIST數據集的邏輯封裝在`d2lzh_pytorch.load_data_fashion_mnist`函數中供后面章節調用。該函數將返回`train_iter`和`test_iter`兩個變量。隨著本書內容的不斷深入,我們會進一步改進該函數。它的完整實現將在5.6節中描述。
最后我們查看讀取一遍訓練數據需要的時間。
``` python
start = time.time()
for X, y in train_iter:
continue
print('%.2f sec' % (time.time() - start))
```
輸出:
```
1.57 sec
```
## 小結
* Fashion-MNIST是一個10類服飾分類數據集,之后章節里將使用它來檢驗不同算法的表現。
* 我們將高和寬分別為`$ h $`和`$ w $`像素的圖像的形狀記為`$ h \times w $`或`(h,w)`。
## 參考文獻
[1] LeCun, Y., Cortes, C., & Burges, C. http://yann.lecun.com/exdb/mnist/
[2] Xiao, H., Rasul, K., & Vollgraf, R. (2017). Fashion-mnist: a novel image dataset for benchmarking machine learning algorithms. arXiv preprint arXiv:1708.07747.
-----------
> 注:本節除了代碼之外與原書基本相同,[原書傳送門](https://zh.d2l.ai/chapter_deep-learning-basics/fashion-mnist.html)
- Home
- Introduce
- 1.深度學習簡介
- 深度學習簡介
- 2.預備知識
- 2.1環境配置
- 2.2數據操作
- 2.3自動求梯度
- 3.深度學習基礎
- 3.1 線性回歸
- 3.2 線性回歸的從零開始實現
- 3.3 線性回歸的簡潔實現
- 3.4 softmax回歸
- 3.5 圖像分類數據集(Fashion-MINST)
- 3.6 softmax回歸的從零開始實現
- 3.7 softmax回歸的簡潔實現
- 3.8 多層感知機
- 3.9 多層感知機的從零開始實現
- 3.10 多層感知機的簡潔實現
- 3.11 模型選擇、反向傳播和計算圖
- 3.12 權重衰減
- 3.13 丟棄法
- 3.14 正向傳播、反向傳播和計算圖
- 3.15 數值穩定性和模型初始化
- 3.16 實戰kaggle比賽:房價預測
- 4 深度學習計算
- 4.1 模型構造
- 4.2 模型參數的訪問、初始化和共享
- 4.3 模型參數的延后初始化
- 4.4 自定義層
- 4.5 讀取和存儲
- 4.6 GPU計算
- 5 卷積神經網絡
- 5.1 二維卷積層
- 5.2 填充和步幅
- 5.3 多輸入通道和多輸出通道
- 5.4 池化層
- 5.5 卷積神經網絡(LeNet)
- 5.6 深度卷積神經網絡(AlexNet)
- 5.7 使用重復元素的網絡(VGG)
- 5.8 網絡中的網絡(NiN)
- 5.9 含并行連結的網絡(GoogLeNet)
- 5.10 批量歸一化
- 5.11 殘差網絡(ResNet)
- 5.12 稠密連接網絡(DenseNet)
- 6 循環神經網絡
- 6.1 語言模型
- 6.2 循環神經網絡
- 6.3 語言模型數據集(周杰倫專輯歌詞)
- 6.4 循環神經網絡的從零開始實現
- 6.5 循環神經網絡的簡單實現
- 6.6 通過時間反向傳播
- 6.7 門控循環單元(GRU)
- 6.8 長短期記憶(LSTM)
- 6.9 深度循環神經網絡
- 6.10 雙向循環神經網絡
- 7 優化算法
- 7.1 優化與深度學習
- 7.2 梯度下降和隨機梯度下降
- 7.3 小批量隨機梯度下降
- 7.4 動量法
- 7.5 AdaGrad算法
- 7.6 RMSProp算法
- 7.7 AdaDelta
- 7.8 Adam算法
- 8 計算性能
- 8.1 命令式和符號式混合編程
- 8.2 異步計算
- 8.3 自動并行計算
- 8.4 多GPU計算
- 9 計算機視覺
- 9.1 圖像增廣
- 9.2 微調
- 9.3 目標檢測和邊界框
- 9.4 錨框
- 10 自然語言處理
- 10.1 詞嵌入(word2vec)
- 10.2 近似訓練
- 10.3 word2vec實現
- 10.4 子詞嵌入(fastText)
- 10.5 全局向量的詞嵌入(Glove)
- 10.6 求近義詞和類比詞
- 10.7 文本情感分類:使用循環神經網絡
- 10.8 文本情感分類:使用卷積網絡
- 10.9 編碼器--解碼器(seq2seq)
- 10.10 束搜索
- 10.11 注意力機制
- 10.12 機器翻譯