# 重新訓練現有的 CNN 模型
從頭開始訓練新的圖像識別需要大量的時間和計算能力。如果我們可以采用先前訓練的網絡并使用我們的圖像重新訓練它,它可以節省我們的計算時間。對于此秘籍,我們將展示如何使用預先訓練的 TensorFlow 圖像識別模型并對其進行微調以處理不同的圖像集。
## 做好準備
其思想是從卷積層重用先前模型的權重和結構,并重新訓練網絡頂部的完全連接層。
TensorFlow 在現有 CNN 模型的基礎上創建了一個關于訓練的教程(請參閱下一節中的第一個要點)。在本文中,我們將說明如何對 CIFAR-10 使用相同的方法。我們將采用的 CNN 網絡使用一種非常流行的架構,稱為 Inception。 Inception CNN 模型由 Google 創建,在許多圖像識別基準測試中表現非常出色。有關詳細信息,請參閱“另請參閱”部分的第二個要點中的紙張參考。
我們將介紹的主要 Python 腳本顯示如何下載 CIFAR-10 圖像數據并自動分離,標記和保存圖像到每個訓練和測試文件夾中的十個類。之后,我們將重申如何在我們的圖像上訓練網絡。
## 操作步驟
執行以下步驟:
1. 我們首先加載必要的庫來下載,解壓縮和保存 CIFAR-10 圖像:
```py
import os
import tarfile
import _pickle as cPickle
import numpy as np
import urllib.request
import scipy.misc
from imageio import imwrite
```
1. 我們現在聲明 CIFAR-10 數據鏈接并創建我們將存儲數據的臨時目錄。我們還將在以后保存圖像時聲明要引用的十個類別:
```py
cifar_link = 'https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz'
data_dir = 'temp'
if not os.path.isdir(data_dir):
os.makedirs(data_dir)
objects = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
```
1. 現在我們將下載 CIFAR-10 `.tar`數據文件,并解壓該文件:
```py
target_file = os.path.join(data_dir, 'cifar-10-python.tar.gz')
if not os.path.isfile(target_file):
print('CIFAR-10 file not found. Downloading CIFAR data (Size = 163MB)')
print('This may take a few minutes, please wait.')
filename, headers = urllib.request.urlretrieve(cifar_link, target_file)
# Extract into memory
tar = tarfile.open(target_file)
tar.extractall(path=data_dir)
tar.close()
```
1. 我們現在為訓練創建必要的文件夾結構。臨時目錄將有兩個文件夾,`train_dir`和`validation_dir`。在每個文件夾中,我們將為每個類別創建 10 個子文件夾:
```py
# Create train image folders
train_folder = 'train_dir'
if not os.path.isdir(os.path.join(data_dir, train_folder)):
for i in range(10):
folder = os.path.join(data_dir, train_folder, objects[i])
os.makedirs(folder)
# Create test image folders
test_folder = 'validation_dir'
if not os.path.isdir(os.path.join(data_dir, test_folder)):
for i in range(10):
folder = os.path.join(data_dir, test_folder, objects[i])
os.makedirs(folder)
```
1. 為了保存圖像,我們將創建一個從內存加載它們并將它們存儲在圖像字典中的函數:
```py
def load_batch_from_file(file):
file_conn = open(file, 'rb')
image_dictionary = cPickle.load(file_conn, encoding='latin1')
file_conn.close()
return(image_dictionary)
```
1. 使用前面的字典,我們將使用以下函數將每個文件保存在正確的位置:
```py
def save_images_from_dict(image_dict, folder='data_dir'):
for ix, label in enumerate(image_dict['labels']):
folder_path = os.path.join(data_dir, folder, objects[label])
filename = image_dict['filenames'][ix]
#Transform image data
image_array = image_dict['data'][ix]
image_array.resize([3, 32, 32])
# Save image
output_location = os.path.join(folder_path, filename)
imwrite(output_location,image_array.transpose())
```
1. 使用上述函數,我們可以遍歷下載的數據文件并將每個圖像保存到正確的位置:
```py
data_location = os.path.join(data_dir, 'cifar-10-batches-py')
train_names = ['data_batch_' + str(x) for x in range(1,6)]
test_names = ['test_batch']
# Sort train images
for file in train_names:
print('Saving images from file: {}'.format(file))
file_location = os.path.join(data_dir, 'cifar-10-batches-py', file)
image_dict = load_batch_from_file(file_location)
save_images_from_dict(image_dict, folder=train_folder)
# Sort test images
for file in test_names:
print('Saving images from file: {}'.format(file))
file_location = os.path.join(data_dir, 'cifar-10-batches-py', file)
image_dict = load_batch_from_file(file_location)
save_images_from_dict(image_dict, folder=test_folder)
```
1. 我們腳本的最后一部分創建了圖像標簽文件,這是我們需要的最后一條信息。這個文件讓我們將輸出解釋為標簽而不是數字索引:
```py
cifar_labels_file = os.path.join(data_dir,'cifar10_labels.txt')
print('Writing labels file, {}'.format(cifar_labels_file))
with open(cifar_labels_file, 'w') as labels_file:
for item in objects:
labels_file.write("{}n".format(item))
```
1. 當前面的腳本運行時,它將下載圖像并將它們分類到 TensorFlow 再訓練教程所期望的正確文件夾結構中。完成后,我們只需按照教程進行操作即可。首先,我們應該克隆教程倉庫:
```py
git clone https://github.com/tensorflow/models/tree/master/research/inception
```
1. 為了使用先前訓練的模型,我們必須下載網絡權重并將其應用于我們的模型。為此,您必須訪問該站點: [https://github.com/tensorflow/models/tree/master/research/slim](https://github.com/tensorflow/models/tree/master/research/slim) ,并按照說明下載并安裝 cifar10 模型架構和權重。您還將最終下載包含下面描述的構建,訓練和測試腳本的數據目錄。
> 對于此步驟,我們導航到 research / inception / inception 目錄,然后執行以下命令,`--train_directory`,`--validation_directory`,`--output_directory`和`--labels_file`的路徑指向相對路徑或完整路徑創建的目錄結構。
1. 現在我們將圖像放在正確的文件夾結構中,我們必須將它們變成`TFRecords`對象。我們通過運行以下命令來完成此操作:
```py
me@computer:~$ python3 data/build_image_data.py
--train_directory="temp/train_dir/"
--validation_directory="temp/validation_dir"
--output_directory="temp/" --labels_file="temp/cifar10_labels.txt"
```
1. 現在我們將使用`bazel`訓練模型,將參數設置為`true`。該腳本每 10 代輸出一次損失。我們可以隨時終止此過程,模型輸出將在`temp/training_results`文件夾中。我們可以從此文件夾加載模型以進行評估:
```py
me@computer:~$ bazel-bin/inception/flowers_train
--train_dir="temp/training_results" --data_dir="temp/data_dir"
--pretrained_model_checkpoint_path="model.ckpt-157585"
--fine_tune=True --initial_learning_rate=0.001
--input_queue_memory_factor=1
```
1. 這應該導致輸出類似于以下內容:
```py
2018-06-02 11:10:10.557012: step 1290, loss = 2.02 (1.2 examples/sec; 23.771 sec/batch)
...
```
## 工作原理
關于預訓練 CNN 上的訓練的官方 TensorFlow 教程需要設置一個文件夾;我們從 CIFAR-10 數據創建的設置。然后我們將數據轉換為所需的`TFRecords`格式并開始訓練模型。請記住,我們正在微調模型并重新訓練頂部的完全連接的層以適合我們的 10 類數據。
## 另見
* 官方 Tensorflow Inception-v3 教程: [https://www.tensorflow.org/tutoriaimg/image_recognition](https://www.tensorflow.org/tutoriaimg/image_recognition)
* Googlenet Inception-v3 文件: [https://arxiv.org/abs/1512.00567](https://arxiv.org/abs/1512.00567)
- TensorFlow 入門
- 介紹
- TensorFlow 如何工作
- 聲明變量和張量
- 使用占位符和變量
- 使用矩陣
- 聲明操作符
- 實現激活函數
- 使用數據源
- 其他資源
- TensorFlow 的方式
- 介紹
- 計算圖中的操作
- 對嵌套操作分層
- 使用多個層
- 實現損失函數
- 實現反向傳播
- 使用批量和隨機訓練
- 把所有東西結合在一起
- 評估模型
- 線性回歸
- 介紹
- 使用矩陣逆方法
- 實現分解方法
- 學習 TensorFlow 線性回歸方法
- 理解線性回歸中的損失函數
- 實現 deming 回歸
- 實現套索和嶺回歸
- 實現彈性網絡回歸
- 實現邏輯回歸
- 支持向量機
- 介紹
- 使用線性 SVM
- 簡化為線性回歸
- 在 TensorFlow 中使用內核
- 實現非線性 SVM
- 實現多類 SVM
- 最近鄰方法
- 介紹
- 使用最近鄰
- 使用基于文本的距離
- 使用混合距離函數的計算
- 使用地址匹配的示例
- 使用最近鄰進行圖像識別
- 神經網絡
- 介紹
- 實現操作門
- 使用門和激活函數
- 實現單層神經網絡
- 實現不同的層
- 使用多層神經網絡
- 改進線性模型的預測
- 學習玩井字棋
- 自然語言處理
- 介紹
- 使用詞袋嵌入
- 實現 TF-IDF
- 使用 Skip-Gram 嵌入
- 使用 CBOW 嵌入
- 使用 word2vec 進行預測
- 使用 doc2vec 進行情緒分析
- 卷積神經網絡
- 介紹
- 實現簡單的 CNN
- 實現先進的 CNN
- 重新訓練現有的 CNN 模型
- 應用 StyleNet 和 NeuralStyle 項目
- 實現 DeepDream
- 循環神經網絡
- 介紹
- 為垃圾郵件預測實現 RNN
- 實現 LSTM 模型
- 堆疊多個 LSTM 層
- 創建序列到序列模型
- 訓練 Siamese RNN 相似性度量
- 將 TensorFlow 投入生產
- 介紹
- 實現單元測試
- 使用多個執行程序
- 并行化 TensorFlow
- 將 TensorFlow 投入生產
- 生產環境 TensorFlow 的一個例子
- 使用 TensorFlow 服務
- 更多 TensorFlow
- 介紹
- 可視化 TensorBoard 中的圖
- 使用遺傳算法
- 使用 k 均值聚類
- 求解常微分方程組
- 使用隨機森林
- 使用 TensorFlow 和 Keras