Scaling Stick-Breaking Attention: An Efficient Implementation and In-depth Study

Authors: Shawn Tan, Songlin Yang, Aaron Courville, Rameswar Panda, Yikang Shen

ICLR 2025 | Venue PDF | Archive PDF | Plain Text | LLM Run Details

Reproducibility Variable Result LLM Response
Research Type Experimental We compare the different properties of stick-breaking attention against softmax attention, 2. We discuss numerically stable implementations of the stick-breaking attention, and make stick-breaking amenable for large-scale training by implementing a kernel for stick-breaking in Triton, 3. We show the performance of stick-breaking attention on length-generalisation in language modelling, and evaluate 1B and 3B parameter models on various NLP tasks. [...] 5 EXPERIMENTS In this section, we compare existing attention methods against stick-breaking. We first look at a modification of a synthetic task from Arora et al. (2023) to understand the inductive biases of stick-breaking. We then compare the stick-breaking against existing length extrapolation methods on a 350M model setting. We then pretrain a 1B parameter model, and evaluate it on various NLP tasks. [...] Table 2: Results on the various NLP benchmarks for the 1B and 3B pretrained model.
Researcher Affiliation Collaboration Shawn Tan MIT-IBM Watson AI Lab EMAIL Songlin Yang MIT EMAIL Aaron Courville Mila, Université de Montréal EMAIL Rameswar Panda MIT-IBM Watson AI Lab EMAIL Yikang Shen MIT-IBM Watson AI Lab EMAIL
Pseudocode Yes Algorithm 1 FORWARD thread i [...] Algorithm 2 BACKWARD thread i
Open Source Code Yes https://github.com/shawntan/stickbreaking-attention
Open Datasets Yes We trained on the first 15B tokens of Slim Pajama (Soboleva et al., 2023), and evaluate it on the Wikitext benchmark in the LM evaluation harness (Gao et al., 2023) , with context lengths of 2048 to 64K. [...] The Pile1. Evaluating on 16K context 1https://huggingface.co/datasets/Neel Nanda/pile-10k. [...] The multiple-choice tasks include: grade-school science questions (ARC; Clark et al. 2018), common sense reasoning (Hellaswag; Zellers et al. 2019), open book question answering (Open Book QA; Mihaylov et al. 2018), physical questions (PIQA; Bisk et al. 2020), reading comprehension (RACE; Lai et al. 2017), and Winograd schema task (Winogrande; Sakaguchi et al. 2021). [...] Finally, we evaluate our 3B model on the GSM8K dataset (Cobbe et al., 2021).
Dataset Splits Yes We trained on the first 15B tokens of Slim Pajama (Soboleva et al., 2023), and evaluate it on the Wikitext benchmark in the LM evaluation harness (Gao et al., 2023) , with context lengths of 2048 to 64K. [...] We also evaluated the models on MMLU with 0-shot and 5-shot settings. [...] Finally, we evaluate our 3B model on the GSM8K dataset (Cobbe et al., 2021). Interestingly, the 5-shot setting underperforms standard attention while Co T with 8-shot sees an improvement of 5.5%.
Hardware Specification Yes We measure throughput on Dolomite Engine (Mishra, 2024) with a 1B class model on a node with 8 H100 GPUs.
Software Dependencies No The paper mentions Py Torch, Flash Attention, Triton, and CUDA, but does not specify version numbers for these software components. For example: "Implementing stick-breaking attention naively in Py Torch results in realising the L2 matrix for the attention logits." and "We modify the Triton implementation of Flash Attention for accelerating the stick-breaking attention."
Experiment Setup Yes We sweep through 4 learning rates of {10 4, 10 10 3 , 10 2}, and report the results of the best performing model. Softmax+Ro PE is able to perform perform this task up to 128 key-value pairs, while stick-breaking is able to deal with sequences up to 192 key-value pairs. The full sequence length is 768. Both models are 2-layer Transformers with 256 hidden dimension and one attention head. [...] In the first stage, there is a warmup for the learning rate to 0.01, then we apply Power decay. Our training corpus has 1T tokens and mixes large-scale open-source datasets of medium quality with permissive licenses. In the second stage, we exponentially decay the learning rate to zero. [...] The training batch size is 1024 and uses padding-free sequence packing for training in Dolomite Engine (Mishra, 2024). [...] To extend their context lengths, we continue training for 6250 steps at a context length of 16k. Our learning rate is scheduled to increase from 10 5 to 0.000125 in 150 steps, and then decays exponentially to 0 until 6250 steps. The effective batch size has 4M tokens. Table 1: Model hyperparameters and total size nlayer dhidden dinter nhead L Total params. 350M 24 1024 2730 32 2048 367,526,912 1B 40 1536 4096 24 4096 1,208,083,968 3B 40 2304 9216 36 4096 3,513,473,280