失眠网,内容丰富有趣,生活中的好帮手!
失眠网 > 代码解读——Retinex低光照图像增强(Deep Retinex Decomposition for Low-Light Enhancement)

代码解读——Retinex低光照图像增强(Deep Retinex Decomposition for Low-Light Enhancement)

时间:2021-07-21 14:45:24

相关推荐

代码解读——Retinex低光照图像增强(Deep Retinex Decomposition for Low-Light Enhancement)

今天带来一篇代码解读的文章,是BMVC上的一篇暗光增强文章。个人觉得网络比较轻量并且能够取得还不错的效果。废话不多说,直接贴传送门:

文章地址:/abs/1808.04560

源码地址:/weichen582/RetinexNet

文章基于Retinex理论,不懂的请戳这里:/lz0499/article/details/81154937

整体结构主要包括两个网络:DecomNet和RelightNet。DecomNet用于分解图片为反射分量和光照分量,RelightNet用于将光照分量修正,再与反射分量重建,得到修正后的图像。可参考下图:

其中,作者提到了在RelightNet中同时对反射分量进行去噪处理,但在代码中我没有明确看到这步操作,有知道的小伙伴可以评论区留言。

先来看DecomNet的网络构建部分。整体就是全卷积网络,具体看我代码注释。

def DecomNet(input_im, layer_num, channel=64, kernel_size=3): #分解网络input_max = tf.reduce_max(input_im, axis=3, keepdims=True)input_im = concat([input_max, input_im])#选取RGB三通道中的最大值(亮度)进行堆叠,变成4通道,与最后一层卷积'recon_layer'相呼应with tf.variable_scope('DecomNet', reuse=tf.AUTO_REUSE): #这里面创建的所有tensor都带有'DecomNet'conv = tf.layers.conv2d(input_im, channel, kernel_size * 3, padding='same', activation=None, name="shallow_feature_extraction")for idx in range(layer_num): #进行5次带relu激活的卷积conv = tf.layers.conv2d(conv, channel, kernel_size, padding='same', activation=tf.nn.relu, name='activated_layer_%d' % idx)conv = tf.layers.conv2d(conv, 4, kernel_size, padding='same', activation=None, name='recon_layer')# reconv到4通道以便分解'''将卷积结果分解成反射分量和光照分量'''R = tf.sigmoid(conv[:,:,:,0:3]) #反射分量(仅由物体本身决定),反应颜色一致性,需要三通道描述L = tf.sigmoid(conv[:,:,:,3:4]) #光照分量,反应光照信息,一通道即可描述(相当于亮度图)return R, L

然后是RelightNet的网络构建部分。

def RelightNet(input_L, input_R, channel=64, kernel_size=3): # 恢复(调整)网络input_im = concat([input_R, input_L])with tf.variable_scope('RelightNet'):'''3次下采样'''conv0 = tf.layers.conv2d(input_im, channel, kernel_size, padding='same', activation=None)conv1 = tf.layers.conv2d(conv0, channel, kernel_size, strides=2, padding='same', activation=tf.nn.relu)conv2 = tf.layers.conv2d(conv1, channel, kernel_size, strides=2, padding='same', activation=tf.nn.relu)conv3 = tf.layers.conv2d(conv2, channel, kernel_size, strides=2, padding='same', activation=tf.nn.relu)'''3次上采样 最近邻插值'''up1 = tf.image.resize_nearest_neighbor(conv3, (tf.shape(conv2)[1], tf.shape(conv2)[2]))deconv1 = tf.layers.conv2d(up1, channel, kernel_size, padding='same', activation=tf.nn.relu) + conv2up2 = tf.image.resize_nearest_neighbor(deconv1, (tf.shape(conv1)[1], tf.shape(conv1)[2]))deconv2= tf.layers.conv2d(up2, channel, kernel_size, padding='same', activation=tf.nn.relu) + conv1up3 = tf.image.resize_nearest_neighbor(deconv2, (tf.shape(conv0)[1], tf.shape(conv0)[2]))deconv3 = tf.layers.conv2d(up3, channel, kernel_size, padding='same', activation=tf.nn.relu) + conv0'''多尺度特征融合,在不同尺度上对光照分量进行恢复'''deconv1_resize = tf.image.resize_nearest_neighbor(deconv1, (tf.shape(deconv3)[1], tf.shape(deconv3)[2]))deconv2_resize = tf.image.resize_nearest_neighbor(deconv2, (tf.shape(deconv3)[1], tf.shape(deconv3)[2]))feature_gather = concat([deconv1_resize, deconv2_resize, deconv3])feature_fusion = tf.layers.conv2d(feature_gather, channel, 1, padding='same', activation=None)output = tf.layers.conv2d(feature_fusion, filters=1, kernel_size=3, padding='same', activation=None)return output#返回单通道图像,即修正后的光照分量

接下来重点看损失函数的构建部分。首先看DecomNet的损失部分:

