Group equivariance ensures consistent responses to group transformations of the input, leading to more robust models and enhanced generalization capabilities. However, this property can lead to overly constrained models if the symmetries considered in the group differ from those observed in data. While common methods address this by determining the appropriate level of symmetry at the dataset level, they are limited to supervised settings and ignore scenarios in which multiple levels of symmetry co-exist in the same dataset. For instance, pictures of cars and planes exhibit different levels of rotation, yet both are included in the CIFAR-10 dataset. In this paper, we propose a method able to detect the level of symmetry of each input without the need for labels. To this end, we derive a sufficient and necessary condition to learn the distribution of symmetries in the data. Using the learned distribution, we generate pseudo-labels that allow us to learn the levels of symmetry of each input in a self-supervised manner. We validate the effectiveness of our approach on synthetic datasets with different per-class levels of symmetries e.g. MNISTMultiple, in which digits are uniformly rotated within a class-dependent interval. We demonstrate that our method can be used for practical applications such as the generation of standardized datasets in which the symmetries are not present, as well as the detection of out-of-distribution symmetries during inference. By doing so, both the generalization and robustness of non-equivariant models can be improved. Our code is publicly available at https://github.com/aurban0/ssl-sym.
翻译:群等变性确保了对输入群变换的一致响应,从而产生更鲁棒的模型并增强泛化能力。然而,当群中考虑的对称性与数据中观察到的对称性不一致时,这一特性可能导致模型过度受限。现有方法通常通过确定数据集层面的适当对称性水平来解决此问题,但仅限于监督学习场景,且忽略了同一数据集中存在多种对称性水平的情况。例如,汽车和飞机的图像呈现不同的旋转对称性水平,但两者均包含在CIFAR-10数据集中。本文提出了一种无需标签即可检测每个输入对称性水平的方法。为此,我们推导出学习数据中对称性分布的充分必要条件。利用学得的分布,我们生成伪标签,从而以自监督方式学习每个输入的对称性水平。我们在具有不同类别内对称性水平的合成数据集(如MNISTMultiple,其中数字在类别相关的区间内均匀旋转)上验证了方法的有效性。实验表明,该方法可应用于实际场景,例如生成无对称性的标准化数据集,以及推理过程中检测分布外对称性。通过这种方式,非等变模型的泛化能力和鲁棒性均能得到提升。我们的代码开源在https://github.com/aurban0/ssl-sym。