We present FIT: a transformer-based architecture with efficient self-attention and adaptive computation. Unlike original transformers, which operate on a single sequence of data tokens, we divide the data tokens into groups, with each group being a shorter sequence of tokens. We employ two types of transformer layers: local layers operate on data tokens within each group, while global layers operate on a smaller set of introduced latent tokens. These layers, comprising the same set of self-attention and feed-forward layers as standard transformers, are interleaved, and cross-attention is used to facilitate information exchange between data and latent tokens within the same group. The attention complexity is $O(n^2)$ locally within each group of size $n$, but can reach $O(L^{{4}/{3}})$ globally for sequence length of $L$. The efficiency can be further enhanced by relying more on global layers that perform adaptive computation using a smaller set of latent tokens. FIT is a versatile architecture and can function as an encoder, diffusion decoder, or autoregressive decoder. We provide initial evidence demonstrating its effectiveness in high-resolution image understanding and generation tasks. Notably, FIT exhibits potential in performing end-to-end training on gigabit-scale data, such as 6400$\times$6400 images, or 160K tokens (after patch tokenization), within a memory capacity of 16GB, without requiring specific optimizations or model parallelism.
翻译:我们提出FIT:一种具有高效自注意力与自适应计算能力的基于变换器的架构。不同于在单一数据令牌序列上操作的原始变换器,我们将数据令牌划分为多个组,每个组为较短的令牌序列。我们采用两种变换器层:局部层对每个组内的数据令牌进行操作,而全局层则对引入的一小组潜在令牌进行操作。这些层(包含与标准变换器相同的自注意力与前馈层)交织排列,并通过交叉注意力促进同一组内数据令牌与潜在令牌之间的信息交换。注意力复杂度在每组规模为$n$时呈局部$O(n^2)$,但对于序列长度$L$可达到全局$O(L^{{4}/{3}})$。通过更依赖使用较小潜在令牌集执行自适应计算的全局层,可进一步提升效率。FIT是一种通用架构,可作为编码器、扩散解码器或自回归解码器。我们提供初步证据,证明其在高分辨率图像理解与生成任务中的有效性。值得注意的是,FIT展现出在无需特定优化或模型并行的情况下,于16GB内存容量内对吉比特级数据(例如6400$\times$6400图像,或经分块令牌化后的16万令牌)进行端到端训练的潜力。