這裡介紹如何在 Keras 的程式中查詢深度學習模型參數的總數量。
Keras 可用來快速搭建各種深度學習模型,但是在嘗試各種模型的過程中,我們也時常會需要了解模型的結構與參數的數量,方便調整模型。
這裡我們以手寫數字辨識的例子來作為示範,這是建立模型的程式碼(準備資料的部分以及最後訓練模型的部分拿掉了)。
import tensorflow as tf num_classes = 10 img_rows, img_cols = 28, 28 input_shape = (img_rows, img_cols, 1) model = tf.contrib.keras.models.Sequential() model.add(tf.contrib.keras.layers.Conv2D(32, kernel_size=(3, 3), activation='relu', input_shape=input_shape)) model.add(tf.contrib.keras.layers.Conv2D(64, (3, 3), activation='relu')) model.add(tf.contrib.keras.layers.MaxPooling2D(pool_size=(2, 2))) model.add(tf.contrib.keras.layers.Dropout(0.25)) model.add(tf.contrib.keras.layers.Flatten()) model.add(tf.contrib.keras.layers.Dense(128, activation='relu')) model.add(tf.contrib.keras.layers.Dropout(0.5)) model.add(tf.contrib.keras.layers.Dense(num_classes, activation='softmax'))
若要查詢深度學習模型參數的數量,可在模型建立好之後,呼叫 Keras 本身所提供的 count_params
來取得所有的參數總量:
# 模型建立完成後,統計參數總量 print("Total Parameters:%d" % model.count_params())
Total Parameters:1199882
以這個手寫數字辨識的例子來說,整個模型中包含了 1,199,882 個參數,也就是說在使用 SGD 找最佳解時,是在 1,199,882 維的空間中尋找,所以控制這個參數總量是很重要的。
另外一個更好用的函數是 summary
,它可以輸出整個模型的摘要資訊,包含簡單的結構表與參數統計,這個資訊可以讓整個模型一目瞭然:
# 輸出模型摘要資訊
model.summary()
_________________________________________________________________ Layer (type) Output Shape Param # ================================================================= conv2d_1 (Conv2D) (None, 26, 26, 32) 320 _________________________________________________________________ conv2d_2 (Conv2D) (None, 24, 24, 64) 18496 _________________________________________________________________ max_pooling2d_1 (MaxPooling2 (None, 12, 12, 64) 0 _________________________________________________________________ dropout_1 (Dropout) (None, 12, 12, 64) 0 _________________________________________________________________ flatten_1 (Flatten) (None, 9216) 0 _________________________________________________________________ dense_1 (Dense) (None, 128) 1179776 _________________________________________________________________ dropout_2 (Dropout) (None, 128) 0 _________________________________________________________________ dense_2 (Dense) (None, 10) 1290 ================================================================= Total params: 1,199,882 Trainable params: 1,199,882 Non-trainable params: 0 _________________________________________________________________
參考資料:StackOverflow