借助树正则化将句法引入Transformer语言模型

Sneaking Syntax into Transformer Language Models with Tree Regularization

摘要 Abstract

虽然人类语言理解的构成性解释基于层级化的树状过程,但像Transformer这样的神经网络模型缺乏对这种树结构的直接归纳偏置。引入句法归纳偏置可能在Transformer语言模型(LMs)中解锁更稳健且数据高效的训练,但现有方法大多通过极大地限制模型来实现这一点,要么降低其表达能力,要么增加推理复杂度。本文提出了一种新的方法,即通过结构化正则化轻量地向给定的Transformer电路注入句法归纳偏置。我们引入了TreeReg,这是一种辅助损失函数,它将银标准解析中的括号决策转化为对向量隐藏状态的一组可微的正交性约束。TreeReg能够无缝集成到标准LM目标中,无需进行架构上的改动。在WikiText-103等自然语言语料库上预训练的LMs在分布外数据上的困惑度降低了多达10%,句法泛化的性能提高了多达9.5个百分点,并且只需要不到一半的数据量即可超越标准LMs。即使对于预训练的大规模语言模型(LLMs),TreeReg依然有效:使用TreeReg继续预训练Sheared Llama模型可以提升句法泛化能力,而在MultiNLI数据集上结合TreeReg进行微调可以将对抗性自然语言推理基准上的性能下降幅度减少41.2个百分点。我们公开了所有代码,以指导未来的研究。

While compositional accounts of human language understanding are based on a hierarchical tree-like process, neural models like transformers lack a direct inductive bias for such tree structures. Introducing syntactic inductive biases could unlock more robust and data-efficient learning in transformer language models (LMs), but existing methods for incorporating such structure greatly restrict models, either limiting their expressivity or increasing inference complexity. This work instead aims to softly inject syntactic inductive biases into given transformer circuits, through a structured regularizer. We introduce TreeReg, an auxiliary loss function that converts bracketing decisions from silver parses into a set of differentiable orthogonality constraints on vector hidden states. TreeReg integrates seamlessly with the standard LM objective, requiring no architectural changes. LMs pre-trained with TreeReg on natural language corpora such as WikiText-103 achieve up to 10% lower perplexities on out-of-distribution data and up to 9.5 point improvements in syntactic generalization, requiring less than half the training data to outperform standard LMs. TreeReg still provides gains for pre-trained LLMs: Continued pre-training of Sheared Llama with TreeReg results in improved syntactic generalization, and fine-tuning on MultiNLI with TreeReg mitigates degradation of performance on adversarial NLI benchmarks by 41.2 points. We release all code to guide future research.