TokenFormer:基于参数令牌化的Transformer扩展再思考
TokenFormer: Rethinking Transformer Scaling with Tokenized Model Parameters
摘要 Abstract
Transformer因在各个领域表现出色已成为基础模型的主要架构。然而,其模型扩展的巨大成本仍然是一个重要问题。这一问题主要源于线性投影中对固定数量参数的依赖。当引入架构修改(如通道维度)时,整个模型通常需要从头开始重新训练。随着模型规模的不断扩大,这种策略导致计算成本越来越高,变得不可持续。为了解决这个问题,我们提出了TokenFormer,这是一种原生可扩展的架构,不仅利用注意力机制处理输入标记之间的计算,还用于处理标记与模型参数之间的交互,从而增强架构灵活性。通过将模型参数视为标记,我们将Transformer中的所有线性投影替换为我们的标记-参数注意力层,其中输入标记作为查询,模型参数作为键和值。这种重构允许渐进且高效地扩展,而无需从头开始重新训练。我们的模型通过逐步添加新的键值参数对从1.24亿扩展到14亿参数,性能与从头开始训练的Transformer相当,同时大大降低了训练成本。代码和模型可在https://github.com/Haiyang-W/TokenFormer获取。
Transformers have become the predominant architecture in foundation models due to their excellent performance across various domains. However, the substantial cost of scaling these models remains a significant concern. This problem arises primarily from their dependence on a fixed number of parameters within linear projections. When architectural modifications (e.g., channel dimensions) are introduced, the entire model typically requires retraining from scratch. As model sizes continue growing, this strategy results in increasingly high computational costs and becomes unsustainable. To overcome this problem, we introduce TokenFormer, a natively scalable architecture that leverages the attention mechanism not only for computations among input tokens but also for interactions between tokens and model parameters, thereby enhancing architectural flexibility. By treating model parameters as tokens, we replace all the linear projections in Transformers with our token-parameter attention layer, where input tokens act as queries and model parameters as keys and values. This reformulation allows for progressive and efficient scaling without necessitating retraining from scratch. Our model scales from 124M to 1.4B parameters by incrementally adding new key-value parameter pairs, achieving performance comparable to Transformers trained from scratch while greatly reducing training costs. Code and models are available at https://github.com/Haiyang-W/TokenFormer.