Recently, Vision Transformer (ViT) has achieved promising performance in image recognition and gradually serves as a powerful backbone in various vision tasks. To satisfy the sequential input of Transformer, the tail of ViT first splits each image into a sequence of visual tokens with a fixed length. Then the following self-attention layers constructs the global relationship between tokens to produce useful representation for the downstream tasks. Empirically, representing the image with more tokens leads to better performance, yet the quadratic computational complexity of self-attention layer to the number of tokens could seriously influence the efficiency of ViT's inference. For computational reduction, a few pruning methods progressively prune uninformative tokens in the Transformer encoder, while leaving the number of tokens before the Transformer untouched. In fact, fewer tokens as the input for the Transformer encoder can directly reduce the following computational cost. In this spirit, we propose a Multi-Tailed Vision Transformer (MT-ViT) in the paper. MT-ViT adopts multiple tails to produce visual sequences of different lengths for the following Transformer encoder. A tail predictor is introduced to decide which tail is the most efficient for the image to produce accurate prediction. Both modules are optimized in an end-to-end fashion, with the Gumbel-Softmax trick. Experiments on ImageNet-1K demonstrate that MT-ViT can achieve a significant reduction on FLOPs with no degradation of the accuracy and outperform other compared methods in both accuracy and FLOPs.
翻译:近期,视觉Transformer(ViT)在图像识别领域取得了显著性能,并逐渐成为各类视觉任务中的强大主干网络。为满足Transformer的序列输入要求,ViT的尾部首先将每幅图像分割为固定长度的视觉令牌序列。随后,自注意力层构建令牌间的全局关系,为下游任务生成有效表征。实验表明,用更多令牌表示图像可提升性能,但自注意力层对令牌数量的二次计算复杂度会严重影响ViT的推理效率。为降低计算量,现有剪枝方法逐步移除Transformer编码器中无信息量的令牌,但未改变Transformer前的令牌数量。实际上,减少Transformer编码器的输入令牌数可直接降低后续计算成本。基于此,本文提出一种多尾视觉Transformer(MT-ViT)。MT-ViT采用多分支尾部为后续Transformer编码器生成不同长度的视觉序列,并引入尾部预测器判断图像的最优尾部以实现高效预测。两个模块通过Gumbel-Softmax技巧以端到端方式联合优化。在ImageNet-1K上的实验表明,MT-ViT在保持精度不下降的前提下显著降低FLOPs,并在精度与计算量两方面均优于对比方法。