self.input_low = tf.placeholder(tf.float32, [None, None, None, 3], name='input_low')#占位符定义,喂数据self.input_high = tf.placeholder(tf.float32, [None, None, None, 3], name='input_high')[R_low, I_low] = DecomNet(self.input_low, layer_num=self.DecomNet_layer_num)[R_high, I_high] = DecomNet(self.input_high, layer_num=self.DecomNet_layer_num)I_low_3 = concat([I_low, I_low, I_low])I_high_3 = concat([I_high, I_high, I_high])I_delta_3 = concat([I_delta, I_delta, I_delta])self.recon_loss_low = tf.reduce_mean(tf.abs(R_low * I_low_3 - self.input_low))#分解出的两个图相乘 与 原来的图应该一样(Retinex理论)保证Decom的正确性self.recon_loss_high = tf.reduce_mean(tf.abs(R_high * I_high_3 - self.input_high)) #同上self.recon_loss_mutal_low = tf.reduce_mean(tf.abs(R_high * I_low_3 - self.input_low)) #保证R_high==R_low(反射分量一致性)self.recon_loss_mutal_high = tf.reduce_mean(tf.abs(R_low * I_high_3 - self.input_high)) #同上self.equal_R_loss = tf.reduce_mean(tf.abs(R_low - R_high))#同上self.Ismooth_loss_low = self.smooth(I_low, R_low) # 平滑光照图的同时保留边缘信息,适应原图特征(用梯度实现)self.Ismooth_loss_high = self.smooth(I_high, R_high)self.loss_Decom = self.recon_loss_low + self.recon_loss_high + 0.001 * self.recon_loss_mutal_low + 0.001 * self.recon_loss_mutal_high + 0.1 * self.Ismooth_loss_low + 0.1 * self.Ismooth_loss_high + 0.01 * self.equal_R_loss

可以看到,作者设计了挺多的函数,如同我注释中所说的。其中,平滑约束是我觉得最巧妙的地方。因为光照分量一般来说是低频部分,相对而言应该是平滑的,而反射分量反应物体特征,应该是细节丰富的。而且在图像边缘区,即便是光照分量也不能太过平滑,否则就类似高斯平滑失去特征了。有了平滑损失就可以自适应调节了。

其中,平滑函数,梯度函数定义如下:

def gradient(self, input_tensor, direction):self.smooth_kernel_x = tf.reshape(tf.constant([[0, 0], [-1, 1]], tf.float32), [2, 2, 1, 1])self.smooth_kernel_y = tf.transpose(self.smooth_kernel_x, [1, 0, 2, 3])if direction == "x":kernel = self.smooth_kernel_xelif direction == "y":kernel = self.smooth_kernel_yreturn tf.abs(tf.nn.conv2d(input_tensor, kernel, strides=[1, 1, 1, 1], padding='SAME'))def ave_gradient(self, input_tensor, direction):return tf.layers.average_pooling2d(self.gradient(input_tensor, direction), pool_size=3, strides=1, padding='SAME')def smooth(self, input_I, input_R):# 给光照图I的梯度赋权值,实现自适应调节作用。反射图R梯度越小,赋予的权值越大,使光照图梯度减小,变得平滑input_R = tf.image.rgb_to_grayscale(input_R)return tf.reduce_mean(self.gradient(input_I, "x") * tf.exp(-10 * self.ave_gradient(input_R, "x")) + self.gradient(input_I, "y") * tf.exp(-10 * self.ave_gradient(input_R, "y")))

然后就是RelightNet的损失部分了:

self.relight_loss = tf.reduce_mean(tf.abs(R_low * I_delta_3 - self.input_high)) #保证恢复图像的正确(与低质量图反射分量重建后 接近 高质量图)self.Ismooth_loss_delta = self.smooth(I_delta, R_low)self.loss_Relight = self.relight_loss + 3 * self.Ismooth_loss_delta

这部分比较简单,就一个重建损失加上平滑损失。

下面是数据增强部分,主要就是旋转操作加翻转操作。

def data_augmentation(image, mode):if mode == 0:# originalreturn imageelif mode == 1:# flip up and downreturn np.flipud(image)elif mode == 2:# rotate counterwise 90 degreereturn np.rot90(image)elif mode == 3:# rotate 90 degree and flip up and downimage = np.rot90(image)return np.flipud(image)elif mode == 4:# rotate 180 degreereturn np.rot90(image, k=2)elif mode == 5:# rotate 180 degree and flipimage = np.rot90(image, k=2)return np.flipud(image)elif mode == 6:# rotate 270 degreereturn np.rot90(image, k=3)elif mode == 7:# rotate 270 degree and flipimage = np.rot90(image, k=3)return np.flipud(image)

还有一点,代码是将整个数据集里的图片全部读入内存再处理的,这对于个人电脑来说有点不现实,所以最好写个生成器。这里我就不贴代码了,有需要的可以私聊我。

最后,谈谈缺点吧,就是处理后的图片色彩有点失真。主要还是因为Decom-Net对低光照/正常光照图像分解出来的反射分量无法做到完全一致吧。

但总体来说,还是很棒的了。

如果觉得《代码解读——Retinex低光照图像增强(Deep Retinex Decomposition for Low-Light Enhancement)》对你有帮助,请点赞、收藏,并留下你的观点哦!

本内容不代表本网观点和政治立场,如有侵犯你的权益请联系我们处理。
网友评论
网友评论仅供其表达个人看法,并不表明网站立场。