Beyond Next Token Prediction: Patch-Level Training for Large Language Models
Authors: Chenze Shao, Fandong Meng, Jie Zhou
ICLR 2025 | Venue PDF | Archive PDF | Plain Text | LLM Run Details
| Reproducibility Variable | Result | LLM Response |
|---|---|---|
| Research Type | Experimental | Experiments on a diverse range of models (370M-2.7B parameters) demonstrate that patch-level training can reduce the overall training costs to 0.5 , without compromising the model performance compared to token-level training. |
| Researcher Affiliation | Industry | Chenze Shao, Fandong Meng , Jie Zhou Pattern Recognition Center, We Chat AI, Tencent Inc, China EMAIL |
| Pseudocode | Yes | A PSEUDOCODE 1 # Model input 2 num_patches = seq_length // self.patch_size 3 inputs_embeds = inputs_embeds.view(batch_size, num_patches, self. patch_size, -1).mean(2) 4 position_ids = position_ids[:, :num_patches] 8 # Model output 9 shift_logits = logits[..., :-1, :].reshape(-1, self.config.vocab_size) 10 shift_labels = labels[..., self.patch_size:].reshape(-1, self.patch_size) 11 loss = 0 12 log_probs = F.log_softmax(shift_logits, dim=1) 13 for i in range(self.patch_size): 14 loss = loss + F.nll_loss(log_probs, shift_labels[:, i]) 15 loss = loss / self.patch_size |
| Open Source Code | Yes | Source code: https://github.com/shaochenze/Patch Train. |
| Open Datasets | Yes | Datasets. We evaluate our approach on standard language modeling tasks, using the Pile dataset (Gao et al., 2020) containing 360B tokens for training 1. We assess the performance of LLMs from multiple aspects, including their perplexity, zero-shot accuracy, and instruction-following ability. Perplexity is calculated on the Wiki Text-103 test set (Merity et al., 2017). |
| Dataset Splits | Yes | Specifically, we conduct patch-level training on a fraction λ of the training data, and then use the resulting parameters to initialize the token-level model. Following this, the token-level model continues training on the remaining data to adapt the knowledge gained during patch-level training to the token-level. ... Perplexity is calculated on the Wiki Text-103 test set (Merity et al., 2017). |
| Hardware Specification | Yes | Table 6 gives the actual running speed of patch-level training in comparison with the token-level baseline setting (patch size=1, block size=2048, per device train batch size=4, accumulation steps=8), measured on 8 NVIDIA A100 GPUs. |
| Software Dependencies | No | The paper mentions 'Flash Attention2 (Dao, 2024)' and 'the tokenizer of LLa MA2', and 'Adam W optimizer (Loshchilov & Hutter, 2019)' but does not specify version numbers for these software components or any other libraries. |
| Experiment Setup | Yes | Unless otherwise specified, the patch size K is 4. The context length for token-level training 2048. For patch-level training, the context length is the patch size K 2048. The global batch size is 2M tokens, and the total number of training steps is N = 180000. For patch-level training, the number of training steps is Nλ, and then the model proceeds with token-level training for N(1 λ) steps. ... Our models are optimized by the Adam W optimizer (Loshchilov & Hutter, 2019) with β1 = 0.9, β2 = 0.95, ϵ = 1e 8. The learning rate is 3e 4 and the cosine learning rate schedule is applied with warmup of 2000 steps. We use a weight decay of 0.1 and gradient clipping of 1.0, and no dropout is applied during training. |