失眠网,内容丰富有趣,生活中的好帮手!
失眠网 > 【神经网络】(15) Xception 代码复现 网络解析 附Tensorflow完整代码

【神经网络】(15) Xception 代码复现 网络解析 附Tensorflow完整代码

时间:2019-04-27 20:18:20

相关推荐

【神经网络】(15) Xception 代码复现 网络解析 附Tensorflow完整代码

各位同学好,今天和大家分享一下如何使用Tensorflow构建Xception神经网络模型。

在前面章节中,我已经介绍了很多种轻量化卷积神经网络模型,感兴趣的可以看一下:/dgvv4/category_11517910.html

Xception 是一种兼顾了准确性和轻量化的算法。如下图所示,横轴表示计算量,纵轴表示准确率。在准确率上,Xception是排在第一梯队的,且在计算速度上,也算是轻量化网络模型。

Xception 使用了 MobileNetV1 的深度可分离卷积方法,建议大家先学习一下MobileNetV1:/dgvv4/article/details/123415708

1. 深度可分离卷积

为了帮助大家更好地掌握 Xception,先简单地复习一下深度可分离卷积的方法。

普通卷积一个卷积核处理所有的通道,输入特征图有多少个通道,卷积核就有几个通道,一个卷积核生成一张特征图。

深度可分离卷积可理解为 深度卷积 + 逐点卷积

深度卷积只处理长宽方向的空间信息;逐点卷积只处理跨通道方向的信息。能大大减少参数量,提高计算效率

深度卷积:是一个卷积核只处理一个通道,即每个卷积核只处理自己对应的通道输入特征图有多少个通道就有多少个卷积核。将每个卷积核处理后的特征图堆叠在一起。输入和输出特征图的通道数相同。

由于只处理长宽方向的信息会导致丢失跨通道信息,为了将跨通道的信息补充回来,需要进行逐点卷积。

逐点卷积:是使用1x1卷积对跨通道维度处理有多少个1x1卷积核就会生成多少个特征图

2. 从 Inception 到 Xception

接下来梳理一下从Inception到Xception网络的核心模块的改进过程,帮助大家对Xception结构有进一步的认识。

首先 InceptionV1 是由9个 BottleNeck-Inception 模块堆叠而成,如下图所示。

2.1 Inception模块

Inception模块的原理:将输入的特征图分成四个分支,进行四种不同的处理,再将四种方法处理的结果特征图堆叠起来,输入到下一层。

通过尽可能多的分解和解耦,用不同的尺度、不同的卷积来获取不同层次,不同力度的信息。

2.2 BottleNeck模块

随着Inception 模块的输出特征图不断的堆叠,特征图的通道数会越来越多。为了防止特征图越来越多,运算量和参数量爆炸。在 3x3 和 5x5 卷积之前添加了1x1卷积进行降维,控制输出特征图的数量,减少参数量和计算量。左图为Inception模块,右图为BottleNeck模块。

2.3 Inception 网络的改进过程

(1)首先 InceptionV3 改进了 BottleNeck 模块,将 5x5 卷积分解成两个 3x3 卷积。两层3x3卷积代替一层5x5卷积可以获得相同的感受野,减少参数量,增加非线性,提高模型的表达能力。

(2)将池化层后的1x1卷积换成3x3卷积。

(3)第一层全使用1x1卷积,第二层全使用3x3卷积。

(4)图像输入进来后,先经过一次1x1卷积生成特征图,接下来三个分支都对这个特征图处理。

(5)图像输入后,使用分组卷积对1x1卷积后的特征图处理,不同的卷积核处理不同的通道,各分组之间相互独立

(6)Xception模块,使用深度可分离卷积思想,先逐点卷积,后深度卷积,每个3x3卷积只处理一个通道。逐点卷积和深度卷积的先后次序并太大无影响。

3. 代码复现

3.1 网络结构图

论文中给出的 Xception 网络模型结构如下图所示

3.2 搭建各个卷积模块

(1)标准卷积块

一个标准卷积块由 卷积+批标准化+激活函数 组成

