Efficient Dictionary Learning with Switch Sparse Autoencoders
Authors: Anish Mudide, Josh Engels, Eric Michaud, Max Tegmark, Christian Schroeder de Witt
ICLR 2025 | Venue PDF | Archive PDF | Plain Text | LLM Run Details
| Reproducibility Variable | Result | LLM Response |
|---|---|---|
| Research Type | Experimental | We present experiments comparing Switch SAEs with other SAE architectures, and find that Switch SAEs deliver a substantial Pareto improvement in the reconstruction vs. sparsity frontier for a given fixed training compute budget. We also study the geometry of features across experts, analyze features duplicated across experts, and verify that Switch SAE features are as interpretable as features found by other SAE architectures. |
| Researcher Affiliation | Academia | Anish Mudide Massachusetts Institute of Technology Joshua Engels Massachusetts Institute of Technology Eric J. Michaud Massachusetts Institute of Technology Max Tegmark Massachusetts Institute of Technology Christian Schroeder de Witt University of Oxford |
| Pseudocode | No | The paper describes the architecture and training methodology in Section 3 and its subsections, using equations and prose, but it does not contain any clearly labeled pseudocode or algorithm blocks. |
| Open Source Code | Yes | 1Our code can be found at https://github.com/amudide/switch_sae |
| Open Datasets | Yes | We train sparse autoencoders on the residual stream activations of GPT-2 small, which have a dimension of 768 (Radford et al., 2019). Early layers of language models handle detokenization and feature engineering, while later layers refine representations for next-token prediction (Lad et al., 2024). In this work, we follow Gao et al. (2024) and train SAEs on activations from layer 8, which we expect to hold rich feature representations not yet specialized for next-token prediction. Using text data from Open Web Text (Gokaslan & Cohen, 2019), we train for 100K steps using a batch size of 8192, for a total of approximately 820 million tokens. |
| Dataset Splits | No | The paper mentions training on 'text data from Open Web Text' for 100K steps with a batch size, but it does not specify how this data was split into distinct training, validation, or test sets for evaluation. It describes the overall training process rather than dataset partitioning. |
| Hardware Specification | No | The paper mentions "huge training runs on large clusters of GPUs" in the conclusion as a potential application but does not specify the exact GPU models, CPU models, or other hardware specifications used for the experiments presented in the paper. |
| Software Dependencies | No | The paper mentions using "Adam (Kingma, 2014)" as an optimizer and refers to a "dictionary learning repository (Marks et al., 2024)", but it does not provide specific version numbers for any software libraries, frameworks (e.g., PyTorch, TensorFlow), or programming languages used. |
| Experiment Setup | Yes | We train for 100K steps using a batch size of 8192, for a total of approximately 820 million tokens. We set the learning rate based on the 1/M scaling law from Gao et al. (2024) and linearly decay the learning rate over the last 20% of training. We optimize Ltotal with Adam using the default parameters β1 = 0.9, β2 = 0.999. We find that results are not overly sensitive to the value of α, but that α = 3 is a reasonable default based on a hyperparameter sweep (see Appendix A.2 for details). All SAEs are trained with k = 64 with a fixed width; we train the Gemma 2 2B SAEs with width 65536 and the GPT-2 SAEs with width 24576. |