The recently proposed Forgetting Transformer (FoX) incorporates a forget gate into softmax attention and has shown consistently better or on-par performance compared to the standard RoPE-based Transformer. Notably, many attention heads in FoX tend to forget quickly, causing their output at each timestep to rely primarily on the local context. Based on this observation, we propose Adaptive Computation Pruning (ACP) for FoX, a method that dynamically prunes computations involving input-output dependencies that are strongly decayed by the forget gate. This is achieved using a dynamically set pruning threshold that ensures that the pruned attention weights remain negligible. We apply ACP to language model pretraining with FoX and show it consistently reduces the number of FLOPs in softmax attention by around 70% across different model sizes and context lengths, resulting in a roughly 10% to 35% improvement in training throughput. Furthermore, longer context lengths yield greater computational savings. All these speed improvements are achieved without any performance degradation. We also perform several analyses to provide deeper insights into our method, such as examining the pruning patterns and analyzing the distribution of FLOP savings across different attention heads. Our code is available at https://github.com/zhixuan-lin/arctic-fox.
翻译:最近提出的遗忘Transformer(FoX)将遗忘门机制引入softmax注意力中,相比基于RoPE的标准Transformer展现出持续更优或相当的性能。值得注意的是,FoX中的许多注意力头倾向于快速遗忘,导致其在每个时间步的输出主要依赖于局部上下文。基于这一观察,我们为FoX提出自适应计算剪枝(ACP)方法,该方法动态剪除那些被遗忘门强烈衰减的输入-输出依赖所涉及的计算。这是通过动态设置的剪枝阈值实现的,该阈值确保被剪枝的注意力权重始终保持可忽略状态。我们将ACP应用于FoX的语言模型预训练,结果表明该方法在不同模型规模和上下文长度下,能持续将softmax注意力的FLOPs数量减少约70%,从而使训练吞吐量提升约10%至35%。此外,更长的上下文长度会带来更大的计算节省。所有这些速度提升均未导致任何性能下降。我们还进行了多项分析以深入理解该方法,例如检查剪枝模式、分析不同注意力头间FLOPs节省量的分布等。代码已发布于https://github.com/zhixuan-lin/arctic-fox。