#(1)标准卷积模块def conv_block(input_tensor, filters, kernel_size, stride):# 普通卷积+标准化+激活函数x = layers.Conv2D(filters = filters, # 输出特征图个数kernel_size = kernel_size, # 卷积sizestrides = stride, # 步长padding = 'same', # 步长=1输出特征图size不变,步长=2特征图长宽减半use_bias = False)(input_tensor) # 有BN层就不需要偏置x = layers.BatchNormalization()(x) # 批标准化x = layers.ReLU()(x) # relu激活函数return x # 返回标准卷积的输出特征图

(2)残差块

按结构图所示,构建一个残差单元,由 两个深度可分离卷积+最大池化+残差边 组成

#(2)深度可分离卷积模块def sep_conv_block(input_tensor, filters, kernel_size):# 激活函数x = layers.ReLU()(input_tensor)# 深度可分离卷积函数,包含了(深度卷积+逐点卷积)x = layers.SeparableConvolution2D(filters = filters, # 逐点卷积的卷积核个数,输出特征图个数kernel_size = kernel_size, # 深度卷积的卷积核sizestrides = 1, # 深度卷积的步长padding = 'same', # 卷积过程中特征图size不变use_bias = False)(x) # 有BN层就不要偏置return x # 返回输出特征图#(3)一个残差单元def res_block(input_tensor, filters):# ① 残差边residual = layers.Conv2D(filters, # 输出图像的通道数kernel_size = (1,1), # 卷积核sizestrides = 2)(input_tensor) # 使输入和输出的size相同residual = layers.BatchNormalization()(residual) # 批标准化# ② 卷积块x = sep_conv_block(input_tensor, filters, kernel_size=(3,3))x = sep_conv_block(x, filters, kernel_size=(3,3))x = layers.MaxPooling2D(pool_size=(3,3), strides=2, padding='same')(x)# ③ 输入输出叠加,残差连接output = layers.Add()([residual, x])return output

3.3 完整代码展示

