Diffusion Self-Weighted Guidance for Offline Reinforcement Learning

Authors: Augusto Tagle, Javier Ruiz-del-solar, Felipe Tobar

TMLR 2025 | Venue PDF | Archive PDF | Plain Text | LLM Run Details

Reproducibility Variable Result LLM Response
Research Type Experimental Through an experimental proof of concept for SWG, we show that the proposed method i) generates samples from the desired distribution on toy examples, ii) performs competitively against state-of-the-art methods on D4RL when using resampling, and iii) exhibits robustness and scalability via ablation studies.
Researcher Affiliation Academia Augusto Tagle EMAIL Initiative for Data & AI, Universidad de Chile Javier Ruiz-del-Solar EMAIL AMTC & Dept. of Electrical Eng., Universidad de Chile Felipe Tobar EMAIL Department of Mathematics, Imperial College London
Pseudocode Yes Algorithm 1 Training of joint DM over weights and actions Algorithm 2 Self-Weighted Guidance sampling Algorithm 3 Expectile Loss critic learning
Open Source Code Yes Our code is available at SWG repository. The code used for training and experimentation can be found in the SWG repository
Open Datasets Yes Using the D4RL benchmark (Fu et al., 2021), we implemented two variants of the proposed method: SWG as described above, and SWG-R, which includes action resampling on top of our method.
Dataset Splits No The paper uses the D4RL benchmark, which has predefined datasets and splits. However, it does not explicitly state the specific training/validation/test splits used for its own experiments, or explicitly refer to using the 'standard splits' for its experiments. It mentions 'averaged mean returns over 5 random seeds and 20 independent evaluation trajectories per seed' for D4RL Locomotion tasks and similar for Ant Maze tasks, implying evaluation on a test set, but does not provide explicit split percentages or counts for training, validation, and testing data.
Hardware Specification Yes All experiments were conducted on a 12GB NVIDIA RTX 3080 Ti GPU.
Software Dependencies No Our implementation is based on the jaxrl repository (Kostrikov, 2021) using the JAX (Bradbury et al., 2018) and Flax (Heek et al., 2024) libraries. We also provide an implementation of our method in Py Torch (Paszke et al., 2019). While specific libraries are mentioned, explicit version numbers for JAX, Flax, or PyTorch are not provided.
Experiment Setup Yes We trained a diffusion model ϵθ on this extended dataset for K = 15 diffusion steps in all experiments, using Adam optimizer (Kingma & Ba, 2015) with learning rate 3e 4 for all tasks, except for Adroit Pen where we use 3e 5. We used a batch size of 1024, a variance-preserving noise schedule (Song et al., 2021), and trained for 1M gradient steps in all tasks, except for Ant Maze, where we used 3M gradient steps... The diffusion model consisted of a Res Net of 3 Residual blocks, with hidden dim 256 and Layer Normalization (Ba et al., 2016). We used a dropout (Srivastava et al., 2014) of 0.1 in all tasks. For the weight component output, we used the features from the middle layer of the Res Net and passed them through a simple MLP with a hidden dimension of 256... During inference, we tuned the guidance scale for each task. We swept over ρ {1, 5, 10, 15, 20, 25, 30}.