Scaling Stick-Breaking Attention: An Efficient Implementation and In-depth Study
Authors: Shawn Tan, Songlin Yang, Aaron Courville, Rameswar Panda, Yikang Shen
ICLR 2025 | Venue PDF | Archive PDF | Plain Text | LLM Run Details
| Reproducibility Variable | Result | LLM Response |
|---|---|---|
| Research Type | Experimental | We compare the different properties of stick-breaking attention against softmax attention, 2. We discuss numerically stable implementations of the stick-breaking attention, and make stick-breaking amenable for large-scale training by implementing a kernel for stick-breaking in Triton, 3. We show the performance of stick-breaking attention on length-generalisation in language modelling, and evaluate 1B and 3B parameter models on various NLP tasks. [...] 5 EXPERIMENTS In this section, we compare existing attention methods against stick-breaking. We first look at a modification of a synthetic task from Arora et al. (2023) to understand the inductive biases of stick-breaking. We then compare the stick-breaking against existing length extrapolation methods on a 350M model setting. We then pretrain a 1B parameter model, and evaluate it on various NLP tasks. [...] Table 2: Results on the various NLP benchmarks for the 1B and 3B pretrained model. |
| Researcher Affiliation | Collaboration | Shawn Tan MIT-IBM Watson AI Lab EMAIL Songlin Yang MIT EMAIL Aaron Courville Mila, Université de Montréal EMAIL Rameswar Panda MIT-IBM Watson AI Lab EMAIL Yikang Shen MIT-IBM Watson AI Lab EMAIL |
| Pseudocode | Yes | Algorithm 1 FORWARD thread i [...] Algorithm 2 BACKWARD thread i |
| Open Source Code | Yes | https://github.com/shawntan/stickbreaking-attention |
| Open Datasets | Yes | We trained on the first 15B tokens of Slim Pajama (Soboleva et al., 2023), and evaluate it on the Wikitext benchmark in the LM evaluation harness (Gao et al., 2023) , with context lengths of 2048 to 64K. [...] The Pile1. Evaluating on 16K context 1https://huggingface.co/datasets/Neel Nanda/pile-10k. [...] The multiple-choice tasks include: grade-school science questions (ARC; Clark et al. 2018), common sense reasoning (Hellaswag; Zellers et al. 2019), open book question answering (Open Book QA; Mihaylov et al. 2018), physical questions (PIQA; Bisk et al. 2020), reading comprehension (RACE; Lai et al. 2017), and Winograd schema task (Winogrande; Sakaguchi et al. 2021). [...] Finally, we evaluate our 3B model on the GSM8K dataset (Cobbe et al., 2021). |
| Dataset Splits | Yes | We trained on the first 15B tokens of Slim Pajama (Soboleva et al., 2023), and evaluate it on the Wikitext benchmark in the LM evaluation harness (Gao et al., 2023) , with context lengths of 2048 to 64K. [...] We also evaluated the models on MMLU with 0-shot and 5-shot settings. [...] Finally, we evaluate our 3B model on the GSM8K dataset (Cobbe et al., 2021). Interestingly, the 5-shot setting underperforms standard attention while Co T with 8-shot sees an improvement of 5.5%. |
| Hardware Specification | Yes | We measure throughput on Dolomite Engine (Mishra, 2024) with a 1B class model on a node with 8 H100 GPUs. |
| Software Dependencies | No | The paper mentions Py Torch, Flash Attention, Triton, and CUDA, but does not specify version numbers for these software components. For example: "Implementing stick-breaking attention naively in Py Torch results in realising the L2 matrix for the attention logits." and "We modify the Triton implementation of Flash Attention for accelerating the stick-breaking attention." |
| Experiment Setup | Yes | We sweep through 4 learning rates of {10 4, 10 10 3 , 10 2}, and report the results of the best performing model. Softmax+Ro PE is able to perform perform this task up to 128 key-value pairs, while stick-breaking is able to deal with sequences up to 192 key-value pairs. The full sequence length is 768. Both models are 2-layer Transformers with 256 hidden dimension and one attention head. [...] In the first stage, there is a warmup for the learning rate to 0.01, then we apply Power decay. Our training corpus has 1T tokens and mixes large-scale open-source datasets of medium quality with permissive licenses. In the second stage, we exponentially decay the learning rate to zero. [...] The training batch size is 1024 and uses padding-free sequence packing for training in Dolomite Engine (Mishra, 2024). [...] To extend their context lengths, we continue training for 6250 steps at a context length of 16k. Our learning rate is scheduled to increase from 10 5 to 0.000125 in 150 steps, and then decays exponentially to 0 until 6250 steps. The effective batch size has 4M tokens. Table 1: Model hyperparameters and total size nlayer dhidden dinter nhead L Total params. 350M 24 1024 2730 32 2048 367,526,912 1B 40 1536 4096 24 4096 1,208,083,968 3B 40 2304 9216 36 4096 3,513,473,280 |