import tensorflow as tffrom tensorflow import kerasfrom tensorflow.keras import Model, layers#(1)标准卷积模块def conv_block(input_tensor, filters, kernel_size, stride):# 普通卷积+标准化+激活函数x = layers.Conv2D(filters = filters, # 输出特征图个数kernel_size = kernel_size, # 卷积sizestrides = stride, # 步长padding = 'same', # 步长=1输出特征图size不变,步长=2特征图长宽减半use_bias = False)(input_tensor) # 有BN层就不需要偏置x = layers.BatchNormalization()(x) # 批标准化x = layers.ReLU()(x) # relu激活函数return x # 返回标准卷积的输出特征图#(2)深度可分离卷积模块def sep_conv_block(input_tensor, filters, kernel_size):# 激活函数x = layers.ReLU()(input_tensor)# 深度可分离卷积函数,包含了(深度卷积+逐点卷积)x = layers.SeparableConvolution2D(filters = filters, # 逐点卷积的卷积核个数,输出特征图个数kernel_size = kernel_size, # 深度卷积的卷积核sizestrides = 1, # 深度卷积的步长padding = 'same', # 卷积过程中特征图size不变use_bias = False)(x) # 有BN层就不要偏置return x # 返回输出特征图#(3)一个残差单元def res_block(input_tensor, filters):# ① 残差边residual = layers.Conv2D(filters, # 输出图像的通道数kernel_size = (1,1), # 卷积核sizestrides = 2)(input_tensor) # 使输入和输出的size相同residual = layers.BatchNormalization()(residual) # 批标准化# ② 卷积块x = sep_conv_block(input_tensor, filters, kernel_size=(3,3))x = sep_conv_block(x, filters, kernel_size=(3,3))x = layers.MaxPooling2D(pool_size=(3,3), strides=2, padding='same')(x)# ③ 输入输出叠加,残差连接output = layers.Add()([residual, x])return output#(4)Middle Flow模块def middle_flow(x, filters):# 该模块循环8次for _ in range(8): # 残差边residual = x# 三个深度可分离卷积块x = sep_conv_block(x, filters, kernel_size=(3,3))x = sep_conv_block(x, filters, kernel_size=(3,3))x = sep_conv_block(x, filters, kernel_size=(3,3))# 叠加残差边x = layers.Add()([residual, x])return x#(5)主干网络def xception(input_shape, classes):# 构建输入inputs = keras.Input(shape=input_shape)# [299,299,3]==>[149,149,32]x = conv_block(inputs, filters=32, kernel_size=(3,3), stride=2) # 标准卷积块# [149,149,32]==>[149,149,64]x = conv_block(x, filters=64, kernel_size=(3,3), stride=1)# [149,149,64]==>[75,75,128]# 残差边residual = layers.Conv2D(filters=128, kernel_size=(1,1), strides=2, padding='same', use_bias=False)(x)residual = layers.BatchNormalization()(residual)# 卷积块[149,149,64]==>[149,149,128]x = layers.SeparableConv2D(128, kernel_size=(3,3), strides=1, padding='same',use_bias=False)(x)x = layers.BatchNormalization()(x)# [149,149,128]==>[149,149,128]x = sep_conv_block(x, filters=128, kernel_size=(3,3))# [149,149,128]==>[75,75,128]x = layers.MaxPooling2D(pool_size=(3,3), strides=2, padding='same')(x)# [75,75,128]==>[38,38,256]x = res_block(x, filters=256)# [38,38,256]==>[19,19,728]x = res_block(x, filters=728)# [19,19,728]==>[19,19,728]x = middle_flow(x, filters=728)# 残差边模块[19,19,728]==>[10,10,1024]residual = layers.Conv2D(filters=1024, kernel_size=(1,1), strides=2, use_bias=False, padding='same')(x) residual = layers.BatchNormalization()(residual) # 批标准化# 卷积块[19,19,728]==>[19,19,728]x = sep_conv_block(x, filters=728, kernel_size=(3,3))# [19,19,728]==>[19,19,1024]x = sep_conv_block(x, filters=1024, kernel_size=(3,3))# [19,19,1024]==>[10,10,1024]x = layers.MaxPooling2D(pool_size=(3,3), strides=2, padding='same')(x)# 叠加残差边[10,10,1024]x = layers.Add()([residual, x])# [10,10,1024]==>[10,10,1536]x = layers.SeparableConv2D(1536, (3,3), padding='same', use_bias=False)(x)x = layers.BatchNormalization()(x)x = layers.ReLU()(x)# [10,10,1536]==>[10,10,2048]x = layers.SeparableConv2D(2048, (3,3), padding='same', use_bias=False)(x)x = layers.BatchNormalization()(x)x = layers.ReLU()(x)# [10,10,2048]==>[None,2048]x = layers.GlobalAveragePooling2D()(x)# [None,2048]==>[None,classes]outputs = layers.Dense(classes)(x) # logits层不做softmax# 构建模型model = Model(inputs, outputs)return model#(6)接收网络模型if __name__ == '__main__':model = xception(input_shape=[299,299,3], classes=1000)model.summary() # 查看网络模型结构

3.4 查看网络架构

通过 model.summary() 查看网络模型框架,网络参数量2千多万

