Deep neural networks perform well on classification tasks where data streams are i.i.d. and labeled data is abundant. Challenges emerge with non-stationary training data streams such as continual learning. One powerful approach that has addressed this challenge involves pre-training of large encoders on volumes of readily available data, followed by task-specific tuning. Given a new task, however, updating the weights of these encoders is challenging as a large number of weights needs to be fine-tuned, and as a result, they forget information about the previous tasks. In the present work, we propose a model architecture to address this issue, building upon a discrete bottleneck containing pairs of separate and learnable key-value codes. Our paradigm will be to encode; process the representation via a discrete bottleneck; and decode. Here, the input is fed to the pre-trained encoder, the output of the encoder is used to select the nearest keys, and the corresponding values are fed to the decoder to solve the current task. The model can only fetch and re-use a sparse number of these key-value pairs during inference, enabling localized and context-dependent model updates. We theoretically investigate the ability of the discrete key-value bottleneck to minimize the effect of learning under distribution shifts and show that it reduces the complexity of the hypothesis class. We empirically verify the proposed method under challenging class-incremental learning scenarios and show that the proposed model - without any task boundaries - reduces catastrophic forgetting across a wide variety of pre-trained models, outperforming relevant baselines on this task.
翻译:深度神经网络在数据流独立同分布且标注数据充足的分类任务上表现良好。然而,当面临非平稳训练数据流(如持续学习)时,挑战随之出现。应对这一挑战的有效方法之一,是在大量易得数据上预训练大型编码器,随后针对特定任务进行微调。但对于新任务而言,更新这些编码器的权重颇具挑战,因为大量权重需要微调,进而导致模型遗忘先前任务的信息。本文提出一种模型架构来解决该问题,该架构基于包含独立可学习键-值对离散瓶颈。我们的范式是:编码,通过离散瓶颈处理表征,再解码。具体而言,输入被馈送至预训练编码器,编码器输出用于选择最近键,而对应的值则被馈送至解码器以解决当前任务。在推理过程中,模型仅能获取并重用少量稀疏的键-值对,从而实现局部化、上下文相关的模型更新。我们从理论上研究了离散键-值瓶颈在分布偏移下最小化学习影响的能力,表明其降低了假设类的复杂度。我们在具有挑战性的类增量学习场景下实证验证了所提方法,并证明该模型——无需任何任务边界——能减少多种预训练模型上的灾难性遗忘,在此任务上优于相关基线方法。