這裡介紹如何使用 Keras 儲存與載入訓練好的模型或參數,以利重複使用或部署產品。
訓練一個實際的類神經網路模型會需要非常大量的運算,所以在模型訓練完之後,最好可以把訓練好的模型參數儲存下來,這樣之後在使用時就可以省去重新訓練的時間。
我們以簡單的手寫數字辨識 CNN 模型為範例,示範如何把訓練好的模型儲存起來。在模型訓練完之後,若要儲存整個模型,只要呼叫 save
函數,並指定 HDF5 的檔案名稱即可:
# [略] # 訓練模型 model.fit(x_train, y_train, batch_size=128 * 2, epochs=12, verbose=1, validation_data=(x_test, y_test)) # 將模型儲存至 HDF5 檔案中 model.save('my_model.h5') # creates a HDF5 file 'my_model.h5'
在模型儲存至 HDF5 檔案之後,未來要使用時就可以呼叫 keras.models.load_model
直接載入之前訓練好的模型:
# 準備 x_test 與 y_test 資料 ... [略] # 從 HDF5 檔案中載入模型 model = tf.contrib.keras.models.load_model('my_model.h5') # 驗證模型 score = model.evaluate(x_test, y_test, verbose=0) # 輸出結果 print('Test loss:', score[0]) print('Test accuracy:', score[1])
從 HDF5 載入的模型,使用起來跟原本的模型是一模一樣的:
Test loss: 0.0300953536768 Test accuracy: 0.9897
如果只要將模型儲存起來,不儲存其中的參數,可以使用 to_json
或 to_yaml
將模型轉為 JSON 或 YAML 的文字資料,在自己儲存至檔案中:
# 將模型匯出至 JSON(不含參數) json_string = model.to_json() # 將模型匯出至 YAML(不含參數) yaml_string = model.to_yaml()
若要從 JSON 或 YAML 重建模型,可以使用 model_from_json
或 model_from_yaml
:
# 從 JSON 資料重建模型 model = tf.contrib.keras.models.model_from_json(json_string) # 從 YAML 資料重建模型 model = tf.contrib.keras.models.model_from_yaml(yaml_string)
若只想要儲存模型的參數(也就是 weights),不包含模型本身,可以使用 save_weights
:
# 將參數儲存至 HDF5 檔案(不含模型) model.save_weights('my_model_weights.h5')
若要載入參數,可使用 load_weights
:
# 從 HDF5 檔案載入參數(不含模型) model.load_weights('my_model_weights.h5')
若要將儲存的參數載入至不同的模型中使用(模型不同,但有相同網路層,例如 fine-tuning 或 transfer-learning),可以加上 by_name
參數:
# 載入參數至不同的模型中使用 model.load_weights('my_model_weights.h5', by_name = True)
以下是一個簡單的範例,假設原始的模型如下,在原始模型訓練好之後,我們將這個模型的參數儲存下來:
# 原始模型 model = Sequential() model.add(Dense(2, input_dim=3, name='dense_1')) model.add(Dense(3, name='dense_2')) # [略] # 儲存參數 model.save_weights("weight_1.h5")
接著我們又建立另外一個新的模型,而這個新的模型與舊的模型之間有部份的網路層是相同的,在將參數載入至新模型時,只有那些相同的網路層參數會受影響,其餘的參數則不會改變:
# 新建模型 model = Sequential() # 相同網路層,會載入參數 model.add(Dense(2, input_dim=3, name='dense_1')) # 不同網路層,不會載入參數 model.add(Dense(10, name='new_dense')) # 載入參數,只會影響 dense_1 那一層 model.load_weights(fname, by_name = True)
若在模型中有包含自訂的網路層、類別或函數等,可在載入時加入 custom_objects
自訂物件參數,使其正常載入:
# 假設模型中有包含一個自訂的 AttentionLayer 類別實體 model = tf.contrib.keras.models.load_model('my_model.h5', custom_objects = {'AttentionLayer': AttentionLayer})
亦可使用 CustomObjectScope
來載入自訂的類別實體:
# 亦可使用 custom object scope 來載入自訂的類別實體 with CustomObjectScope({'AttentionLayer': AttentionLayer}): model = load_model('my_model.h5')
自訂物件參數的用法,在 load_model
、model_from_json
與 model_from_yaml
中都相同:
# 從 JSON 資料中載入 model = model_from_json(json_string, custom_objects={'AttentionLayer': AttentionLayer})
若在儲存或載入 HDF5 檔案時出現這樣的錯誤:
ImportError: `save_model` requires h5py.
代表系統上少裝了 h5py
,用 pip
裝一下即可:
pip install h5py
參考資料:Keras