Mixture-of-Transformers: A Sparse and Scalable Architecture for Multi-Modal Foundation Models

Authors: Weixin Liang, LILI YU, Liang Luo, Srini Iyer, Ning Dong, Chunting Zhou, Gargi Ghosh, Mike Lewis, Wen-tau Yih, Luke Zettlemoyer, Xi Victoria Lin

TMLR 2025 | Venue PDF | Archive PDF | Plain Text | LLM Run Details

Reproducibility Variable Result LLM Response
Research Type Experimental We evaluate Mo T across multiple settings and model scales. In the Chameleon 7B setting (autoregressive text-and-image generation), Mo T matches the dense baseline s performance using only 55.8% of the FLOPs. When extended to include speech, Mo T reaches speech performance comparable to the dense baseline with only 37.2% of the FLOPs. In the Transfusion setting, where text and image are trained with different objectives, a 7B Mo T model matches the image modality performance of the dense baseline with one third of the FLOPs, and a 760M Mo T model outperforms a 1.4B dense baseline across key image generation metrics. System profiling further highlights Mo T s practical benefits, achieving dense baseline image quality in 47.2% of the wall-clock time and text quality in 75.6% of the wall-clock time (measured on AWS p4de.24xlarge instances with NVIDIA A100 GPUs).
Researcher Affiliation Collaboration Weixin Liang* EMAIL Department of Computer Science Stanford University Lili Yu , Liang Luo , Srinivasan Iyer, Ning Dong, Chunting Zhou Gargi Ghosh, Mike Lewis, Wen-tau Yih, Luke Zettlemoyer Xi Victoria Lin EMAIL Meta AI
Pseudocode Yes Algorithm 1 Mixture-of-Transformers (Mo T) Computation 1: Let x = (x1, . . . , xn) be the input sequence, where xi Rd and mi {text, image, speech} is the modality of xi 2: Let M = {text, image, speech} be the set of modalities 3: for each modality m M do 4: Im {i : mi = m} Indices of tokens for modality m 5: Xm {xi : i Im} Group tokens by modality 6: Qm W m Q Xm, Km W m K Xm, Vm W m V Xm Modality-specific projections 7: end for 8: Q S m M Qm, K S m M Km, V S m M Vm Restore original sequence order 9: A softmax QKT V Global self-attention 10: for each modality m M do 11: Om W m O AIm Modality-specific output projection 12: Hm Xm + Layer Normm attn(Om) Residual connection and layer norm 13: Fm FFNm(Hm) Modality-specific feed-forward network 14: Ym Hm + Layer Normm ffn(Fm) Residual connection and layer norm 15: end for 16: return {Ym : m M} Return transformer layer outputs
Open Source Code No The paper does not explicitly state that source code for the methodology described is publicly available, nor does it provide a link to a code repository.
Open Datasets Yes We evaluated the 7B model performance using validation losses on held-out sets of the Obelisc (Laurençon et al., 2023), MS-COCO (Lin et al., 2014), Flickr30k (Plummer et al., 2015), and Shutterstock datasets. More specifically, for MS-COCO and Flickr30k, we take the Karpathy test split of MS-COCO (Lin et al., 2014) and the Karpathy test split of Flickr30k (Plummer et al., 2015), and report text-to-image and image-to-text conditional perplexity using these two datasets. We utilized the training dataset from Spi Rit-LM (Nguyen et al., 2024) (Table 2) as our speech dataset. For text, we utilize the Llama 2 tokenizer and corpus (Touvron et al., 2023b), which contains 2 trillion tokens across diverse domains. For text-to-image tasks, we report the diffusion validation loss4 following SD 3 (Esser et al., 2024) on held-out Conceptual 12M (CC12M; Changpinyo et al. (2021)) data.
Dataset Splits Yes More specifically, for MS-COCO and Flickr30k, we take the Karpathy test split of MS-COCO (Lin et al., 2014) and the Karpathy test split of Flickr30k (Plummer et al., 2015), and report text-to-image and image-to-text conditional perplexity using these two datasets.
Hardware Specification Yes System profiling further highlights Mo T s practical benefits, achieving dense baseline image quality in 47.2% of the wall-clock time and text quality in 75.6% of the wall-clock time (measured on AWS p4de.24xlarge instances with NVIDIA A100 GPUs). We conducted our experiments and system profiling on AWS, using p4de.24xlarge instances equipped with NVIDIA A100 Tensor Core GPUs.
Software Dependencies No The paper mentions "Pytorch 2 Compiler (Ansel et al., 2024)" but does not provide specific version numbers for PyTorch itself or other key software libraries used in their implementation.
Experiment Setup Yes Model Hyperparameters. We evaluated Mo T across multiple model scales ranging from 37M to 7B parameters, comparing it to dense transformer and Mo E-4x baselines. All models were pre-trained from scratch with controlled FLOPs for fair comparison. Table 1 details the architectural specifications and training configurations for each model scale. Model architectures were scaled progressively, with hidden dimensions increasing from 256 to 4096, and layer counts from 4 to 32. Attention heads scaled from 8 to 32, while sequence length remained constant at 4096 tokens across all scales. As model size increases, we reduce batch sizes per GPU from 12 to 2, while increasing the number of GPUs from 32 to 384. Training steps were set at 160,000 for smaller models (37M to 443M) and 120,000 for larger models (880M to 7B). Total training tokens ranged from 0.168 to 0.377 trillion, with most configurations processing approximately 0.252 trillion tokens. We randomly initialize all model parameters, and optimize them using Adam W (β1 =0.9, β2 =0.95, ϵ =1e-8) with a learning rate of 3e-4, warmed up for 4000 steps and decaying to 1.5e-5 using a cosine scheduler. We train on sequences of 4096 tokens in batches of 2M tokens for 250k steps, reaching 0.5T tokens in total. We regularize with weight decay of 0.1 and clip gradients by norm (1.0). We conduct 250 diffusion steps during inference.