Recently, denoising diffusion probabilistic models (DDPM) have been applied to image segmentation by generating segmentation masks conditioned on images, while the applications were mainly limited to 2D networks without exploiting potential benefits from the 3D formulation. In this work, we studied the DDPM-based segmentation model for 3D multiclass segmentation on two large multiclass data sets (prostate MR and abdominal CT). We observed that the difference between training and test methods led to inferior performance for existing DDPM methods. To mitigate the inconsistency, we proposed a recycling method which generated corrupted masks based on the model's prediction at a previous time step instead of using ground truth. The proposed method achieved statistically significantly improved performance compared to existing DDPMs, independent of a number of other techniques for reducing train-test discrepancy, including performing mask prediction, using Dice loss, and reducing the number of diffusion time steps during training. The performance of diffusion models was also competitive and visually similar to non-diffusion-based U-net, within the same compute budget. The JAX-based diffusion framework has been released at https://github.com/mathpluscode/ImgX-DiffSeg.
翻译:最近,去噪扩散概率模型(DDPM)已被应用于图像分割,通过生成以图像为条件的分割掩膜,但相关应用主要局限于2D网络,未能充分利用3D表述的潜在优势。本文研究基于DDPM的分割模型在两个大型多类别数据集(前列腺MR和腹部CT)上的3D多类分割性能。我们观察到,现有DDPM方法中训练与测试方法的差异导致性能欠佳。为解决这一不一致性,我们提出一种循环方法:基于模型上一时间步的预测生成退化掩膜,而非使用真实标签。该方法相较现有DDPM实现了统计显著的性能提升,且该改进独立于多种用于减少训练-测试差异的现有技术,包括执行掩膜预测、使用Dice损失以及减少训练期间的扩散时间步数。在相同计算预算下,扩散模型的性能与基于U-net的非扩散方法具有竞争力且视觉相似。基于JAX的扩散框架已发布于https://github.com/mathpluscode/ImgX-DiffSeg。