BASNet论文详解

[toc]

基本信息

  • 论文名称

    BASNet: Boundary-Aware Salient Object Detection

  • 作者

    Xuebin Qin等

  • 发表时间

    2019年

  • 来源

    CVPR2019

主要收获

  • 知识

    • 本文混合损失函数中的3种loss

    • 本文核心是混合损失函数以及refine module,其中refine module借用了ResNet中的思路:$S_{refined}=S_{coarse}+S_{residual}$。

    • 如图3所示,coarse saliency map存在2方面的问题:①the blurry and noisy boundaries(边界不准确、不清晰), ②the unevenly predicted regional probabilities(同类区域中像素概率不均匀)

    • RRM_LC和RRM_MS等模块比较shallow,所以很难捕捉到可用于refinement的high level information。

    • 评估指标

      • PR Curve

        PR Curve是1种评估预测所得显著性图的标准方式。1张显著性图的precision和recall通过比较二值化的显著性图及其ground truth计算。1个二值化threshold对应的1对precision和recall是数据集中所有显著性图的平均precision和recall。threshold从0到1变化,可以得到1个precision-recall pair序列,画出来就是PR Curve。

      • F-measure

        F-measure可以全面地衡量precision和recall,其基于1对precision和recall进行计算。在进行算法比较时,一般采用最大的F-measure进行比较。

      • MAE

        MAE指显著性图与其ground truth的average absolute per-pixel difference。模型对于1个数据集的MAE为所有显著性图的MAE的平均值

  • 一些未知的东西

    • U-Net
    • SegNet
    • dilation convolution
    • 感受野的计算方法

Abstract

  • 背景:多数前人工作只关注region accuracy而非boundary quality。
  • 本文提出:a predict-refine architecture: BASNet and a new hybrid loss for Boundary-Aware Salient object detection
    • BASNet
      • a densely supervised Encoder-Decoder network, in charge of saliency prediction
      • a residual refinement module(RRM), in charge of saliency map refinement
    • The hybrid loss
      • 3级层次结构:pixel-level, patch-level, map-level
      • 方式:混合3种loss,Binary Cross Entropy (BCE), Structural SIMilarity (SSIM) and Intersection-over-Union (IoU) losses
  • 效果:effectively segment the salient object regions and accurately predict the fine structures with clear boundaries
    • 精度:在6个数据集上,ofregional and boundary evaluation measures超过SOTA。
    • 速度:over 25 fps on a single GPU
  • Code:https://github.com/NathanUA/BASNet

1. Inroducion

  • 背景

    近期FCN被用于显著性目标检测,性能优于传统算法,但their predicted saliency maps are still defective in fine structures and/or boundaries,如图1(c)和(d)所示。

  • two main challenges in accurate salient object detection

    1. network:networks that aggregate multi-level deep features are needed
    2. loss:models trained with CE loss usually have low confidence in differentiating boundary pixels, leading to blurry boundaries. Other losses such as Intersection over Union (IoU) loss, F-measure loss and Dice-score loss are not specifically designed for capturing fine structures
  • BASNet: the prediction module and RRM

    To capture both global (coarse) and local (fine) contexts,a new predict-refine network is proposed。

    • the prediction module

      a U-Net-like densely supervised Encoder-Decoder network,transfers the input image to a probability map

    • RRM

      a novel residual refinement module,refines the predicted map by learning the residuals between the coarse saliency map and ground truth

  • the hybrid loss

    To obtain high confidence saliency map and clear boundary, we propose a hybrid loss that combines Binary Cross Entropy (BCE), Structural SIMilarity (SSIM) and IoU losses, which are expected to learn from ground truth information in pixel-, patch- and map- level, respectively.

    Rather than using explicit boundary losses (NLDF+ , C2S ), we implicitly inject the goal of accurate boundary prediction in the hybrid loss, contemplating that it may help reduce spurious error from cross propagating the information learned on the boundary and the other regions on the image。

    毕竟任务不是边缘检测,所以不能只检测边缘。

2. Related Works

  • Traditional Methods

    早期方法根据1个预定义的基于手工特征计算的saliency measure搜索像素以检测显著目标。Borji等人提供了1个全面的综述。

  • Patch-wise Deep Methods

    受深度卷积神经网络推动图片分类发展的启发,早期的deep方法基于从单/多尺度提取的local image patches将image pixels和super pixels分类为显著/非显著。这些方法通常生成coarse outputs因为在全连接层中spatial information会丢失

  • FCN-based Methods

    相比于Traditional Methods和Patch-wise Deep Methods,基于FCN的方法取得了重要进展,因为FCN is able to capture richer spatial and multiscale information

    UCF、Amulet、NLDF+、DSS+、HED、RAS、LFR、BMPM。

  • Deep Recurrent and attention Methods

    PAGRN、RADF+、RFCNPiCANetR

  • Coarse to Fine Deep Methods

    为捕捉finer structures和more accurate boundaries,学者提出了大量refinement策略。

    SRM、R3Net+、DGRL

