這裡介紹如何在 Keras 的程式中查詢深度學習模型參數的總數量。

Keras 可用來快速搭建各種深度學習模型,但是在嘗試各種模型的過程中,我們也時常會需要了解模型的結構與參數的數量,方便調整模型。


這裡我們以手寫數字辨識的例子來作為示範,這是建立模型的程式碼(準備資料的部分以及最後訓練模型的部分拿掉了)。

import tensorflow as tf
num_classes = 10
img_rows, img_cols = 28, 28
input_shape = (img_rows, img_cols, 1)

model = tf.contrib.keras.models.Sequential()
model.add(tf.contrib.keras.layers.Conv2D(32, kernel_size=(3, 3),
                 activation='relu',
                 input_shape=input_shape))
model.add(tf.contrib.keras.layers.Conv2D(64, (3, 3), activation='relu'))
model.add(tf.contrib.keras.layers.MaxPooling2D(pool_size=(2, 2)))
model.add(tf.contrib.keras.layers.Dropout(0.25))
model.add(tf.contrib.keras.layers.Flatten())
model.add(tf.contrib.keras.layers.Dense(128, activation='relu'))
model.add(tf.contrib.keras.layers.Dropout(0.5))
model.add(tf.contrib.keras.layers.Dense(num_classes, activation='softmax'))

若要查詢深度學習模型參數的數量,可在模型建立好之後,呼叫 Keras 本身所提供的 count_params 來取得所有的參數總量:

# 模型建立完成後,統計參數總量
print("Total Parameters:%d" % model.count_params())
Total Parameters:1199882

以這個手寫數字辨識的例子來說,整個模型中包含了 1,199,882 個參數,也就是說在使用 SGD 找最佳解時,是在 1,199,882 維的空間中尋找,所以控制這個參數總量是很重要的。

另外一個更好用的函數是 summary,它可以輸出整個模型的摘要資訊,包含簡單的結構表與參數統計,這個資訊可以讓整個模型一目瞭然:

# 輸出模型摘要資訊
model.summary()
_________________________________________________________________
Layer (type)                 Output Shape              Param #
=================================================================
conv2d_1 (Conv2D)            (None, 26, 26, 32)        320
_________________________________________________________________
conv2d_2 (Conv2D)            (None, 24, 24, 64)        18496
_________________________________________________________________
max_pooling2d_1 (MaxPooling2 (None, 12, 12, 64)        0
_________________________________________________________________
dropout_1 (Dropout)          (None, 12, 12, 64)        0
_________________________________________________________________
flatten_1 (Flatten)          (None, 9216)              0
_________________________________________________________________
dense_1 (Dense)              (None, 128)               1179776
_________________________________________________________________
dropout_2 (Dropout)          (None, 128)               0
_________________________________________________________________
dense_2 (Dense)              (None, 10)                1290
=================================================================
Total params: 1,199,882
Trainable params: 1,199,882
Non-trainable params: 0
_________________________________________________________________

參考資料:StackOverflow