FlashMask: Efficient and Rich Mask Extension of FlashAttention

Authors: Guoxia Wang, Jinle Zeng, Xiyuan Xiao, Siming Wu, Jiabin Yang, Lujing Zheng, Zeyu Chen, Jiang Bian, Dianhai Yu, Haifeng Wang

ICLR 2025 | Venue PDF | Archive PDF | Plain Text | LLM Run Details

Reproducibility Variable Result LLM Response
Research Type Experimental We evaluate FLASHMASK s performance in fine-tuning and alignment training of LLMs such as SFT, Lo RA, DPO, and RM. FLASHMASK achieves significant throughput improvements, with end-to-end speedups ranging from 1.65x to 3.22x compared to existing Flash Attention dense method. Additionally, our kernel-level comparisons demonstrate that FLASHMASK surpasses the latest counterpart, Flex Attention, by 12.1% to 60.7% in terms of kernel TFLOPs/s.
Researcher Affiliation Industry EMAIL
Pseudocode Yes Algorithm 1 details the forward computation process of FLASHMASK extended from Flash Attention2, with blue-shaded parts indicating FLASHMASK computations. ... Algorithm 2 in the Appendix.
Open Source Code Yes The code is open-sourced on Paddle Paddle1 and integrated into Paddle NLP2, supporting models with over 100 billion parameters for contexts extending up to 128K tokens. 1https://github.com/Paddle Paddle/Paddle 2https://github.com/Paddle Paddle/Paddle NLP
Open Datasets Yes SFT and Lo RA utilized the same dataset, validated using allenai/tulu-v2-sft-mixture Ivison et al. (2023). For DPO and RM, which both employ (Question, Answer) data formats, we used the Hugging Face H4/ultrafeedback_binarized Tunstall et al. (2023) dataset for validation.
Dataset Splits No The paper describes methods for constructing synthetic data and sampling for experiments (e.g., 'For each sequence length 𝑁, we collected 240 valid samples and categorized them into 10 bins by sparsity 𝜌'). It also mentions using specific datasets for validation (e.g., 'Hugging Face H4/ultrafeedback_binarized Tunstall et al. (2023) dataset for validation'), but it does not provide explicit percentages, sample counts, or citations to predefined train/test/validation splits for these datasets, nor does it detail a comprehensive splitting methodology for reproducibility.
Hardware Specification Yes All experiments were conducted on machines equipped with NVIDIA A100-SXM 80G GPUs, Intel(R) Xeon(R) Platinum 8350C CPUs... All end-to-end throughput experiments were conducted on four servers, each equipped with eight NVIDIA A800-SXM 80G GPUs, totaling 32 GPUs.
Software Dependencies Yes All experiments were conducted on machines equipped with ... CUDA 12.0, and driver version 525.125.06. ... Flex Attention using Py Torch 2.6.0.dev20240920+cu124. ... Flash Infer version 0.1.6, CUDA 12.1, Py Torch 2.4, and BF16 data type.
Experiment Setup Yes Detailed information about datasets and hyperparameters settings is provided in the appendix A. ... The hyperparameters and distributed strategies for different scales are detailed in Table 1. ... We consistently applied a linear learning rate decay strategy, with warm-up steps set to 3% of the total training steps. The Adam W optimizer was used with 𝛽1 = 0.9 and 𝛽2 = 0.999. ... The maximum training sequence length was set to 8K. ... Additional hyperparameters are listed in Table 3.