DeFT: Decoding with Flash Tree-attention for Efficient Tree-structured LLM Inference

Authors: Jinwei Yao, Kaiqi Chen, Kexun Zhang, Jiaxuan You, Binhang Yuan, Zeke Wang, Tao Lin

ICLR 2025 | Venue PDF | Archive PDF | Plain Text | LLM Run Details

Reproducibility Variable Result LLM Response
Research Type Experimental We empirically verify its effectiveness on few-shot prompting, multi-step reasoning, and speculative-decoding tasks. DEFT-Flatten can achieve a decoding latency speedup of 1.3 for few-shot prompting, 2.2 for speculative decoding, 1.1 for multi-step reasoning, due to an up to 3.59 faster attention calculation, with the baseline implementations (Dao et al., 2023; Cai et al., 2024; Zheng et al., 2023). We compare different tree split strategies DEFT-Node, DEFT-Node-Chunk, and DEFT-Flatten in ablation studies (see section 4.4), showing the balanced partitioning of QKV groups matters.
Researcher Affiliation Academia 1Westlake University 2Zhejiang University 3Carnegie Mellon University 4University of Illinois Urbana-Champaign 5Hong Kong University of Science and Technology
Pseudocode Yes DEFT-Node and DEFT-Flatten algorithms with two phases in a Python style can be found in Appendix A.8 and Appendix A.9, respectively. Algorithm 1 DEFT-Node Algorithm-Phase 1: QKV Preparation. Algorithm 2 DEFT-Node Algorithm-Phase 2: Attention Calculation. Algorithm 3 DEFT-Flatten Algorithm-Phase 1: QKV Preparation. Algorithm 4 DEFT-Flatten Algorithm-Phase 2: Attention Calculation.
Open Source Code Yes Our code is available at https://github. com/LINs-lab/De FT.
Open Datasets Yes Workloads generation. To ensure fairness for workloads of different baselines, we reconstruct decoding trees from real multi-step reasoning and speculative decoding tasks, as shown in Table 4. For multi-step reasoning, we include these four tasks from Besta et al. (2023)... For speculative decoding tasks, we used the token tree topology from Medusa (Cai et al., 2024) and recorded real interaction data with APPS (Hendrycks et al., 2021) as prompt dataset...
Dataset Splits No The paper describes how specific evaluation scenarios (decoding trees) were generated or reconstructed from existing work (e.g., Besta et al. (2023), Medusa (Cai et al., 2024), APPS (Hendrycks et al., 2021)) for their experiments, and sets stopping criteria for these evaluations (e.g., '400 iterations', '1000 steps'). However, it does not specify traditional dataset splits (e.g., train/test/validation percentages or counts) for training a model, as the paper focuses on inference acceleration of pre-trained LLMs.
Hardware Specification Yes We evaluate the performance of DEFT in NVIDIA A100 (80GB) in Llama3-8B model (Touvron et al., 2023b)... Table 19: [Different GPUs] Speedup of DEFT in average attention latency (second) with NVIDIA RTX 4090 (24GB) for LLama3-8B model(GQA).
Software Dependencies No We implement the DEFT attention kernel by Open AI Triton (Tillet et al., 2019)... Tree Attention-Medusa (Cai et al., 2024)... uses Py Torch s General Matrix Multiply (GEMM)... While these tools are mentioned, specific version numbers for Triton, PyTorch, or any other critical libraries are not provided.
Experiment Setup Yes In this section, to demonstrate the effectiveness of DEFT under different tree topologies, we comprehensively conduct experiments on three types of tree-based decoding tasks, including: (1) few-shot prompting... (2) multi-step reasoning... (3) speculative decoding... We evaluate the performance of DEFT in NVIDIA A100 (80GB) in Llama3-8B model (Touvron et al., 2023b)... For few-shot prompting tasks, we used a prompt with 4k tokens and performed 400 decoding iterations... For speculative decoding tasks... we reconstruct decoding trees from real multi-step reasoning and speculative decoding tasks, as shown in Table 4.