Transformers can efficiently learn in-context from example demonstrations. Most existing theoretical analyses studied the in-context learning (ICL) ability of transformers for linear function classes, where it is typically shown that the minimizer of the pretraining loss implements one gradient descent step on the least squares objective. However, this simplified linear setting arguably does not demonstrate the statistical efficiency of ICL, since the pretrained transformer does not outperform directly solving linear regression on the test prompt. In this paper, we study ICL of a nonlinear function class via transformer with nonlinear MLP layer: given a class of \textit{single-index} target functions $f_*(\boldsymbol{x}) = \sigma_*(\langle\boldsymbol{x},\boldsymbol{\beta}\rangle)$, where the index features $\boldsymbol{\beta}\in\mathbb{R}^d$ are drawn from a $r$-dimensional subspace, we show that a nonlinear transformer optimized by gradient descent (with a pretraining sample complexity that depends on the \textit{information exponent} of the link functions $\sigma_*$) learns $f_*$ in-context with a prompt length that only depends on the dimension of the distribution of target functions $r$; in contrast, any algorithm that directly learns $f_*$ on test prompt yields a statistical complexity that scales with the ambient dimension $d$. Our result highlights the adaptivity of the pretrained transformer to low-dimensional structures of the function class, which enables sample-efficient ICL that outperforms estimators that only have access to the in-context data.
翻译:Transformer能够通过示例演示高效地进行上下文学习。现有理论分析大多研究Transformer针对线性函数类的上下文学习能力,通常表明预训练损失的最小化器实现了最小二乘目标上的一步梯度下降。然而,这种简化的线性设定可能无法充分展示上下文学习的统计效率,因为预训练Transformer在测试提示上的表现并未超越直接求解线性回归的方法。本文通过具有非线性MLP层的Transformer研究非线性函数类的上下文学习:给定一类单索引目标函数$f_*(\boldsymbol{x}) = \sigma_*(\langle\boldsymbol{x},\boldsymbol{\beta}\rangle)$,其中索引特征$\boldsymbol{\beta}\in\mathbb{R}^d$采样自$r$维子空间,我们证明通过梯度下降优化的非线性Transformer(其预训练样本复杂度取决于链接函数$\sigma_*$的信息指数)能够以仅依赖于目标函数分布维度$r$的提示长度实现上下文学习;相比之下,任何直接在测试提示上学习$f_*$的算法都会产生与环境维度$d$成比例的统计复杂度。我们的结果突显了预训练Transformer对函数类低维结构的自适应能力,这使得样本高效的上下文学习能够超越仅依赖上下文数据的估计器。