Large language models (LLMs) have shown strong performance on code translation between widely used programming languages. However, translation becomes much less reliable for domain-specific code, where correctness depends on framework-specific APIs and execution semantics. One example is translating deep-learning code from PyTorch to JAX, where LLM outputs often contain subtle bugs or non-idiomatic usage that prevents execution or changes behavior. Prior work suggests that curated bug-fix data from LLM-generated code can help improve code generation quality, but such resources are still limited for PyTorch-to-JAX translation. In this work, we introduce T2J, a benchmark of LLM translation bugs paired with developer-written fixes for PyTorch-to-JAX code. We start from 20 kernels in the TorchLeet dataset, translate them to JAX using the weak LLM gpt-4o-mini, and hire software developers to debug and repair the generated JAX implementations. We then use T2J to improve PyTorch-to-JAX translation for the weak LLM gpt-4o-mini via in-context learning. Our evaluation shows that using T2J yields up to 20% improvement of our proposed metric T2J-CodeTrans-Score.
翻译:暂无翻译