Variational inference is an approximation framework for Bayesian inference that seeks to improve quantified uncertainty in predictions by optimizing a simplified distribution over parameters to stand in for the full posterior. Capturing model variations that remain consistent with training data enables more robust predictions by reducing parameter sensitivity. This work introduces a fixed-point optimization for variational inference that is applicable when every feasible log density can be expressed as a linear combination of functions from a given basis. In such cases, the optimizer becomes a fixed-point of projective integral updates. When the basis spans univariate quadratics in each parameter, feasible densities are Gaussian and the projective integral updates yield quasi-Newton variational Bayes (QNVB). Other bases and updates are also possible. As these updates require high-dimensional integration, this work first proposes an efficient quasirandom quadrature sequence for mean-field distributions. Each iterate of the sequence contains two evaluation points that combine to correctly integrate all univariate quadratics and, if the mean-field factors are symmetric, all univariate cubics. More importantly, averaging results over short subsequences achieves periodic exactness on a much larger space of multivariate quadratics. The corresponding variational updates require 4 loss evaluations with standard (not second-order) backpropagation to eliminate error terms from over half of all multivariate quadratic basis functions. This integration technique is motivated by first proposing stochastic blocked mean-field quadratures, which may be useful in other contexts. A PyTorch implementation of QNVB allows for better control over model uncertainty during training than competing methods. Experiments demonstrate superior generalizability for multiple learning problems and architectures.
翻译:变分推断是贝叶斯推断的一种近似框架,通过优化一个简化分布来替代完整后验分布,从而提升预测中量化不确定性的准确性。捕捉与训练数据一致的模型变化,可以通过降低参数敏感性实现更稳健的预测。本文提出一种适用于变分推断的定点优化方法,该方法的适用条件为:每个可行的对数密度均可表示为给定基函数的线性组合。在此类情况下,优化器将成为投影积分更新的不动点。当基函数覆盖每个参数的单变量二次函数时,可行密度为高斯分布,投影积分更新可转化为拟牛顿变分贝叶斯方法(QNVB)。其他基函数与更新形式同样可行。由于此类更新需要高维积分,本文首先为平均场分布提出一种高效的拟随机求积序列。该序列的每次迭代包含两个评估点,可正确积分所有单变量二次函数,若平均场因子对称,还能积分所有单变量三次函数。更重要的是,短序列结果的平均化能在更大规模的多变量二次函数空间上实现周期性精确积分。对应的变分更新仅需使用标准(非二阶)反向传播进行4次损失评估,即可消除超过半数多变量二次基函数引起的误差项。该积分技术源于随机分块平均场求积的初步提出,后者可能在其他场景中具有应用价值。基于PyTorch实现的QNVB方法相比现有方法,能在训练过程中更好地控制模型不确定性。实验证明,该方法在多种学习问题与架构中均具有更优的泛化能力。