In this work, we study rapid, step-wise improvements of the loss in transformers when being confronted with multi-step decision tasks. We found that transformers struggle to learn the intermediate tasks, whereas CNNs have no such issue on the tasks we studied. When transformers learn the intermediate task, they do this rapidly and unexpectedly after both training and validation loss saturated for hundreds of epochs. We call these rapid improvements Eureka-moments, since the transformer appears to suddenly learn a previously incomprehensible task. Similar leaps in performance have become known as Grokking. In contrast to Grokking, for Eureka-moments, both the validation and the training loss saturate before rapidly improving. We trace the problem back to the Softmax function in the self-attention block of transformers and show ways to alleviate the problem. These fixes improve training speed. The improved models reach 95% of the baseline model in just 20% of training steps while having a much higher likelihood to learn the intermediate task, lead to higher final accuracy and are more robust to hyper-parameters.
翻译:本文研究Transformer在面对多步决策任务时损失函数出现快速、阶梯式改进的现象。我们发现Transformer难以学习中间任务,而CNN在我们研究的任务中不存在此类问题。当Transformer学习中间任务时,它们会在训练损失和验证损失均饱和数百个epoch后,以出乎意料的方式快速实现任务学习。我们将这种快速改进称为"尤里卡时刻",因为Transformer似乎突然学会了此前无法理解的任务。类似的性能跃升现象被称为Grokking。与Grokking不同的是,在尤里卡时刻中,验证损失和训练损失在快速改进前均经历了饱和状态。我们追溯该问题至Transformer自注意力模块中的Softmax函数,并提出了缓解该问题的方法。这些改进措施提升了训练速度。改进后的模型仅需20%的训练步数即可达到基线模型95%的性能,同时学习中间任务的概率显著提高,最终精度更高且对超参数更具鲁棒性。