We present Token-UNet, adopting the TokenLearner and TokenFuser modules to encase Transformers into UNets. While Transformers have enabled global interactions among input elements in medical imaging, current computational challenges hinder their deployment on common hardware. Models like (Swin)UNETR adapt the UNet architecture by incorporating (Swin)Transformer encoders, which process tokens that each represent small subvolumes ($8^3$ voxels) of the input. The Transformer attention mechanism scales quadratically with the number of tokens, which is tied to the cubic scaling of 3D input resolution. This work reconsiders the role of convolution and attention, introducing Token-UNets, a family of 3D segmentation models that can operate in constrained computational environments and time frames. To mitigate computational demands, our approach maintains the convolutional encoder of UNet-like models, and applies TokenLearner to 3D feature maps. This module pools a preset number of tokens from local and global structures. Our results show this tokenization effectively encodes task-relevant information, yielding naturally interpretable attention maps. The memory footprint, computation times at inference, and parameter counts of our heaviest model are reduced to 33\%, 10\%, and 35\% of the SwinUNETR values, with better average performance (86.75\% $\pm 0.19\%$ Dice score for SwinUNETR vs our 87.21\% $\pm 0.35\%$). This work opens the way to more efficient trainings in contexts with limited computational resources, such as 3D medical imaging. Easing model optimization, fine-tuning, and transfer-learning in limited hardware settings can accelerate and diversify the development of approaches, for the benefit of the research community.
翻译:我们提出Token-UNet,采用TokenLearner与TokenFuser模块将Transformer封装到UNet中。尽管Transformer已在医学成像中实现了输入元素间的全局交互,但当前的计算挑战阻碍了其在常见硬件上的部署。诸如(Swin)UNETR等模型通过引入(Swin)Transformer编码器来改进UNet架构,这些编码器处理的每个令牌代表输入的小子体积($8^3$体素)。Transformer注意力机制的计算复杂度随令牌数量呈二次方增长,而令牌数量又与3D输入分辨率的三次方缩放相关。本研究重新思考了卷积与注意力的作用,提出了Token-UNet系列——一类能在受限计算环境与时间框架下运行的3D分割模型。为降低计算需求,我们的方法保留了类UNet模型的卷积编码器,并将TokenLearner应用于3D特征图。该模块从局部与全局结构中池化预设数量的令牌。实验结果表明,这种令牌化方法能有效编码任务相关信息,并生成天然可解释的注意力图。我们最重模型的内存占用、推理计算时间及参数量分别降至SwinUNETR的33%、10%和35%,同时获得更优的平均性能(SwinUNETR的Dice分数为86.75% $\pm 0.19%$,我们的模型为87.21% $\pm 0.35%$)。这项工作为在计算资源受限的场景(如3D医学成像)中实现更高效的训练开辟了道路。在有限硬件设置下简化模型优化、微调与迁移学习,可加速并丰富方法开发,惠及研究社区。