3. BASNet

3.1. Overview of Network Architecture

BASNet包括2个Module:Predict Module和Refine Module,如图2所示。

  • Predict Module

    a U-Net-like densely supervised Encoder-Decoder network,作用是predict saliency map from input images

  • Refine Module

    Residual Refinement Module,refines the resulting saliency map of the prediction module by learning the residuals between the saliency map and the ground truth。

3.2. Predict Module

  • 整体

    • 受U-Net和SegNet启发,predict module是1个Encoder-Decoder网络,这种结构可以同时捕捉high level global contexts and low level details
    • 受HED启发,为减少过拟合,通过ground truth对decoder每个stage的最后1层进行监督,如图2所示。
  • encoder

    encoder部分包含1个输入卷积层6个由basic res-blocks组成的stage,其中输入卷积层和前4个stage是修改过的ResNet34

    改动为本文中的输入卷积层有64个步长为1的3×3卷积核而非步长为2的7×7卷积核,并且在输入卷积层之后没有pooling,这使得在前几层获得更大尺寸的特征图但也减小了整体的感受野。

    为获得和ResNet34相同的感受野,本文在ResNet34又加了2个stage,每个stage均由size为2的不重叠maxpooling层及3个512filter的basic res-block组成。

  • bridge

    为进一步捕捉global infomation,本文在encoder和decoder之间添加了1个bridge

    该bridge包括3个512核的dilation为2的3×3卷积层,其中每个卷积层后都有1个BN层和ReLU。

  • decoder

    decoder几乎和encoder完全对称。decoder的每个stage包含3个卷积层,每个卷积层后有1个BN层和ReLU。

    每个stage的输入是其前1个stage的输出和其对应的encoder中的stage的输出的concatenate结果。

    为得到side-output saliency maps,bridge和decoder中每个stage的输出被进行处理,处理过程为:1个3×3卷积、上采样、sigmoid。因此输入1张图片,本文的predict module在训练过程中输出7个saliency maps,其中最后1个saliency map的accuracy最高,所以其作为predict module的最终输出传入refine module。

3.3. Refine Module

  • Refinement Module通常被定义为1个residual block,其通过学习saliency map和ground truth之间的residual$S_{residual}$来refine预测得到的coarse saliency map$S_{coarse}$,公式为$S_{refined}=S_{coarse}+S_{residual}$。
  • 如图3所示,coarse包含2方面的含义:①the blurry and noisy boundaries, ②the unevenly predicted regional probabilities,模型预测所得的coarse saliency map中这2方面的问题都会有。
  • RRM_LC(residual refinement module based on local context)起初被提出是用于boundary refinement,因为其感受野较小,Islam和Deng等人iteratively或recurrently在不同尺度上使用它refine saliency maps。Wang等人使用了PPM(pyramid pooling module),其中3个尺度的pyramid pooling features被concatenate。为避免池化操作导致细节损失,RRM_MS使用kernel size和dilation不同的卷积层捕捉multi-scale contexts。然而这些模块是shallow的,所以很难捕捉到可用于refinement的high level information
  • 本文的RRM和predict module结构相似但简单很多,其包含1个输入层、1个encoder、1个bridge、1个decoder和1个输出层。encoder和decoder都包含4个stage每个stage只包含1个64核的3×3卷积层,每个卷积层后面都有1个BN层和1个ReLU。bridge和1个stage结构相同,也包含1个64核的3×3卷积层(后面跟着1个BN层和1个ReLU)。encoder中下采样时使用maxpooling,decoder中上采样时使用bilinear interpolation。RRM的输出即本文整个模型最终的输出。

3.4. Hybrid Loss

  • 训练中的Loss是各个side output的loss的加权和,每个side output的loss公式都是1个hybrid loss(3种loss之和)。本文模型对8个side output进行深度监督,其中7个side output来自Predict Module、1个side output来自Refine Module。

    • BCE Loss:在二分类和分割任务中,二值交叉熵损失函数(Binary Cross Entropy Loss,BCE Loss)是最常用的损失函数。公式略
    • SSIM Loss:结构相似损失(Structural Similarity Loss,SSIM Loss)在提出时被用于图像质量评价。它可以捕捉一张图片中的结构信息,因此本文将其集成于混合损失函数中,以获取显著目标标注中的结构信息。公式略
    • IoU Loss:交并比损失(Intersection over Union Loss,IoU Loss)在提出时被用来衡量2个集合的相似性,后来被作为目标检测和分割的标准评估指标。最近,它也被用在了显著性目标检测的训练中。公式略
  • 本文阐述了3种loss的作用,图5中的热力图展示了每个像素的loss随训练过程的变化,3列分别代表训练过程中的不同阶段,3行分别是不同的loss。

    • BCE Loss是pixel-level的measure,它并不考虑周围像素的label并且foreground像素和background像素的权重相同,有助于所有像素的收敛

    • SSIM Loss是patch-level的measure,它考虑每个像素的local neighborhood,对边界具有更高的权重(即使预测前景的概率相同,但边界附近的loss比前景中心的loss更高)。

      在训练过程的初始阶段,边界周围像素的loss是最大的(见图5第2行),这帮助集中于边界附近像素的收敛。随着训练过程,foreground的SSIM Loss减小而background的loss成为主导项。但只有当background的预测非常接近ground truth(0)时 background的loss才会起作用,这时loss会从1急速下降到0,因为通常只有在训练晚期BCE loss平滑(flat)时background的预测才会到0。SSIM Loss保证有足够的梯度使得网络继续学习。因为预测被push到0,background的预测看起来会更clean。

    • IoU Loss是map-level的measure,但是本文为了阐述所以根据式6画出了每个像素的IoU Loss。

      随着foreground的预测越来越接近1,foreground的loss最终变成0。

      把3个loss混合起来,利用BCE使每个像素都有smooth gradient,利用IoU给予foreground更多注意力,通过SSIM基于图像结构使得边界的loss更大

