Boundary Aware PoolNet(2):BASNet模型与代码介绍

Boundary Aware PoolNet = PoolNet + BASNet,即使用BASNet中的Deep Supervision策略和Hybrid Loss改进PoolNet。

为理解Boundary Aware PoolNet,我们并不需要学习整个BASNet,只需要了解其中的Deep Supervision策略和Hybrid Loss即可。

本文将简单介绍BASNet的模型结构,重点介绍其Deep SupervisionHybrid Loss的理论和代码实现。

相关文章汇总:

BASNet

传送门

BASNet结构

img

如上图所示,BASNet模型包括Predict Module和Refine Module。

  • Predict Module

    a U-Net-like densely supervised Encoder-Decoder network,作用是predict saliency map from input images。

    其实这个Encoder-Decoder结构和FPN(特征金字塔网络)没什么区别吧。

  • Refine Module

    refines the resulting saliency map of the prediction module by learning the residuals between the saliency map and the ground truth。

基于上述的2个Module,BASNet使用Deep Supervision(上图中的Sup1-8)和Hybrid Loss进行模型训练。

代码

Predict Module的代码在文件./model/BASNet.py中类BASNet中,Refine Module的代码在文件./model/BASNet.py中类RefUnet中。

Deep Supervision

直白来讲,Deep Supervision即使用神经网络中多个层的Loss之和进行梯度下降。

如前文中BASNet结构图所示,BASNet作者计算了Predict Module中的7层和Refine Module中的最后1层的Loss并进行求和,然后进行梯度下降,以此实现Deep Supervision。在计算边路输出时,需要进行上采样和卷积使得边路输出的尺寸、通道数与输入相同。

Deep Supervision的代码在文件./model/BASNet.py的类BASNet的函数forward()中,可知类BASNetforward()时返回了8个边路输出,后继计算这8层的Hybrid Loss并求和进行梯度下降。

Hybrid Loss

直白来讲,Hybrid Loss即在计算损失时使用BCE Loss、SSIM Loss、IOU Loss这3个损失之和而非只使用BCE损失函数。

Hybrid Loss的代码在文件./basnet_trin.py中的函数muti_bce_loss_fusion()中。该函数的输入为BASNet的8个边路输出和输入对应的标注,该函数使用函数bce_ssim_loss()计算1个边路输出与标注的3种Loss之和。


Github(github.com):@chouxianyu

Github Pages(github.io):@臭咸鱼

知乎(zhihu.com):@臭咸鱼

博客园(cnblogs.com):@臭咸鱼

B站(bilibili.com):@绝版臭咸鱼

微信公众号:@臭咸鱼

转载请注明出处,欢迎讨论和交流!