Recently, denoising diffusion probabilistic models (DDPM) have been applied to image segmentation by generating segmentation masks conditioned on images, while the applications were mainly limited to 2D networks without exploiting potential benefits from the 3D formulation. In this work, for the first time, DDPMs are used for 3D multiclass image segmentation. We make three key contributions that all focus on aligning the training strategy with the evaluation methodology, and improving efficiency. Firstly, the model predicts segmentation masks instead of sampled noise and is optimised directly via Dice loss. Secondly, the predicted mask in the previous time step is recycled to generate noise-corrupted masks to reduce information leakage. Finally, the diffusion process during training was reduced to five steps, the same as the evaluation. Through studies on two large multiclass data sets (prostate MR and abdominal CT), we demonstrated significantly improved performance compared to existing DDPMs, and reached competitive performance with non-diffusion segmentation models, based on U-net, within the same compute budget. The JAX-based diffusion framework has been released on https://github.com/mathpluscode/ImgX-DiffSeg.
翻译:近期,去噪扩散概率模型通过生成以图像为条件的分割掩码被应用于图像分割,但相关应用主要局限于二维网络,未能充分利用三维公式的潜在优势。本研究首次将去噪扩散概率模型应用于三维多类图像分割。我们提出三项关键贡献,均聚焦于训练策略与评估方法的一致性及效率提升:其一,模型直接预测分割掩码而非采样噪声,并通过Dice损失进行直接优化;其二,利用前一时间步的预测掩码生成噪声污染掩码以减少信息泄漏;其三,将训练过程中的扩散步数缩减至五步,与评估阶段保持一致。通过在两个大规模多类数据集(前列腺MR和腹部CT)上的实验,我们证实该方法相较于现有去噪扩散概率模型性能显著提升,并在相同计算预算下达到基于U-net的非扩散分割模型的竞争性能。基于JAX的扩散框架已在https://github.com/mathpluscode/ImgX-DiffSeg开源。