4. Experimental Results

4.1. Datasets

在6个常用数据集上对模型进行了评估。具体是哪些数据集、这些数据集有哪些特点请见原文。

4.2. Implementation and Experimental Setup

  • 训练

    • 使用DUTS-TR训练模型,训练前进行离线数据增强(将图片水平翻转)。
    • 将图片resize到256×256并随机裁剪成224×224。
    • encoder的部分参数使用ResNet34的预训练模型进行初始化,其它卷积层通过Xavier初始化。
    • 使用Adam进行训练,超参数为默认值(初始学习率1e-3,betas=(0.9, 0.999),eps=1e-8,weight decay=0)
    • 一直训练到loss收敛,不使用validation set。最终经历了400k iterations,batch size为8,耗时125小时c
  • 测试/推理

    将原图resize到256×256,再将最终得到的显著图resize back到原图大小

  • 软/硬件环境

    训练和测试的软硬件环境一致。

    PyTorch0.4.0,An eight-core PC with an AMD Ryzen 1800x 3.5 GHz CPU (with 32GB RAM) and a GTX 1080ti GPU (with 11GB memory)

    256×256图片的推理耗时为0.04秒。

4.3. Evaluation Metrics

PR Curve、F-measure、MAE、relaxed F-measure of boundary ($relaxF^b_{\beta}$)。对这几个评估指标的具体介绍请见原文。

  • PR Curve

    PR Curve是1种评估预测所得显著性图的标准方式。1张显著性图的precision和recall通过比较二值化的显著性图及其ground truth计算。1个二值化threshold对应的1对precision和recall是数据集中所有显著性图的平均precision和recall。threshold从0到1变化,可以得到1个precision-recall pair序列,画出来就是PR Curve。

  • F-measure

    F-measure可以全面地衡量precision和recall,其基于1对precision和recall进行计算。在进行算法比较时,一般采用最大的F-measure进行比较。

  • MAE

    MAE指显著性图与其ground truth的average absolute per-pixel difference。模型对于1个数据集的MAE为所有显著性图的MAE的平均值

  • relaxed F-measure of boundary

    此处省略,请见原文。

4.4. Ablation Study

这个section验证了本文模型中的每个关键component。消融实验包括architecture ablation和loss ablation2个部分。消融实验是在ECSSD数据集上进行的。

  • architecture ablation

    为证明BASNet的有效性,本文提供了BASNet与其它相关结构的量化对比结果。

    Loss都使用BCE,首先以U-Net作为baseline,然后是Encoder-Decoder Network、Deep Supervision、RRM_LC、RRM_MS、RRM_Ours,结果如表1所示,可见本文提出的BASNet在这5个实验中性能最优

  • loss ablation

    为阐述本文提出的fusion loss的有效性,本文基于BASNet使用不同loss进行了1系列实验。表1中的结果证明本文提出的hybrid loss极大地提升了性能,特别是边界的质量。

    为进一步阐述损失函数对于BASNet预测质量的影响,使用不同Loss对BASNet进行训练的结果如图7所示,很明显可以看出本文提出的混合损失函数达到了最优的质量

4.5. Comparison with State-of-the-arts

和15个SOTA算法进行比较。公平起见,使用原文作者提供的显著性图或者使用原文作者公开的模型。

  • Quantitative evaluation

    图6展示了在5个数据集上的PR曲线和F-measure曲线,表2展示了在6个数据集上的maximum region-based F-measure、MAE、the relaxed boundary Fmeasure。数据提了很多个percent(略)

  • Qualitative evaluation

    图8展示对比了8种算法对不同类型图片的识别效果图,图片类型有images with low contrast、fine structures、large object touching image boundaries、complex object boundaries。

5. Conclusion


Github(github.com):@chouxianyu

Github Pages(github.io):@臭咸鱼

知乎(zhihu.com):@臭咸鱼

博客园(cnblogs.com):@臭咸鱼

B站(bilibili.com):@绝版臭咸鱼

微信公众号:@臭咸鱼

转载请注明出处,欢迎讨论和交流!