封面《殻ノ少女》

前言

最近读了 UNETR 的论文,需要了解这个模型,因此就记录一下学习笔记。UNETR 是一种将 UNet 和 Transformer 结合的网络结构,可以用于图像分割任务。

结构

传统 CNN 网络在分割领域有了很大的成功,但是 CNN 因其结构原因在全局特征信息和长范围空间依赖性上表现较差,这对分割结果有很大的影响。

Transformer 因其 attention 机制,在全局特征方面有很好的表现,因此 UNETR 提出了一种新的网络结构,将 UNet 和 Transformer 结合,如下图所示。
其相对其他以往的分割网络有一下特征

  • 直接为 3D 分割量身定制,直接使用体积数据
  • 直接将 transformer 作为网络的编码器,而非使用其作为网络中的 attention 层。所有 ViT 中间层的输出大小都是一样的
  • 不依赖 CNN backbone 生成输入序列,而是直接使用 tokenized patch

UNETR

在 backbone 方面完全采用了 ViT 的结构,其结构与 Transformer 一样,patch embedding 采用了卷积网络结构,patch size 和 stride size 都为 16,并将大小为 (N,C,H,W,D)(N,C,H,W,D) 的图像转换为 (N,L,C)(N,L,C),其中 L=H4×W4×D4L=\frac{H}{4} × \frac{W}{4} × \frac{D}{4}

在特征金字塔的路径上先将 ViT 中 3、6、9、12 层的输出从 (N,L,C)(N,L,C) reshape 成 (N,C,H16,W16,D16)(N,C,\frac{H}{16},\frac{W}{16},\frac{D}{16}) 再进行反卷积上采样,对于下方的特征则是直接反卷积上采样,两者特征直接拼接再卷积提取特征。

损失函数

损失函数方面采用 soff dice loss + CE loss,其定义如下

L(G,Y)=12JΣj=1JΣi=1IGi,jYi,jΣi=1IGi,j+Σi=1IYi,j1IΣi=1IΣj=1JGi,jlog(Yi,j) L(G,Y) = 1 - \frac{2}{J} \Sigma_{j=1}^J \frac{\Sigma_{i=1}^I G_{i,j}Y_{i,j} }{\Sigma_{i=1}^I G_{i,j} + \Sigma_{i=1}^I Y_{i,j}} - \frac{1}{I} \Sigma_{i=1}^I \Sigma_{j=1}^J G_{i,j}log(Y_{i,j})

其中 II 是体素数量,JJ 是类别数量,GG 是 ground truth,YY 是预测值。

后记

笔者在 ISLES2022 数据集上以输入大小 (2,96,96,96),优化器为 adamw,损失函数为 dice loss 的情况下训练了 UNETR,是目前各个网络这效果最好的,其 dice 系数在 79.9%。但是由于使用了 transformer 其计算量和显存开销依然不可忽视。

参考文献

UNETR

ISLES