Forgetting Transformer: Softmax Attention with a Forget Gate

Authors: Zhixuan Lin, Evgenii Nikishin, Xu He, Aaron Courville

ICLR 2025 | Venue PDF | Archive PDF | Plain Text | LLM Run Details

Reproducibility Variable Result LLM Response
Research Type Experimental We show that Fo X outperforms the Transformer on long-context language modeling, length extrapolation, and short-context downstream tasks, while performing on par with the Transformer on long-context downstream tasks. Moreover, it is compatible with the Flash Attention algorithm and does not require any positional embeddings. Several analyses, including the needle-in-the-haystack test, show that Fo X also retains the Transformer s superior long-context capabilities over recurrent sequence models such as Mamba2, HGRN2, and Delta Net. We also introduce a Pro block design that incorporates some common architectural components in recurrent sequence models and find it significantly improves the performance of both Fo X and the Transformer. Our code is available at https://github.com/zhixuan-lin/ forgetting-transformer. 4 EMPIRICAL STUDY The advantages of Transformers in long-context abilities over recurrent sequence models have been verified multiple times (Hsieh et al., 2024; Waleffe et al., 2024; Shen et al., 2024; Qin et al., 2024a). However, forget gates introduce a recency bias. It is thus natural to ask whether Fo X still maintains this advantage. Therefore, our empirical study places a special focus on long-context capabilities.
Researcher Affiliation Collaboration Zhixuan Lin Mila & Universit e de Montr eal EMAIL Evgenii Nikishin Mila & Universit e de Montr eal EMAIL Xu Owen He Maker Maker AI EMAIL Aaron Courville Mila & Universit e de Montr eal EMAIL
Pseudocode Yes In Algorithm 1 and 2, we provide the algorithms for computing the forward pass and backward pass of Forgetting Attention in a hardware-aware way. The algorithm is reproduced from the Flash Attention-2 paper (Dao, 2023), with the changes needed to implement Forgetting Attention added and highlighted.
Open Source Code Yes Our code is available at https://github.com/zhixuan-lin/ forgetting-transformer.
Open Datasets Yes Dataset and baselines We focus on long-context language modeling and train all models on Long Crawl64 (Buckman, 2024), a long-sequence subset of Red Pajama-v2 (Together Computer, 2023) pre-tokenized with the Tik Token tokenizer (Open AI, 2022) for GPT-2 (Radford et al., 2019). For baselines, we focus on two types of comparisons. First, we compare Fo X with the Transformer. For the Transformer, we also test both the LLa MA and the Pro architecture (referred to as Transformer (LLa MA) and Transformer (Pro), respectively). Similar to Xiong et al. (2023), we find it crucial to use a large Ro PE angle θ for the Transformer for long-context training. Following Xiong et al. (2023) we use θ = 500000. Second, to show the advantage of Fo X over recurrent sequence models in long-context capabilities, we compare it with Mamba-2 (Dao & Gu, 2024), HGRN2 (Qin et al., 2024a), and Delta Net (Yang et al., 2024). The implementation of all models is based on the Flash Linear Attention repository (Yang & Zhang, 2024). ... We use two sets of downstream tasks: a set of short-context tasks from LM-evaluation-harness (Gao et al., 2024) and a set of long-context tasks from Long Bench (Bai et al., 2023). ... To complement our main results in which we perform long-context training on Long Crawl64, we have also run short-context training on the more commonly used Slim Pajama dataset (Soboleva et al., 2023).
Dataset Splits Yes For our main experiments, we train models with 760M (non-embedding) parameters on a 45 * 2^30-token (roughly 48B tokens) subset of Long Crawl64 with a training context length of 16384 tokens. For the validation set, we use a 2 * 2^30-token subset of the Long Crawl64 held-out set with sequences of 65536 tokens.
Hardware Specification Yes Throughput is measured on 4 NVIDIA L40S GPUs.
Software Dependencies No The Flash Attention kernels for Fo X (Pro), Transformer (Pro), and Fo X (LLa MA) are implemented in Triton by us on top of Flag Attention (Flag Open, 2023) without significant optimization, while Transformer (LLa MA) uses the official Flash Attention implementation (Dao, 2023)) in CUDA. ... In practice, we implement Forgetting Attention based on the Triton (Open AI, 2021) Flash Attention implementation in Flag Attention (Flag Open, 2023).
Experiment Setup Yes All models are trained with Adam W (Loshchilov, 2017) with (β1, β2) = (0.9, 0.95). We use a linear learning rate warmup from 0 to the peak learning rate for the first 256 * 2^20 tokens and then a cosine decay to 0. Each training batch contains 0.5 * 2^20 tokens. All models use a weight decay of 0.1 and gradient clipping of 1.0. We search the learning rate for each model within {1 * 10^i, 2 * 10^i, 5 * 10^i} for different i s until we identify a locally optimal value. We tune the head dimensions for Fo X and the Transformer in {64, 128}. We find that Fo X often prefers higher learning rates and more heads/smaller head dimensions than the Transformer, and the Pro models often prefer higher learning rates than the LLa MA models. Details of the hyperparameters and experimental setup can be found in Appendix B.