Model: "model"__________________________________________________________________________________________________Layer (type)Output Shape Param #Connected to ==================================================================================================input_1 (InputLayer) [(None, 299, 299, 3) 0 __________________________________________________________________________________________________conv2d (Conv2D) (None, 150, 150, 32) 864 input_1[0][0]__________________________________________________________________________________________________batch_normalization (BatchNorma (None, 150, 150, 32) 128 conv2d[0][0] __________________________________________________________________________________________________re_lu (ReLU)(None, 150, 150, 32) 0 batch_normalization[0][0] __________________________________________________________________________________________________conv2d_1 (Conv2D)(None, 150, 150, 64) 18432 re_lu[0][0] __________________________________________________________________________________________________batch_normalization_1 (BatchNor (None, 150, 150, 64) 256 conv2d_1[0][0] __________________________________________________________________________________________________re_lu_1 (ReLU) (None, 150, 150, 64) 0 batch_normalization_1[0][0]__________________________________________________________________________________________________separable_conv2d (SeparableConv (None, 150, 150, 128 8768 re_lu_1[0][0]__________________________________________________________________________________________________batch_normalization_3 (BatchNor (None, 150, 150, 128 512 separable_conv2d[0][0] __________________________________________________________________________________________________re_lu_2 (ReLU) (None, 150, 150, 128 0 batch_normalization_3[0][0]__________________________________________________________________________________________________separable_conv2d_1 (SeparableCo (None, 150, 150, 128 17536 re_lu_2[0][0]__________________________________________________________________________________________________max_pooling2d (MaxPooling2D) (None, 75, 75, 128) 0 separable_conv2d_1[0][0] __________________________________________________________________________________________________re_lu_3 (ReLU) (None, 75, 75, 128) 0 max_pooling2d[0][0] __________________________________________________________________________________________________separable_conv2d_2 (SeparableCo (None, 75, 75, 256) 33920 re_lu_3[0][0]__________________________________________________________________________________________________re_lu_4 (ReLU) (None, 75, 75, 256) 0 separable_conv2d_2[0][0] __________________________________________________________________________________________________conv2d_3 (Conv2D)(None, 38, 38, 256) 33024 max_pooling2d[0][0] __________________________________________________________________________________________________separable_conv2d_3 (SeparableCo (None, 75, 75, 256) 67840 re_lu_4[0][0]__________________________________________________________________________________________________batch_normalization_4 (BatchNor (None, 38, 38, 256) 1024 conv2d_3[0][0] __________________________________________________________________________________________________max_pooling2d_1 (MaxPooling2D) (None, 38, 38, 256) 0 separable_conv2d_3[0][0] __________________________________________________________________________________________________add (Add) (None, 38, 38, 256) 0 batch_normalization_4[0][0]max_pooling2d_1[0][0] __________________________________________________________________________________________________re_lu_5 (ReLU) (None, 38, 38, 256) 0 add[0][0] __________________________________________________________________________________________________separable_conv2d_4 (SeparableCo (None, 38, 38, 728) 188672re_lu_5[0][0]__________________________________________________________________________________________________re_lu_6 (ReLU) (None, 38, 38, 728) 0 separable_conv2d_4[0][0] __________________________________________________________________________________________________conv2d_4 (Conv2D)(None, 19, 19, 728) 187096add[0][0] __________________________________________________________________________________________________separable_conv2d_5 (SeparableCo (None, 38, 38, 728) 536536re_lu_6[0][0]__________________________________________________________________________________________________batch_normalization_5 (BatchNor (None, 19, 19, 728) 2912 conv2d_4[0][0] __________________________________________________________________________________________________max_pooling2d_2 (MaxPooling2D) (None, 19, 19, 728) 0 separable_conv2d_5[0][0] __________________________________________________________________________________________________add_1 (Add) (None, 19, 19, 728) 0 batch_normalization_5[0][0]max_pooling2d_2[0][0] __________________________________________________________________________________________________re_lu_7 (ReLU) (None, 19, 19, 728) 0 add_1[0][0] __________________________________________________________________________________________________separable_conv2d_6 (SeparableCo (None, 19, 19, 728) 536536re_lu_7[0][0]__________________________________________________________________________________________________re_lu_8 (ReLU) (None, 19, 19, 728) 0 separable_conv2d_6[0][0] __________________________________________________________________________________________________separable_conv2d_7 (SeparableCo (None, 19, 19, 728) 536536re_lu_8[0][0]__________________________________________________________________________________________________re_lu_9 (ReLU) (None, 19, 19, 728) 0 separable_conv2d_7[0][0] __________________________________________________________________________________________________separable_conv2d_8 (SeparableCo (None, 19, 19, 728) 536536re_lu_9[0][0]__________________________________________________________________________________________________add_2 (Add) (None, 19, 19, 728) 0 add_1[0][0] separable_conv2d_8[0][0] __________________________________________________________________________________________________re_lu_10 (ReLU) (None, 19, 19, 728) 0 add_2[0][0] __________________________________________________________________________________________________separable_conv2d_9 (SeparableCo (None, 19, 19, 728) 536536re_lu_10[0][0] __________________________________________________________________________________________________re_lu_11 (ReLU) (None, 19, 19, 728) 0 separable_conv2d_9[0][0] __________________________________________________________________________________________________separable_conv2d_10 (SeparableC (None, 19, 19, 728) 536536re_lu_11[0][0] __________________________________________________________________________________________________re_lu_12 (ReLU) (None, 19, 19, 728) 0 separable_conv2d_10[0][0] __________________________________________________________________________________________________separable_conv2d_11 (SeparableC (None, 19, 19, 728) 536536re_lu_12[0][0] __________________________________________________________________________________________________add_3 (Add) (None, 19, 19, 728) 0 add_2[0][0] separable_conv2d_11[0][0] __________________________________________________________________________________________________re_lu_13 (ReLU) (None, 19, 19, 728) 0 add_3[0][0] __________________________________________________________________________________________________separable_conv2d_12 (SeparableC (None, 19, 19, 728) 536536re_lu_13[0][0] __________________________________________________________________________________________________re_lu_14 (ReLU) (None, 19, 19, 728) 0 separable_conv2d_12[0][0] __________________________________________________________________________________________________separable_conv2d_13 (SeparableC (None, 19, 19, 728) 536536re_lu_14[0][0] __________________________________________________________________________________________________re_lu_15 (ReLU) (None, 19, 19, 728) 0 separable_conv2d_13[0][0] __________________________________________________________________________________________________separable_conv2d_14 (SeparableC (None, 19, 19, 728) 536536re_lu_15[0][0] __________________________________________________________________________________________________add_4 (Add) (None, 19, 19, 728) 0 add_3[0][0] separable_conv2d_14[0][0] __________________________________________________________________________________________________re_lu_16 (ReLU) (None, 19, 19, 728) 0 add_4[0][0] __________________________________________________________________________________________________separable_conv2d_15 (SeparableC (None, 19, 19, 728) 536536re_lu_16[0][0] __________________________________________________________________________________________________re_lu_17 (ReLU) (None, 19, 19, 728) 0 separable_conv2d_15[0][0] __________________________________________________________________________________________________separable_conv2d_16 (SeparableC (None, 19, 19, 728) 536536re_lu_17[0][0] __________________________________________________________________________________________________re_lu_18 (ReLU) (None, 19, 19, 728) 0 separable_conv2d_16[0][0] __________________________________________________________________________________________________separable_conv2d_17 (SeparableC (None, 19, 19, 728) 536536re_lu_18[0][0] __________________________________________________________________________________________________add_5 (Add) (None, 19, 19, 728) 0 add_4[0][0] separable_conv2d_17[0][0] __________________________________________________________________________________________________re_lu_19 (ReLU) (None, 19, 19, 728) 0 add_5[0][0] __________________________________________________________________________________________________separable_conv2d_18 (SeparableC (None, 19, 19, 728) 536536re_lu_19[0][0] __________________________________________________________________________________________________re_lu_20 (ReLU) (None, 19, 19, 728) 0 separable_conv2d_18[0][0] __________________________________________________________________________________________________separable_conv2d_19 (SeparableC (None, 19, 19, 728) 536536re_lu_20[0][0] __________________________________________________________________________________________________re_lu_21 (ReLU) (None, 19, 19, 728) 0 separable_conv2d_19[0][0] __________________________________________________________________________________________________separable_conv2d_20 (SeparableC (None, 19, 19, 728) 536536re_lu_21[0][0] __________________________________________________________________________________________________add_6 (Add) (None, 19, 19, 728) 0 add_5[0][0] separable_conv2d_20[0][0] __________________________________________________________________________________________________re_lu_22 (ReLU) (None, 19, 19, 728) 0 add_6[0][0] __________________________________________________________________________________________________separable_conv2d_21 (SeparableC (None, 19, 19, 728) 536536re_lu_22[0][0] __________________________________________________________________________________________________re_lu_23 (ReLU) (None, 19, 19, 728) 0 separable_conv2d_21[0][0] __________________________________________________________________________________________________separable_conv2d_22 (SeparableC (None, 19, 19, 728) 536536re_lu_23[0][0] __________________________________________________________________________________________________re_lu_24 (ReLU) (None, 19, 19, 728) 0 separable_conv2d_22[0][0] __________________________________________________________________________________________________separable_conv2d_23 (SeparableC (None, 19, 19, 728) 536536re_lu_24[0][0] __________________________________________________________________________________________________add_7 (Add) (None, 19, 19, 728) 0 add_6[0][0] separable_conv2d_23[0][0] __________________________________________________________________________________________________re_lu_25 (ReLU) (None, 19, 19, 728) 0 add_7[0][0] __________________________________________________________________________________________________separable_conv2d_24 (SeparableC (None, 19, 19, 728) 536536re_lu_25[0][0] __________________________________________________________________________________________________re_lu_26 (ReLU) (None, 19, 19, 728) 0 separable_conv2d_24[0][0] __________________________________________________________________________________________________separable_conv2d_25 (SeparableC (None, 19, 19, 728) 536536re_lu_26[0][0] __________________________________________________________________________________________________re_lu_27 (ReLU) (None, 19, 19, 728) 0 separable_conv2d_25[0][0] __________________________________________________________________________________________________separable_conv2d_26 (SeparableC (None, 19, 19, 728) 536536re_lu_27[0][0] __________________________________________________________________________________________________add_8 (Add) (None, 19, 19, 728) 0 add_7[0][0] separable_conv2d_26[0][0] __________________________________________________________________________________________________re_lu_28 (ReLU) (None, 19, 19, 728) 0 add_8[0][0] __________________________________________________________________________________________________separable_conv2d_27 (SeparableC (None, 19, 19, 728) 536536re_lu_28[0][0] __________________________________________________________________________________________________re_lu_29 (ReLU) (None, 19, 19, 728) 0 separable_conv2d_27[0][0] __________________________________________________________________________________________________separable_conv2d_28 (SeparableC (None, 19, 19, 728) 536536re_lu_29[0][0] __________________________________________________________________________________________________re_lu_30 (ReLU) (None, 19, 19, 728) 0 separable_conv2d_28[0][0] __________________________________________________________________________________________________separable_conv2d_29 (SeparableC (None, 19, 19, 728) 536536re_lu_30[0][0] __________________________________________________________________________________________________add_9 (Add) (None, 19, 19, 728) 0 add_8[0][0] separable_conv2d_29[0][0] __________________________________________________________________________________________________re_lu_31 (ReLU) (None, 19, 19, 728) 0 add_9[0][0] __________________________________________________________________________________________________separable_conv2d_30 (SeparableC (None, 19, 19, 728) 536536re_lu_31[0][0] __________________________________________________________________________________________________re_lu_32 (ReLU) (None, 19, 19, 728) 0 separable_conv2d_30[0][0] __________________________________________________________________________________________________conv2d_5 (Conv2D)(None, 10, 10, 1024) 745472add_9[0][0] __________________________________________________________________________________________________separable_conv2d_31 (SeparableC (None, 19, 19, 1024) 752024re_lu_32[0][0] __________________________________________________________________________________________________batch_normalization_6 (BatchNor (None, 10, 10, 1024) 4096 conv2d_5[0][0] __________________________________________________________________________________________________max_pooling2d_3 (MaxPooling2D) (None, 10, 10, 1024) 0 separable_conv2d_31[0][0] __________________________________________________________________________________________________add_10 (Add)(None, 10, 10, 1024) 0 batch_normalization_6[0][0]max_pooling2d_3[0][0] __________________________________________________________________________________________________separable_conv2d_32 (SeparableC (None, 10, 10, 1536) 1582080add_10[0][0] __________________________________________________________________________________________________batch_normalization_7 (BatchNor (None, 10, 10, 1536) 6144 separable_conv2d_32[0][0] __________________________________________________________________________________________________re_lu_33 (ReLU) (None, 10, 10, 1536) 0 batch_normalization_7[0][0]__________________________________________________________________________________________________separable_conv2d_33 (SeparableC (None, 10, 10, 2048) 3159552re_lu_33[0][0] __________________________________________________________________________________________________batch_normalization_8 (BatchNor (None, 10, 10, 2048) 8192 separable_conv2d_33[0][0] __________________________________________________________________________________________________re_lu_34 (ReLU) (None, 10, 10, 2048) 0 batch_normalization_8[0][0]__________________________________________________________________________________________________global_average_pooling2d (Globa (None, 2048) 0 re_lu_34[0][0] __________________________________________________________________________________________________dense (Dense) (None, 1000) 2049000global_average_pooling2d[0][0] ==================================================================================================Total params: 22,817,480Trainable params: 22,805,848Non-trainable params: 11,632__________________________________________________________________________________________________

如果觉得《【神经网络】(15) Xception 代码复现 网络解析 附Tensorflow完整代码》对你有帮助,请点赞、收藏,并留下你的观点哦!

本内容不代表本网观点和政治立场,如有侵犯你的权益请联系我们处理。
网友评论
网友评论仅供其表达个人看法,并不表明网站立场。