Offline model-based optimization aims to find a design that maximizes a property of interest using only an offline dataset, with applications in robot, protein, and molecule design, among others. A prevalent approach is gradient ascent, where a proxy model is trained on the offline dataset and then used to optimize the design. This method suffers from an out-of-distribution issue, where the proxy is not accurate for unseen designs. To mitigate this issue, we explore using a pseudo-labeler to generate valuable data for fine-tuning the proxy. Specifically, we propose \textit{\textbf{I}mportance-aware \textbf{C}o-\textbf{T}eaching for Offline Model-based Optimization}~(\textbf{ICT}). This method maintains three symmetric proxies with their mean ensemble as the final proxy, and comprises two steps. The first step is \textit{pseudo-label-driven co-teaching}. In this step, one proxy is iteratively selected as the pseudo-labeler for designs near the current optimization point, generating pseudo-labeled data. Subsequently, a co-teaching process identifies small-loss samples as valuable data and exchanges them between the other two proxies for fine-tuning, promoting knowledge transfer. This procedure is repeated three times, with a different proxy chosen as the pseudo-labeler each time, ultimately enhancing the ensemble performance. To further improve accuracy of pseudo-labels, we perform a secondary step of \textit{meta-learning-based sample reweighting}, which assigns importance weights to samples in the pseudo-labeled dataset and updates them via meta-learning. ICT achieves state-of-the-art results across multiple design-bench tasks, achieving the best mean rank of $3.1$ and median rank of $2$, among $15$ methods. Our source code can be found here.
翻译:离线模型优化旨在仅利用离线数据集,寻找最大化目标属性的设计方案,广泛应用于机器人、蛋白质和分子设计等领域。常用方法为梯度上升法,即先在离线数据集上训练代理模型,再基于该模型优化设计。此方法存在分布外问题——代理模型对未见设计不准确。为缓解该问题,我们探索利用伪标签器生成有价值数据以微调代理模型。具体而言,我们提出**重要性感知的协同教学用于离线模型优化(ICT)**。该方法维护三个对称代理,取其均值集成作为最终代理,包含两个步骤:第一步为**伪标签驱动协同教学**。此步骤中,每次迭代选定一个代理作为当前优化点附近设计的伪标签器,生成伪标签数据;随后通过协同教学过程识别低损失样本作为有价值数据,并在其余两个代理间交换用于微调,促进知识迁移。该过程重复三次,每次选择不同代理作为伪标签器,最终提升集成性能。为进一步提高伪标签准确性,我们执行第二步**基于元学习的样本重加权**,为伪标签数据集中的样本分配重要性权重,并通过元学习更新权重。ICT在多个设计基准任务中取得最优结果,在15种方法中以最佳平均排名3.1和中位排名2位列第一。我们的源代码可在此处获取。