Optimized Multi-Token Joint Decoding With Auxiliary Model for LLM Inference
Authors: Zongyue Qin, Ziniu Hu, Zifan He, Neha Prakriya, Jason Cong, Yizhou Sun
ICLR 2025 | Venue PDF | Archive PDF | Plain Text | LLM Run Details
| Reproducibility Variable | Result | LLM Response |
|---|---|---|
| Research Type | Experimental | Empirical evaluations across various tasks reveal that MTAD improves downstream performance by 25% compared to standard single-token sampling. Furthermore, MTAD achieves a 1.42 speed-up and consumes 1.54 less energy than vanilla speculative decoding methods. These results highlight MTAD s ability to make multi-token joint decoding both effective and efficient, promoting more productive and high-performance deployment of LLMs. |
| Researcher Affiliation | Academia | Department of Computer Science, University of California, Los Angeles, USA. Correspondence to: EMAIL California Institute of Technology, USA. |
| Pseudocode | Yes | Algorithm 1 One Iteration of MTAD Algorithm |
| Open Source Code | Yes | 1We release our code at https://github.com/Zongyue Qin/MTAD |
| Open Datasets | Yes | In the main paper, we report results with three public datasets for evaluation: (1) Spider (Yu et al., 2018), MTBench (Zheng et al., 2023), and Human Eval (Chen et al., 2021). We use Llama-3-8B and Llama-3-8B-Instruct (Dubey et al., 2024) as target models, and Llama-3-1B and Llama-3-1B-Instruct as their draft models, respectively. We provide additional experiments with other datasets and model families in Appendix C. |
| Dataset Splits | No | The paper mentions using several datasets for evaluation and reports performance metrics on them but does not explicitly specify the training, validation, or test dataset splits used for the experiments. It describes experimental settings like generating a maximum of 128 tokens for each input and running for 1,000 seconds, but this pertains to experiment execution rather than dataset partitioning. |
| Hardware Specification | Yes | The experiments are conducted on a machine with 1 Nvidia L40 GPU (48 GB), 4 CPUs, and 50 GB main memory, using a batch size of 1, which is common for online serving (Schuster et al., 2022). |
| Software Dependencies | No | The paper references a public speculative decoding implementation (Bear, 2024) for warping sampling distributions and describes hyperparameter selection. However, it does not explicitly list specific software components like Python, PyTorch, or CUDA with their corresponding version numbers used in the experimental setup. |
| Experiment Setup | Yes | All the methods are stochastic with top-k and top-p sampling with the temperature = 1. The details of the hyper-parameters (e.g., k and p) and machine configurations of the experiments can listed in the Appendix D, E, and F. For MTAD, we choose the beam width from {4, 8}, the number of draft tokens from {3, 4}, and the acceptance threshold from {0.1, 0.3, 0.5, 0.7, 0.9}. |