We present the Keras domain packages KerasCV and KerasNLP, extensions of the Keras API for Computer Vision and Natural Language Processing workflows, capable of running on either JAX, TensorFlow, or PyTorch. These domain packages are designed to enable fast experimentation, with a focus on ease-of-use and performance. We adopt a modular, layered design: at the library's lowest level of abstraction, we provide building blocks for creating models and data preprocessing pipelines, and at the library's highest level of abstraction, we provide pretrained ``task" models for popular architectures such as Stable Diffusion, YOLOv8, GPT2, BERT, Mistral, CLIP, Gemma, T5, etc. Task models have built-in preprocessing, pretrained weights, and can be fine-tuned on raw inputs. To enable efficient training, we support XLA compilation for all models, and run all preprocessing via a compiled graph of TensorFlow operations using the tf.data API. The libraries are fully open-source (Apache 2.0 license) and available on GitHub.
翻译:我们推出Keras领域专用包KerasCV与KerasNLP,它们是针对计算机视觉和自然语言处理工作流程的Keras API扩展,可在JAX、TensorFlow或PyTorch任一后端上运行。这些领域包以实现快速实验为目标,着重强调易用性与性能表现。我们采用模块化分层设计:在库的最底层抽象级别,我们提供用于创建模型和数据预处理流程的基础构建模块;在库的最高层抽象级别,我们为Stable Diffusion、YOLOv8、GPT2、BERT、Mistral、CLIP、Gemma、T5等主流架构提供预训练的"任务"模型。任务模型内置预处理功能与预训练权重,并可直接基于原始输入进行微调。为实现高效训练,我们支持所有模型的XLA编译,并通过tf.data API使用TensorFlow运算的编译图来执行所有预处理流程。该系列库采用完全开源模式(Apache 2.0许可证),已在GitHub平台发布。