分類: 程式設計

Keras 儲存與載入訓練好的模型或參數教學

這裡介紹如何使用 Keras 儲存與載入訓練好的模型或參數,以利重複使用或部署產品。

訓練一個實際的類神經網路模型會需要非常大量的運算,所以在模型訓練完之後,最好可以把訓練好的模型參數儲存下來,這樣之後在使用時就可以省去重新訓練的時間。


在 Keras 中若要儲存與載入訓練好的模型或參數,可以使用其內建模型儲存與載入功能,將模型儲存於 HDF5 或 JSON 檔案中,以下是 Keras 儲存模型的操作方式。

Keras 儲存與載入完整模型與參數

我們以簡單的手寫數字辨識 CNN 模型為範例,示範如何把訓練好的模型儲存起來。在模型訓練完之後,若要儲存整個模型,只要呼叫 save 函數,並指定 HDF5 的檔案名稱即可:

# [略]

# 訓練模型
model.fit(x_train, y_train,
          batch_size=128 * 2,
          epochs=12,
          verbose=1,
          validation_data=(x_test, y_test))

# 將模型儲存至 HDF5 檔案中
model.save('my_model.h5')  # creates a HDF5 file 'my_model.h5'

在模型儲存至 HDF5 檔案之後,未來要使用時就可以呼叫 keras.models.load_model 直接載入之前訓練好的模型:

# 準備 x_test 與 y_test 資料 ... [略]

# 從 HDF5 檔案中載入模型
model = tf.contrib.keras.models.load_model('my_model.h5')

# 驗證模型
score = model.evaluate(x_test, y_test, verbose=0)

# 輸出結果
print('Test loss:', score[0])
print('Test accuracy:', score[1])

從 HDF5 載入的模型,使用起來跟原本的模型是一模一樣的:

Test loss: 0.0300953536768
Test accuracy: 0.9897

只儲存與載入模型

如果只要將模型儲存起來,不儲存其中的參數,可以使用 to_jsonto_yaml 將模型轉為 JSON 或 YAML 的文字資料,在自己儲存至檔案中:

# 將模型匯出至 JSON(不含參數)
json_string = model.to_json()

# 將模型匯出至 YAML(不含參數)
yaml_string = model.to_yaml()

若要從 JSON 或 YAML 重建模型,可以使用 model_from_jsonmodel_from_yaml

# 從 JSON 資料重建模型
model = tf.contrib.keras.models.model_from_json(json_string)

# 從 YAML 資料重建模型
model = tf.contrib.keras.models.model_from_yaml(yaml_string)

只儲存與載入參數

若只想要儲存模型的參數(也就是 weights),不包含模型本身,可以使用 save_weights

# 將參數儲存至 HDF5 檔案(不含模型)
model.save_weights('my_model_weights.h5')

若要載入參數,可使用 load_weights

# 從 HDF5 檔案載入參數(不含模型)
model.load_weights('my_model_weights.h5')

若要將儲存的參數載入至不同的模型中使用(模型不同,但有相同網路層,例如 fine-tuning 或 transfer-learning),可以加上 by_name 參數:

# 載入參數至不同的模型中使用
model.load_weights('my_model_weights.h5', by_name = True)

範例

以下是一個簡單的範例,假設原始的模型如下,在原始模型訓練好之後,我們將這個模型的參數儲存下來:

# 原始模型
model = Sequential()
model.add(Dense(2, input_dim=3, name='dense_1'))
model.add(Dense(3, name='dense_2'))

# [略]

# 儲存參數
model.save_weights("weight_1.h5")

接著我們又建立另外一個新的模型,而這個新的模型與舊的模型之間有部份的網路層是相同的,在將參數載入至新模型時,只有那些相同的網路層參數會受影響,其餘的參數則不會改變:

# 新建模型
model = Sequential()

# 相同網路層,會載入參數
model.add(Dense(2, input_dim=3, name='dense_1'))

# 不同網路層,不會載入參數
model.add(Dense(10, name='new_dense'))

# 載入參數,只會影響 dense_1 那一層
model.load_weights(fname, by_name = True)

自訂類別

若在模型中有包含自訂的網路層、類別或函數等,可在載入時加入 custom_objects 自訂物件參數,使其正常載入:

# 假設模型中有包含一個自訂的 AttentionLayer 類別實體
model = tf.contrib.keras.models.load_model('my_model.h5',
  custom_objects = {'AttentionLayer': AttentionLayer})

亦可使用 CustomObjectScope 來載入自訂的類別實體:

# 亦可使用 custom object scope 來載入自訂的類別實體
with CustomObjectScope({'AttentionLayer': AttentionLayer}):
  model = load_model('my_model.h5')

自訂物件參數的用法,在 load_modelmodel_from_jsonmodel_from_yaml 中都相同:

# 從 JSON 資料中載入
model = model_from_json(json_string, custom_objects={'AttentionLayer': AttentionLayer})

常見問題

若在儲存或載入 HDF5 檔案時出現這樣的錯誤:

ImportError: `save_model` requires h5py.

代表系統上少裝了 h5py,用 pip 裝一下即可:

pip install h5py

參考資料:Keras

G. T. Wang

個人使用 Linux 經驗長達十餘年,樂於分享各種自由軟體技術與實作文章。

Share
Published by
G. T. Wang

Recent Posts

光陽 KYMCO GP 125 機車接電發動、更換電瓶記錄

本篇記錄我的光陽 KYMCO ...

2 年 ago

[開箱] YubiKey 5C NFC 實體金鑰

本篇是 YubiKey 5C ...

2 年 ago

[DIY] 自製竹火把

本篇記錄我拿竹子加上過期的苦茶...

2 年 ago