重新审视端到端稀疏自编码器训练:只需短暂微调即可
Revisiting End-To-End Sparse Autoencoder Training: A Short Finetune Is All You Need
摘要 Abstract
稀疏自编码器(SAEs)广泛用于解释语言模型激活。一个关键评估指标是在用SAE重构的模型对数替换原始模型激活时,原始模型对数与重构模型对数之间的交叉熵损失增加量。通常,当用均方误差(MSE)重构预先计算并打乱的激活时,仅针对SAE进行训练。近期研究引入了直接结合Kullback-Leibler(KL)散度和MSE来训练SAE(“端到端”SAEs),显著提高了重构准确性,但付出了大幅增加计算成本的代价,这限制了其广泛应用。我们提出了一种简短的KL+MSE微调步骤,仅应用于最后2500万个训练标记(仅为典型训练预算的很小一部分),从而实现了可比的改进,将交叉熵损失差距减少了20%-50%,同时只带来极小的额外计算开销。我们进一步发现,多种微调方法(KL微调、LoRA适配器、线性适配器)在交叉熵改进方面表现出相似且非叠加的效果,表明MSE训练的SAE中存在一种共同的、容易修正的错误来源。尽管KL和MSE损失之间存在规模差异,我们展示了在不同训练阶段之间有效转移超参数和稀疏惩罚的一种简单方法。虽然ReLU和TopK SAEs都显示出显著的交叉熵损失改善,但在监督式SAEBench指标上的评估结果喜忧参半,具体取决于SAE架构和下游任务。然而,我们的方法可能在解释性应用(如电路分析)中提供有意义的改进,且附加成本较低。
Sparse autoencoders (SAEs) are widely used for interpreting language model activations. A key evaluation metric is the increase in cross-entropy loss between the original model logits and the reconstructed model logits when replacing model activations with SAE reconstructions. Typically, SAEs are trained solely on mean squared error (MSE) when reconstructing precomputed, shuffled activations. Recent work introduced training SAEs directly with a combination of KL divergence and MSE ("end-to-end" SAEs), significantly improving reconstruction accuracy at the cost of substantially increased computation, which has limited their widespread adoption. We propose a brief KL+MSE fine-tuning step applied only to the final 25M training tokens (just a few percent of typical training budgets) that achieves comparable improvements, reducing the cross-entropy loss gap by 20-50%, while incurring minimal additional computational cost. We further find that multiple fine-tuning methods (KL fine-tuning, LoRA adapters, linear adapters) yield similar, non-additive cross-entropy improvements, suggesting a common, easily correctable error source in MSE-trained SAEs. We demonstrate a straightforward method for effectively transferring hyperparameters and sparsity penalties between training phases despite scale differences between KL and MSE losses. While both ReLU and TopK SAEs see significant cross-entropy loss improvements, evaluations on supervised SAEBench metrics yield mixed results, with improvements on some metrics and decreases on others, depending on both the SAE architecture and downstream task. Nonetheless, our method may offer meaningful improvements in interpretability applications such as circuit analysis with minor additional cost.