Token Statistics Transformer: Linear-Time Attention via Variational Rate Reduction
Authors: Ziyang Wu, Tianjiao Ding, Yifu Lu, Druv Pai, Jingyuan Zhang, Weida Wang, Yaodong Yu, Yi Ma, Benjamin Haeffele
ICLR 2025 | Venue PDF | Archive PDF | Plain Text | LLM Run Details
| Reproducibility Variable | Result | LLM Response |
|---|---|---|
| Research Type | Experimental | In this section, we conduct experiments to verify and study the properties and performance of our proposed Token Statistics Transformer (TOST) on real-world datasets and tasks. As detailed in Section 3, we design our white-box attention operator by following the principle of MCR2 and unrolling the compression objective defined in equation 4. We thus adopt a simplistic implementation as close to our theoretical derivation as possible. Hence, it is not the goal of this work to show the current implementation of TOST can outperform existing highly engineered architectures. Instead, our empirical studies aim to provide answers and evidence for the following questions: 1. Does our proposed TSSA attention operator optimize the compression objective 4 in practice? 2. If we simply replace standard self-attention with our TSSA attention operator, leaving the rest of the architecture largely unchanged, do we maintain (or improve) task performance? We provide positive answers to both questions. |
| Researcher Affiliation | Collaboration | Ziyang Wu UC Berkeley Tianjiao Ding UPenn Yifu Lu UMich Druv Pai UC Berkeley Jingyuan Zhang THU & Transc Engram Weida Wang Tsinghua SIGS Yaodong Yu UC Berkeley Yi Ma UC Berkeley & HKU Benjamin D. Haeffele JHU & UPenn |
| Pseudocode | Yes | We visualize one layer of TOST in Figure 3, and provide pseudocode for the model in Appendix E. [...] We provide a PyTorch-style pseudocode in Appendix E that reflects the architecture modifications introduced above. In particular, Algorithm 1 and Algorithm 2 implement the TSSA attention layer and its causal variant, respectively. |
| Open Source Code | Yes | Code is available at https://github.com/Robin Wu218/To ST. |
| Open Datasets | Yes | For vision experiments, we pre-train the proposed TOST models on the Image Net-1k (Deng et al., 2009) dataset. We also use these pre-trained networks as initialization and fine-tune them on smaller datasets including CIFAR10/100 (Krizhevsky et al., 2009), Oxford Flowers (Nilsback & Zisserman, 2008), Oxford-IIT-Pets (Parkhi et al., 2012) for transfer learning evaluations. We also adopt the Long-Range Arena (Tay et al., 2021) benchmark to analyze the long sequence modeling capability of TOST. For causal language modeling, we train TOST autoregressively on the Open Web Text dataset(Gokaslan et al., 2019) and test the trained model on its test split as well as other datasets as shown in (Radford et al., 2019). |
| Dataset Splits | Yes | We measure this term on both training and validation set of the Image Net-1k dataset by taking 500 samples from each. [...] For causal language modeling, [...] We use a context length of 1024, and optimize the models using Adam W optimizer (Loshchilov & Hutter, 2019) for 600,000 iterations, with batch size 480. |
| Hardware Specification | Yes | We conduct all pre-training experiments on 128 NVIDIA V100 GPUs. [...] Regarding computational resources, we conducted experiments on two NVIDIA RTX 4090 GPUs. [...] we further conduct experiments to provide empirical evidence by evaluating real world memory cost and inference speed on an NVIDIA H800 GPU. |
| Software Dependencies | No | We provide a Py Torch-style pseudocode in Appendix E that reflects the architecture modifications introduced above. [...] Our implementation is based on Karpathy (2022). |
| Experiment Setup | Yes | We train our models using the Adam W optimizer with a learning rate of 2e-4 for 400 epochs throughout our pre-training experiments. We configure our batch size to be 2048 for all our training experiments. [...] We use a context length of 1024, and optimize the models using Adam W optimizer (Loshchilov & Hutter, 2019) for 600,000 iterations, with batch size 480. We set the learning rate to 0.0006 with 2000 warm up iterations and cosine decay, weight decay to 0.1. |