Offline model-based optimization aims to find a design that maximizes a property of interest using only an offline dataset, with applications in robot, protein, and molecule design, among others. A prevalent approach is gradient ascent, where a proxy model is trained on the offline dataset and then used to optimize the design. This method suffers from an out-of-distribution issue, where the proxy is not accurate for unseen designs. To mitigate this issue, we explore using a pseudo-labeler to generate valuable data for fine-tuning the proxy. Specifically, we propose \textit{\textbf{I}mportance-aware \textbf{C}o-\textbf{T}eaching for Offline Model-based Optimization}~(\textbf{ICT}). This method maintains three symmetric proxies with their mean ensemble as the final proxy, and comprises two steps. The first step is \textit{pseudo-label-driven co-teaching}. In this step, one proxy is iteratively selected as the pseudo-labeler for designs near the current optimization point, generating pseudo-labeled data. Subsequently, a co-teaching process identifies small-loss samples as valuable data and exchanges them between the other two proxies for fine-tuning, promoting knowledge transfer. This procedure is repeated three times, with a different proxy chosen as the pseudo-labeler each time, ultimately enhancing the ensemble performance. To further improve accuracy of pseudo-labels, we perform a secondary step of \textit{meta-learning-based sample reweighting}, which assigns importance weights to samples in the pseudo-labeled dataset and updates them via meta-learning. ICT achieves state-of-the-art results across multiple design-bench tasks, achieving the best mean rank of $3.1$ and median rank of $2$, among $15$ methods. Our source code can be found here.
翻译:离线模型优化旨在仅利用离线数据集,找到最大化目标属性的设计方案,广泛应用于机器人、蛋白质和分子设计等领域。常见方法是梯度上升法,即在离线数据集上训练代理模型,然后用于优化设计。该方法存在分布外问题,即代理模型对未见设计不准确。为缓解此问题,我们探索使用伪标签器生成有价值数据来微调代理模型。具体而言,我们提出\textit{\textbf{重要性感知协同教学用于离线模型优化}}(ICT)。该方法维护三个对称代理模型,以其均值集合作为最终代理模型,包含两步。第一步是\textit{伪标签驱动的协同教学}。在此步骤中,迭代选择一个代理模型作为当前优化点附近设计的伪标签器,生成伪标记数据。随后,协同教学过程识别小损失样本作为有价值数据,在另外两个代理模型之间交换用于微调,促进知识迁移。该过程重复三次,每次选择不同代理模型作为伪标签器,最终提升集合性能。为进一步提高伪标签准确性,我们执行第二步\textit{基于元学习的样本重加权},为伪标记数据集中的样本分配重要性权重,并通过元学习更新这些权重。ICT在多个设计基准任务中取得最先进结果,在15种方法中实现最佳平均排名$3.1$和中位数排名$2$。我们的源代码可在此处获取。