When machine learning models are trained continually on a sequence of tasks, they are liable to forget what they learned on previous tasks -- a phenomenon known as catastrophic forgetting. Proposed solutions to catastrophic forgetting tend to involve storing information about past tasks, meaning that memory usage is a chief consideration in determining their practicality. This paper proposes a memory-efficient solution to catastrophic forgetting, improving upon an established algorithm known as orthogonal gradient descent (OGD). OGD utilizes prior model gradients to find weight updates that preserve performance on prior datapoints. However, since the memory cost of storing prior model gradients grows with the runtime of the algorithm, OGD is ill-suited to continual learning over arbitrarily long time horizons. To address this problem, this paper proposes SketchOGD. SketchOGD employs an online sketching algorithm to compress model gradients as they are encountered into a matrix of a fixed, user-determined size. In contrast to existing memory-efficient variants of OGD, SketchOGD runs online without the need for advance knowledge of the total number of tasks, is simple to implement, and is more amenable to analysis. We provide theoretical guarantees on the approximation error of the relevant sketches under a novel metric suited to the downstream task of OGD. Experimentally, we find that SketchOGD tends to outperform current state-of-the-art variants of OGD given a fixed memory budget.
翻译:摘要:当机器学习模型在连续任务序列上进行持续训练时,它们容易遗忘之前任务学到的知识——这一现象被称为灾难性遗忘。针对灾难性遗忘的现有解决方案通常需要存储过去任务的信息,这意味着内存使用是决定其实用性的主要考量因素。本文提出了一种内存高效的灾难性遗忘解决方案,改进了已知算法——正交梯度下降(OGD)。OGD利用先前的模型梯度来寻找能保持先前数据点性能的权重更新。然而,由于存储先前模型梯度的内存成本随算法运行时间增长,OGD不适合在任意长时间跨度上进行持续学习。为解决此问题,本文提出SketchOGD。SketchOGD采用在线草图算法,将遇到的模型梯度压缩为固定用户指定大小的矩阵。与现有内存高效的OGD变体相比,SketchOGD无需预先知道总任务数量即可在线运行,实现简单且更易于分析。我们在适用于OGD下游任务的新度量下,为相关草图的近似误差提供了理论保证。实验表明,在固定内存预算下,SketchOGD往往优于当前最先进的OGD变体。