We introduce Hyper Input Convex Neural Networks (HyCNNs), a novel neural network architecture designed for learning convex functions. HyCNNs combine the principles of Maxout networks with input convex neural networks (ICNNs) to create a neural network that is always convex in the input, theoretically capable of leveraging depth, and performs reliable when trained at scale compared to ICNNs. Concretely, we prove that HyCNNs require exponentially fewer parameters than ICNNs to approximate quadratic functions up to a given precision. Throughout a series of synthetic experiments, we demonstrate that HyCNNs outperform existing ICNNs and MLPs in terms of predictive performance for convex regression and interpolation tasks. We further apply HyCNNs to learn high-dimensional optimal transport maps for synthetic examples and for single-cell RNA sequencing data, where they oftentimes outperform ICNN-based neural optimal transport methods and other baselines across a wide range of settings.
翻译:我们提出超输入凸神经网络(Hyper Input Convex Neural Networks, HyCNNs),这是一种用于学习凸函数的新型神经网络架构。HyCNNs将Maxout网络与输入凸神经网络(ICNNs)的原理相结合,构建出一种在输入上始终保持凸性、理论上可充分利用深度优势,且在大规模训练时性能更可靠的神经网络。具体而言,我们证明了在逼近二次函数至给定精度时,HyCNNs所需的参数数量相较ICNNs呈指数级减少。通过一系列合成实验,我们发现HyCNNs在凸回归与插值任务中的预测性能优于现有ICNNs及多层感知机(MLPs)。进一步地,我们将HyCNNs应用于合成示例与单细胞RNA测序数据的高维最优输运映射学习,结果表明,在各种设定下,HyCNNs在多数情况下均优于基于ICNN的神经最优输运方法及其他基线模型。