Convolution models with long filters have demonstrated state-of-the-art reasoning abilities in many long-sequence tasks but lag behind the most optimized Transformers in wall-clock time. A major bottleneck is the Fast Fourier Transform (FFT)--which allows long convolutions to run in $O(N logN)$ time in sequence length $N$ but has poor hardware utilization. In this paper, we study how to optimize the FFT convolution. We find two key bottlenecks: the FFT does not effectively use specialized matrix multiply units, and it incurs expensive I/O between layers of the memory hierarchy. In response, we propose FlashFFTConv. FlashFFTConv uses a matrix decomposition that computes the FFT using matrix multiply units and enables kernel fusion for long sequences, reducing I/O. We also present two sparse convolution algorithms--1) partial convolutions and 2) frequency-sparse convolutions--which can be implemented simply by skipping blocks in the matrix decomposition, enabling further opportunities for memory and compute savings. FlashFFTConv speeds up exact FFT convolutions by up to 7.93$\times$ over PyTorch and achieves up to 4.4$\times$ speedup end-to-end. Given the same compute budget, FlashFFTConv allows Hyena-GPT-s to achieve 2.3 points better perplexity on the PILE and M2-BERT-base to achieve 3.3 points higher GLUE score--matching models with twice the parameter count. FlashFFTConv also achieves 96.1% accuracy on Path-512, a high-resolution vision task where no model had previously achieved better than 50%. Furthermore, partial convolutions enable longer-sequence models--yielding the first DNA model that can process the longest human genes (2.3M base pairs)--and frequency-sparse convolutions speed up pretrained models while maintaining or improving model quality.
翻译:具有长滤波器的卷积模型在许多长序列任务中展现出最先进的推理能力,但在实际计算时间上落后于优化最充分的Transformer。其核心瓶颈在于快速傅里叶变换(FFT)——虽然能够使长卷积在序列长度$N$下以$O(N \log N)$时间复杂度运行,但硬件利用率较低。本文研究如何优化FFT卷积,发现两个关键瓶颈:FFT未能有效利用专用矩阵乘法单元,且在存储层级间产生高昂的输入/输出开销。为此,我们提出FlashFFTConv。该方法采用矩阵分解,通过矩阵乘法单元计算FFT,并实现长序列的核融合以减少I/O开销。同时提出两种稀疏卷积算法:1)部分卷积 2)频域稀疏卷积,可通过简单跳过矩阵分解中的计算块实现,从而进一步节省内存与计算资源。FlashFFTConv在精确FFT卷积上比PyTorch加速高达7.93倍,端到端加速达4.4倍。在相同计算预算下,FlashFFTConv使Hyena-GPT-s在PILE数据集上困惑度提升2.3个点,M2-BERT-base在GLUE评分上提升3.3个点——达到参数数量翻倍模型的表现。此外,FlashFFTConv在Path-512高分辨率视觉任务中达到96.1%准确率,而此前所有模型均未突破50%。部分卷积支持更长序列建模——首次实现能够处理最长人类基因(230万碱基对)的DNA模型——频域稀疏卷积则在保持或提升模型质量的同时加速预训练模型。