ByteCheckpoint:面向大规模基础模型开发的统一检查点系统
ByteCheckpoint: A Unified Checkpointing System for Large Foundation Model Development
摘要 Abstract
在大型基础模型(LFM)开发过程中,保存训练状态的检查点机制至关重要,可用于应对各种故障或GPU资源及并行配置的变化导致的训练中断。此外,保存的检查点会被分发到评估任务或在不同训练阶段之间传输(例如从预训练到后训练)。所有这些场景都需要将分布式检查点从一种并行模式重分配到另一种并行模式。在生产环境中,不同的LFM会根据模型大小和训练规模使用不同的框架和存储后端进行训练。因此,需要一个高性能的检查点系统来在整个LFM开发生命周期中实现高效的检查点管理。我们介绍了ByteCheckpoint,这是一个面向大规模LFM训练的工业级检查点系统。ByteCheckpoint具有以下特点:一种与并行模式无关的检查点表示方法,可实现高效的加载时检查点重分配;一种通用的检查点保存/加载工作流,以适应多种训练框架并支持不同的存储后端;全栈优化以确保高I/O效率和可扩展性;一套监控工具,用于简化大规模性能分析和瓶颈检测。与现有的开源检查点系统[52, 58]相比,ByteCheckpoint显著减少了运行时检查点停滞时间,平均减少了54.20倍。对于保存和加载时间,ByteCheckpoint分别实现了高达9.96倍和8.80倍的改进。
Checkpointing to preserve training states is crucial during the development of Large Foundation Models (LFMs), for training resumption upon various failures or changes in GPU resources and parallelism configurations. In addition, saved checkpoints are dispatched to evaluation tasks or transferred across different training stages (e.g., from pre-training to post-training). All these scenarios require resharding distributed checkpoints from one parallelism to another. In production environments, different LFMs are trained with various frameworks and storage backends, depending on model sizes and training scales. A high-performance checkpointing system is needed to enable efficient checkpoint management at scale throughout the lifecycle of LFM development. We introduce ByteCheckpoint, an industrial-grade checkpointing system for large-scale LFM training. ByteCheckpoint features: a parallelism-agnostic checkpoint representation that enables efficient load-time checkpoint resharding; a generic checkpoint saving/loading workflow to accommodate multiple training frameworks and support different storage backends; full-stack optimizations to ensure high I/O efficiency and scalability; a suite of monitoring tools to streamline large-scale performance analysis and bottleneck detection. Compared to existing open-source checkpointing systems [52, 58], ByteCheckpoint significantly reduces runtime checkpoint stalls, achieving an average reduction of 54.20x. For saving and loading times, ByteCheckpoint achieves improvements of up to 9.96x and 8.80x, respectively.