Importance Corrected Neural JKO Sampling
Authors: Johannes Hertrich, Robert Gruhlke
ICML 2025 | Venue PDF | Archive PDF | Plain Text | LLM Run Details
| Reproducibility Variable | Result | LLM Response |
|---|---|---|
| Research Type | Experimental | Numerical examples show that our method yields accurate results on various test distributions including high-dimensional multimodal targets and outperforms the state of the art in almost all cases significantly. |
| Researcher Affiliation | Academia | 1Université Paris Dauphine PSL 2University College London 3FU Berlin. |
| Pseudocode | Yes | Algorithm 1 Sampling and density propagation for importance-based rejection steps Algorithm 2 Importance corrected neural JKO sampling Algorithm 3 Training of neural JKO steps Algorithm 4 Sampling and density propagation for neural JKO steps Algorithm 5 Parameter selection for importance-based rejection steps Algorithm 6 Density Evaluation of Importance Corrected Neural JKO Models |
| Open Source Code | Yes | The code is available at https://github.com/johertrich/neural_JKO_ic |
| Open Datasets | Yes | We evaluate our method on the following test distributions. Mustache: ... Shifted 8 Modes: ... Shifted 8 Peaky: ... Funnel: ... GMM-d: ... GMM40-50D: Another Gaussian mixture model, which was used as an example in (Blessing et al., 2024; Chen et al., 2025) based on (Midgley et al., 2023). ... LGCP: This is a high dimensional standard example taken from (Arbel et al., 2021; Matthews et al., 2022; Vargas et al., 2023a). It describes a Log-Gaussian Cox process on a 40x40 grid as arising from spatial statistics (Møller et al., 1998). |
| Dataset Splits | No | The paper evaluates sampling performance against ground truth and other methods on various distributions. It specifies sample counts for evaluation (e.g., N = 50000 samples) but does not describe conventional training/test/validation splits for datasets, as the task is sampling from a target distribution rather than supervised learning on pre-split data. |
| Hardware Specification | Yes | The execution times are measured on a single NVIDIA RTX 4090 GPU with 24 GB memory. |
| Software Dependencies | No | For implementing the CNFs, we use the code from Ffjord (Grathwohl et al., 2019) and the torchdiffeq library by (Chen, 2018). While specific libraries are mentioned, explicit version numbers for these software components are not provided. |
| Experiment Setup | Yes | To build our importance corrected neural JKO model, we first apply n1 N JKO steps followed by n2 N blocks consisting out of a JKO step and three importance-based rejection steps. The velocity fields of the normalizing flows are parameterized by a dense three-layer feed-forward neural network. For the JKO steps, we choose an initial step size τ0 > 0 and then increase the step size exponentially by τk+1 = 4τk. The choices of n1, n2, τ0 and the number of hidden neurons from the networks is given in Table 6 together with the execution times for training and sampling. For instance, for 'Mustache', n1=6, n2=6, τ0=0.05, and 54 hidden neurons are used with a batch size of 5000. |