<ruby id="bdb3f"></ruby>

    <p id="bdb3f"><cite id="bdb3f"></cite></p>

      <p id="bdb3f"><cite id="bdb3f"><th id="bdb3f"></th></cite></p><p id="bdb3f"></p>
        <p id="bdb3f"><cite id="bdb3f"></cite></p>

          <pre id="bdb3f"></pre>
          <pre id="bdb3f"><del id="bdb3f"><thead id="bdb3f"></thead></del></pre>

          <ruby id="bdb3f"><mark id="bdb3f"></mark></ruby><ruby id="bdb3f"></ruby>
          <pre id="bdb3f"><pre id="bdb3f"><mark id="bdb3f"></mark></pre></pre><output id="bdb3f"></output><p id="bdb3f"></p><p id="bdb3f"></p>

          <pre id="bdb3f"><del id="bdb3f"><progress id="bdb3f"></progress></del></pre>

                <ruby id="bdb3f"></ruby>

                ThinkChat2.0新版上線,更智能更精彩,支持會話、畫圖、視頻、閱讀、搜索等,送10W Token,即刻開啟你的AI之旅 廣告
                # 如何在 Keras 中檢查深度學習模型 > 原文: [https://machinelearningmastery.com/check-point-deep-learning-models-keras/](https://machinelearningmastery.com/check-point-deep-learning-models-keras/) 深度學習模型在訓練時可能需要花費數小時,數天甚至數周。 如果意外停止運行,則可能會丟失大量成果。 在這篇文章中,您將了解如何使用Python中的keras庫在模型訓練期間檢查您的深度學習模型。 讓我們開始吧。 * **2017 年 3 月更新**:更新了 Keras 2.0.2,TensorFlow 1.0.1 和 Theano 0.9.0 的示例。 * **更新 March / 2018** :添加了備用鏈接以下載數據集,因為原始圖像已被刪除。 ![How to Check-Point Deep Learning Models in Keras](https://img.kancloud.cn/b7/82/b7827003fe430de1cd822883ab69fb42_640x480.png) 照片由 [saragoldsmith](https://www.flickr.com/photos/saragoldsmith/2353051153/) 提供,并保留其所屬權利。 ## 檢驗點神經網絡模型 [應用程序檢查點](https://en.wikipedia.org/wiki/Application_checkpointing)是一種容錯技術,適用于長時間運行的進程。 這是一種在系統出現故障時采用系統狀態快照的方法,如果出現問題,任務并非全部丟失,檢查點可以直接使用,或者從中斷處開始,用作程序重新運行的起點。 在訓練深度學習模型時,檢查點是模型的權重參數,這些權重可用于按原樣進行預測,或用作持續訓練的基礎。 Keras 庫通過回調 API 提供[檢查點功能。](http://keras.io/callbacks/#modelcheckpoint) ModelCheckpoint 回調類允許您定義檢查模型權重的位置,文件應如何命名以及在何種情況下創建模型的檢查點。 API 允許您指定要監控的度量標準,例如訓練或驗證數據集的損失或準確性,您可以指定是否在最大化或最小化分數時尋求改進,最后,用于存儲權重的文件名可以包含諸如迭代數量或度量的變量。 然后,在模型上調用`fit()`函數時,可以將 ModelCheckpoint 傳遞給訓練過程。 注意,您可能需要安裝 [h5py 庫](http://www.h5py.org/)以輸出 HDF5 格式的網絡權重。 ## 檢查點神經網絡模型改進 檢查點的良好用途是每次在訓練期間觀察到性能提升時輸出模型權重參數。 下面的例子為皮馬印第安人糖尿病二元分類問題創建了一個小型神經網絡。該示例假設 _pima-indians-diabetes.csv_ 文件位于您的工作目錄中。 您可以從此處下載數據集: * [皮馬印第安人糖尿病數據集](https://raw.githubusercontent.com/jbrownlee/Datasets/master/pima-indians-diabetes.data.csv) 該示例使用 33%的數據作為驗證集。 只有在驗證數據集`(monitor ='val_acc'和 mode ='max')`的分類準確性有所提高時,才會設置檢驗點以保存網絡權重參數。權重參數存儲在一個文件中,該.hdf5文件的文件名為當前精度值(格式化輸出為:`權重改進 - {val_acc = .2f} .hdf5`)。 ```py # 當驗證集的精度有所提高時,需要保存當前的權重參數 from keras.models import Sequential from keras.layers import Dense from keras.callbacks import ModelCheckpoint import matplotlib.pyplot as plt import numpy # 固定隨機種子再現性 seed = 7 numpy.random.seed(seed) # 加載數據集 dataset = numpy.loadtxt("pima-indians-diabetes.csv", delimiter=",") # 將數據集劃分為輸入變量X和輸出變量Y X = dataset[:,0:8] Y = dataset[:,8] # 創建模型 model = Sequential() model.add(Dense(12, input_dim=8, kernel_initializer='uniform', activation='relu')) model.add(Dense(8, kernel_initializer='uniform', activation='relu')) model.add(Dense(1, kernel_initializer='uniform', activation='sigmoid')) # 編譯模型 model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy']) # 檢查點 filepath="weights-improvement-{epoch:02d}-{val_acc:.2f}.hdf5" checkpoint = ModelCheckpoint(filepath, monitor='val_acc', verbose=1, save_best_only=True, mode='max') callbacks_list = [checkpoint] # 擬合模型 model.fit(X, Y, validation_split=0.33, epochs=150, batch_size=10, callbacks=callbacks_list, verbose=0) ``` 運行該示例將生成以下輸出(為簡潔表示,只顯示其一部分結果): ```py ... Epoch 00134: val_acc did not improve Epoch 00135: val_acc did not improve Epoch 00136: val_acc did not improve Epoch 00137: val_acc did not improve Epoch 00138: val_acc did not improve Epoch 00139: val_acc did not improve Epoch 00140: val_acc improved from 0.83465 to 0.83858, saving model to weights-improvement-140-0.84.hdf5 Epoch 00141: val_acc did not improve Epoch 00142: val_acc did not improve Epoch 00143: val_acc did not improve Epoch 00144: val_acc did not improve Epoch 00145: val_acc did not improve Epoch 00146: val_acc improved from 0.83858 to 0.84252, saving model to weights-improvement-146-0.84.hdf5 Epoch 00147: val_acc did not improve Epoch 00148: val_acc improved from 0.84252 to 0.84252, saving model to weights-improvement-148-0.84.hdf5 Epoch 00149: val_acc did not improve ``` 您將在工作目錄中看到許多文件,其中包含 HDF5 格式的網絡權重。例如: ```py ... weights-improvement-53-0.76.hdf5 weights-improvement-71-0.76.hdf5 weights-improvement-77-0.78.hdf5 weights-improvement-99-0.78.hdf5 ``` 這是一個非常簡單的檢查點策略,如果驗證準確度在訓練時期上下移動,則可能會創建大量不必要的檢查點文件,然而,它將確保您發現模型運行期間的最佳快照。 ## 僅限檢查點最佳神經網絡模型 更簡單的檢查點策略是當且僅當驗證準確度提高時將模型權重保存到同一文件中。 這可以使用上面相同的代碼輕松完成,并將輸出文件名更改為固定的字符串(不包括分數或迭代信息)。 在這種情況下,只有當驗證數據集上模型的分類精度提高到當前最佳時,模型權重才會被寫入文件`weights.best.hdf5`. ```py #當驗證模型精度最高時,保存權重 from keras.models import Sequential from keras.layers import Dense from keras.callbacks import ModelCheckpoint import matplotlib.pyplot as plt import numpy #固定隨機種子再現性 seed = 7 numpy.random.seed(seed) # 加載數據集 dataset = numpy.loadtxt("pima-indians-diabetes.csv", delimiter=",") # 將數據集劃分為輸入變量和輸出變量 X = dataset[:,0:8] Y = dataset[:,8] # 創建模型 model = Sequential() model.add(Dense(12, input_dim=8, kernel_initializer='uniform', activation='relu')) model.add(Dense(8, kernel_initializer='uniform', activation='relu')) model.add(Dense(1, kernel_initializer='uniform', activation='sigmoid')) # 編譯模型 model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy']) # 檢查點 filepath="weights.best.hdf5" checkpoint = ModelCheckpoint(filepath, monitor='val_acc', verbose=1, save_best_only=True, mode='max') callbacks_list = [checkpoint] # 擬合模型 model.fit(X, Y, validation_split=0.33, epochs=150, batch_size=10, callbacks=callbacks_list, verbose=0) ``` 運行此示例提供以下輸出(為簡潔表示,只顯示其一部分結果): ```py ... Epoch 00139: val_acc improved from 0.79134 to 0.79134, saving model to weights.best.hdf5 Epoch 00140: val_acc did not improve Epoch 00141: val_acc did not improve Epoch 00142: val_acc did not improve Epoch 00143: val_acc did not improve Epoch 00144: val_acc improved from 0.79134 to 0.79528, saving model to weights.best.hdf5 Epoch 00145: val_acc improved from 0.79528 to 0.79528, saving model to weights.best.hdf5 Epoch 00146: val_acc did not improve Epoch 00147: val_acc did not improve Epoch 00148: val_acc did not improve Epoch 00149: val_acc did not improve ``` 您應該可以在本地目錄中看到權重文件。 ```py weights.best.hdf5 ``` 這是在您的實驗中能夠始終使用的一個方便的檢查點策略,它將確保為運行保存最佳模型,以便您以后使用,這個策略避免了您在訓練時需要包含代碼以手動跟蹤和序列化最佳模型。 ## 加載一個檢查點的神經網絡模型 現在您已經了解了如何在訓練期間檢查您的深度學習模型,您現在需要了解如何加載和使用檢查點模型。 檢查點僅包括模型權重,假設您了解網絡結構,這些模型權重也可以序列化為 JSON 或 YAML 格式的文件。 在下面的示例中,模型結構是已知的,最佳權重從上一個實驗加載,存儲在 weights.best.hdf5 文件的工作目錄中。 然后使用該模型對整個數據集進行預測。 ```py # 怎樣從一個檢查點加載和使用權重參數 from keras.models import Sequential from keras.layers import Dense from keras.callbacks import ModelCheckpoint import matplotlib.pyplot as plt import numpy # 固定隨機種子再現性 numpy.random.seed(seed) # 創建模型 model = Sequential() model.add(Dense(12, input_dim=8, kernel_initializer='uniform', activation='relu')) model.add(Dense(8, kernel_initializer='uniform', activation='relu')) model.add(Dense(1, kernel_initializer='uniform', activation='sigmoid')) # 加載權重參數 model.load_weights("weights.best.hdf5") # 編譯模型(需要做出預測) model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy']) print("Created model and loaded weights from file") # 加載數據集 dataset = numpy.loadtxt("pima-indians-diabetes.csv", delimiter=",") # 將數據集劃分為輸入變量和輸出變量 X = dataset[:,0:8] Y = dataset[:,8] # 在整個數據集上使用加載的權重參數評估模型性能 scores = model.evaluate(X, Y, verbose=0) print("%s: %.2f%%" % (model.metrics_names[1], scores[1]*100)) ``` 運行該示例將生成以下輸出: ```py Created model and loaded weights from file acc: 77.73% ``` ## 摘要 在這篇文章中,您已經了解了檢查點在深度學習模型中長時間訓練中的重要性。 您學習了兩個檢查點策略,您可以在下一個深度學習項目中使用它們: 1. 檢查點模型改進。 2. 唯一的檢查點最佳模型。 您還學習了如何加載檢查點模型并進行預測。 如果您對深度學習模型檢查點或此篇文章有任何疑問,在評論中提出您的問題,我會盡力回答。
                  <ruby id="bdb3f"></ruby>

                  <p id="bdb3f"><cite id="bdb3f"></cite></p>

                    <p id="bdb3f"><cite id="bdb3f"><th id="bdb3f"></th></cite></p><p id="bdb3f"></p>
                      <p id="bdb3f"><cite id="bdb3f"></cite></p>

                        <pre id="bdb3f"></pre>
                        <pre id="bdb3f"><del id="bdb3f"><thead id="bdb3f"></thead></del></pre>

                        <ruby id="bdb3f"><mark id="bdb3f"></mark></ruby><ruby id="bdb3f"></ruby>
                        <pre id="bdb3f"><pre id="bdb3f"><mark id="bdb3f"></mark></pre></pre><output id="bdb3f"></output><p id="bdb3f"></p><p id="bdb3f"></p>

                        <pre id="bdb3f"><del id="bdb3f"><progress id="bdb3f"></progress></del></pre>

                              <ruby id="bdb3f"></ruby>

                              哎呀哎呀视频在线观看