Despite its flexibility to learn diverse inductive biases in machine learning programs, meta learning (i.e., learning to learn) has long been recognized to suffer from poor scalability due to its tremendous compute/memory costs, training instability, and a lack of efficient distributed training support. In this work, we focus on making scalable meta learning practical by introducing SAMA, which combines advances in both implicit differentiation algorithms and systems. Specifically, SAMA is designed to flexibly support a broad range of adaptive optimizers in the base level of meta learning programs, while reducing computational burden by avoiding explicit computation of second-order gradient information, and exploiting efficient distributed training techniques implemented for first-order gradients. Evaluated on multiple large-scale meta learning benchmarks, SAMA showcases up to 1.7/4.8x increase in throughput and 2.0/3.8x decrease in memory consumption respectively on single-/multi-GPU setups compared to other baseline meta learning algorithms. Furthermore, we show that SAMA-based data optimization leads to consistent improvements in text classification accuracy with BERT and RoBERTa large language models, and achieves state-of-the-art results in both small- and large-scale data pruning on image classification tasks, demonstrating the practical applicability of scalable meta learning across language and vision domains.
翻译:尽管元学习(即学会如何学习)具有在机器学习程序中学习多样化归纳偏置的灵活性,但长期以来人们认识到其存在可扩展性差的缺陷,这主要源于其巨大的计算/内存开销、训练不稳定性,以及缺乏高效的分布式训练支持。在本工作中,我们通过引入SAMA(融合了隐式微分算法与系统的进展)来致力于让可扩展元学习变得实用。具体而言,SAMA被设计为能够灵活支持元学习程序基础层中的广泛自适应优化器,同时通过避免显式计算二阶梯度信息来降低计算负担,并利用针对一阶梯度实现的高效分布式训练技术。在多个大规模元学习基准上的评估表明,与基线元学习算法相比,SAMA在单GPU和多GPU设置下分别实现了高达1.7/4.8倍的吞吐量提升以及2.0/3.8倍的内存消耗降低。此外,我们展示了基于SAMA的数据优化能够持续提升BERT和RoBERTa大语言模型在文本分类任务上的准确率,并在图像分类任务的小规模和大规模数据剪枝中均取得了最优结果,从而证明了可扩展元学习在语言与视觉领域的实用适用性。