It is well-known that the performance of well-trained deep neural networks may degrade significantly when they are applied to data with even slightly shifted distributions. Recent studies have shown that introducing certain perturbation on feature statistics (\eg, mean and standard deviation) during training can enhance the cross-domain generalization ability. Existing methods typically conduct such perturbation by utilizing the feature statistics within a mini-batch, limiting their representation capability. Inspired by the domain generalization objective, we introduce a novel Adversarial Style Augmentation (ASA) method, which explores broader style spaces by generating more effective statistics perturbation via adversarial training. Specifically, we first search for the most sensitive direction and intensity for statistics perturbation by maximizing the task loss. By updating the model against the adversarial statistics perturbation during training, we allow the model to explore the worst-case domain and hence improve its generalization performance. To facilitate the application of ASA, we design a simple yet effective module, namely AdvStyle, which instantiates the ASA method in a plug-and-play manner. We justify the efficacy of AdvStyle on tasks of cross-domain classification and instance retrieval. It achieves higher mean accuracy and lower performance fluctuation. Especially, our method significantly outperforms its competitors on the PACS dataset under the single source generalization setting, \eg, boosting the classification accuracy from 61.2\% to 67.1\% with a ResNet50 backbone. Our code will be available at \url{https://github.com/YBZh/AdvStyle}.
翻译:众所周知,训练有素的深度神经网络在应用于分布发生轻微偏移的数据时,其性能可能会显著下降。近期研究表明,在训练过程中对特征统计量(如均值和标准差)引入特定扰动可以增强跨领域泛化能力。现有方法通常利用小批量内的特征统计量进行此类扰动,限制了其表征能力。受领域泛化目标的启发,我们提出了一种新颖的对抗性风格增强(ASA)方法,通过对抗训练生成更有效的统计量扰动来探索更广泛的风格空间。具体而言,我们首先通过最大化任务损失来搜索统计量扰动的最敏感方向和强度。通过在训练过程中针对对抗性统计量扰动更新模型,使模型能够探索最坏情况下的领域,从而提升其泛化性能。为便于ASA的应用,我们设计了一个简单而有效的模块——AdvStyle,以即插即用的方式实例化ASA方法。我们在跨领域分类和实例检索任务上验证了AdvStyle的有效性,实现了更高的平均准确率和更低的性能波动。尤其在PACS数据集上的单源泛化设置下,我们的方法显著优于竞争对手,例如使用ResNet50骨干网络将分类准确率从61.2%提升至67.1%。我们的代码将发布在\url{https://github.com/YBZh/AdvStyle}。