Deep neural networks perform well on classification tasks where data streams are i.i.d. and labeled data is abundant. Challenges emerge with non-stationary training data streams such as continual learning. One powerful approach that has addressed this challenge involves pre-training of large encoders on volumes of readily available data, followed by task-specific tuning. Given a new task, however, updating the weights of these encoders is challenging as a large number of weights needs to be fine-tuned, and as a result, they forget information about the previous tasks. In the present work, we propose a model architecture to address this issue, building upon a discrete bottleneck containing pairs of separate and learnable key-value codes. Our paradigm will be to encode; process the representation via a discrete bottleneck; and decode. Here, the input is fed to the pre-trained encoder, the output of the encoder is used to select the nearest keys, and the corresponding values are fed to the decoder to solve the current task. The model can only fetch and re-use a sparse number of these key-value pairs during inference, enabling localized and context-dependent model updates. We theoretically investigate the ability of the discrete key-value bottleneck to minimize the effect of learning under distribution shifts and show that it reduces the complexity of the hypothesis class. We empirically verify the proposed method under challenging class-incremental learning scenarios and show that the proposed model - without any task boundaries - reduces catastrophic forgetting across a wide variety of pre-trained models, outperforming relevant baselines on this task.
翻译:深度神经网络在数据流独立同分布且标注数据充足的分类任务上表现良好。然而,当面对非平稳训练数据流(如持续学习)时,挑战随之而来。应对这一挑战的有效方法之一是在大量现有数据上预训练大型编码器,随后进行任务特定的微调。然而,当新任务出现时,更新这些编码器的权重极具挑战性,因为需要微调大量权重,从而导致遗忘先前任务的信息。本文提出了一种模型架构来解决该问题,该架构基于一个包含成对可分离、可学习键值编码的离散瓶颈。我们的范式是:编码,通过离散瓶颈处理表示,再解码。在此过程中,输入被送入预训练编码器,编码器的输出用于选择最近的键,相应的值被送入解码器以解决当前任务。在推理过程中,模型只能获取并重用少量这些键值对,从而实现局部化和上下文相关的模型更新。我们从理论上研究了离散键值瓶颈在分布漂移下最小化学习影响的能力,并证明它降低了假设类的复杂度。我们通过具有挑战性的类增量学习场景对所提方法进行了实证验证,结果表明,所提模型无需任何任务边界即可减少各种预训练模型上的灾难性遗忘,在该任务上优于相关基线方法。