We present FIT: a transformer-based architecture with efficient self-attention and adaptive computation. Unlike original transformers, which operate on a single sequence of data tokens, we divide the data tokens into groups, with each group being a shorter sequence of tokens. We employ two types of transformer layers: local layers operate on data tokens within each group, while global layers operate on a smaller set of introduced latent tokens. These layers, comprising the same set of self-attention and feed-forward layers as standard transformers, are interleaved, and cross-attention is used to facilitate information exchange between data and latent tokens within the same group. The attention complexity is $O(n^2)$ locally within each group of size $n$, but can reach $O(L^{{4}/{3}})$ globally for sequence length of $L$. The efficiency can be further enhanced by relying more on global layers that perform adaptive computation using a smaller set of latent tokens. FIT is a versatile architecture and can function as an encoder, diffusion decoder, or autoregressive decoder. We provide initial evidence demonstrating its effectiveness in high-resolution image understanding and generation tasks. Notably, FIT exhibits potential in performing end-to-end training on gigabit-scale data, such as 6400$\times$6400 images, even without specific optimizations or model parallelism.
翻译:我们提出FIT:一种具有高效自注意力与自适应计算的基于变换器的架构。与在单一数据令牌序列上操作的原始变换器不同,我们将数据令牌划分为若干组,每组为一个较短的令牌序列。我们采用两种类型的变换器层:局部层对各组内的数据令牌进行操作,而全局层则对一组较小的引入隐令牌进行操作。这些层由标准变换器的相同自注意力与前馈层组成,交错排列,并通过交叉注意力促进同组数据令牌与隐令牌间的信息交换。注意力复杂度在大小为$n$的组内为$O(n^2)$,但对于序列长度$L$可在全局达到$O(L^{{4}/{3}})$。通过更多地依赖利用较小隐令牌集进行自适应计算的全局层,可进一步提升效率。FIT是一种多功能架构,可用作编码器、扩散解码器或自回归解码器。我们提供了初步证据,证明其在高清图像理解与生成任务中的有效性。值得注意的是,FIT在千兆级数据(如6400$\times$6400图像)上展现了端到端训练的潜力,即使未采用特定优化或模型并行化。