Round and Round We Go! What makes Rotary Positional Encodings useful?
Authors: Federico Barbero, Alex Vitvitskyi, Christos Perivolaropoulos, Razvan Pascanu, Petar Veličković
ICLR 2025 | Venue PDF | Archive PDF | Plain Text | LLM Run Details
| Reproducibility Variable | Result | LLM Response |
|---|---|---|
| Research Type | Experimental | We study the internals of a trained Gemma 7B model to understand how Ro PE is being used at a mechanical level. We find that Gemma learns to use Ro PE to construct robust positional attention patterns by exploiting the highest frequencies. We also find that, in general, Gemma greatly prefers to use the lowest frequencies of Ro PE, which we suspect are used to carry semantic information. We mathematically prove interesting behaviours of Ro PE and conduct experiments to verify our findings, proposing a modification of Ro PE that fixes some highlighted issues and improves performance. We train Gemma 2B models from scratch on the Wiki and Flan V2 training datasets (see Section D for details). We set the base wavelength θ to 10,000 and train and evaluate using the standard Gemma 8k token context. In Table 2, we show the perplexity on the validation set. We can see how truncating the lowest frequencies not only maintains the same performance, but even seems to improve the validation perplexity, supporting our claims regarding the low frequencies acting as non-robust semantic channels in standard Ro PE. |
| Researcher Affiliation | Collaboration | Federico Barbero University of Oxford Alex Vitvitskyi Google Deep Mind Christos Perivolaropoulos Google Deep Mind Razvan Pascanu Google Deep Mind Petar Veliˇckovi c Google Deep Mind |
| Pseudocode | Yes | def apply_p_rope( inputs: jax.Array, # [B, L] positions: jax.Array, # [B, L] head_dim: int, max_wavelength: int = _MAX_WAVELENGTH, rope_percentage: float = 1.0, ) -> jax.Array: """Applies p-Ro PE.""" rope_angles = int(rope_percentage * head_dim // 2) nope_angles = head_dim // 2 - rope_angles fraction = 2. * jnp.arange(0, rope_angles) / head_dim timescale = max_wavelength**fraction timescale = jnp.pad( max_wavelength**fraction, (0, nope_angles), mode= 'constant', constant_values=(0, jnp.inf) ) sinusoid_inp = ( positions[..., jnp.newaxis] / timescale[jnp.newaxis, jnp.newaxis, :] ) sinusoid_inp = sinusoid_inp[..., jnp.newaxis, :] sin = jnp.sin(sinusoid_inp) cos = jnp.cos(sinusoid_inp) first_half, second_half = jnp.split(inputs, 2, axis=-1) first_part = first_half * cos - second_half * sin second_part = second_half * cos + first_half * sin out = jnp.concatenate([first_part, second_part], axis=-1) return out.astype(inputs.dtype) |
| Open Source Code | No | mechanisms similar to p-Ro PE have been discovered although we were unaware of this at the time of writing this work. An example is this Git Hub Issue that suggests to use partial rotary embeddings (Ro PEpartial): https://github.com/lucidrains/x-transformers/issues/40. Other works have also found Ro PEpartial to be useful (Black et al., 2022; Liu et al., 2024). |
| Open Datasets | Yes | We train Gemma 2B models from scratch on the Wiki and Flan V2 training datasets (see Section D for details). ... Wiki is a dataset based on English Wikipedia articles, built from the Wikipedia dump. ... The dataset is available at: https://www.tensorflow.org/datasets/catalog/wikipedia. There are 6,672,479 documents and the total size is of 19.88 Gi B. ... Flan V2 was introduced by Longpre et al. (2023). It is a dataset for instruction tuning which combines collections from FLAN, P3/T0, and Natural Instructions with dialog, program synthesis, and complex reasoning tasks. The dataset contains 15,000,000 samples. |
| Dataset Splits | Yes | Wiki is a dataset based on English Wikipedia articles, built from the Wikipedia dump. Each sample contains the contents of a full Wikipedia article, with processing done to strip markdown and unwanted sections. The dataset is available at: https://www.tensorflow.org/datasets/catalog/wikipedia. There are 6,672,479 documents and the total size is of 19.88 Gi B. We held out 10% of the data randomly for the validation split. |
| Hardware Specification | No | The paper does not provide specific hardware details (e.g., GPU/CPU models, processor types, memory amounts, or detailed computer specifications) used for running its experiments. It only mentions the models used (Gemma 7B, Gemma 2B) and the datasets, but not the hardware infrastructure. |
| Software Dependencies | No | The paper includes a code snippet written using 'jax' and 'jnp' (JAX NumPy), indicating the use of these libraries. However, it does not specify any version numbers for JAX or other software dependencies. |
| Experiment Setup | Yes | In all experiments we train for 10,000 steps using a batch size of 512 and a sequence length of 8,192. We fix the wavelength to the standard value of 10,000. We use the standard Gemma 2B architecture (Gemma Team et al., 2024). For p-Ro PE, we use the algorithm apply_p_rope described below. |