UltraTWD: Optimizing Ultrametric Trees for Tree-Wasserstein Distance

Authors: Fangchen Yu, Yanzhen Chen, Jiaxing Wei, Jianfeng Mao, Wenye Li, Qiang Sun

ICML 2025 | Venue PDF | Archive PDF | Plain Text | LLM Run Details

Reproducibility Variable Result LLM Response
Research Type Experimental Section 4. Experiments: We evaluate our methods using four benchmark text datasets: BBCSport, Reuters, Ohsumed, and Recipe (Huang et al., 2016), following previous studies (Takezawa et al., 2021; Chen et al., 2024). These datasets are publicly available. Table 3. Comprehensive comparison of tree-Wasserstein distance methods across four text datasets. Our Ultra TWD-IP and GD methods outperform the baseline methods in most cases across (a) approximation error, (b) document retrieval, (c) document ranking, and (d) document classification.
Researcher Affiliation Academia 1 The Chinese University of Hong Kong, Shenzhen 2 Mohamed bin Zayed University of Artificial Intelligence 3 Shenzhen Research Institute of Big Data 4 The Hong Kong University of Science and Technology (Guangzhou) 5 University of Toronto.
Pseudocode Yes Algorithm 1 Ultra TWD-MST (Minimum Spanning Tree) Input: D Rn n: cost matrix. Output: T : optimal ultrametric tree under infinity norm. 1: Compute the l -nearest ultrametric via Eq. (12): D T = MST(D) + 1 2 MST(D) D 1. 2: Construct the ultrametric tree: T = MST(D T ). Algorithm 2 Ultra TWD-IP (Iterative Projection) Input: D Rn n: cost matrix, m: maximum number of iterations (default m = 1). Output: T : the ultrametric tree under the Frobenius norm. 1: Initialize D0 = D. 2: for t = 1 to m do 3: Update Dt 1 t+1D + t t+1Dt 1 (Movement step) 4: for each triplet (i, j, k) do 5: Dt PΩijk(Dt) (Projection step) 6: end for 7: end for 8: Construct the ultrametric tree: T = MST(Dt). Algorithm 3 Ultra TWD-GD (Gradient Descent) Input: D Rn n: cost matrix, m: maximum iterations (default m = 8), α: learning rate (default α = 0.02). Output: T : the ultrametric tree under the Frobenius norm 1: Initialize T 0 = MST(D) with node heights H0 T . 2: for t = 1 to m do 3: Update heights: Ht T Ht 1 T α F(Ht 1 T ). 4: Compute the ultrametric: Dt = f(Ht T ). 5: Adjust entries of Dt: dt ij 1 2(2dt ij dt ii dt jj). 6: Update tree: T t = MST(Dt) with heights Ht T . 7: end for 8: Return the optimized ultrametric tree: T = T t.
Open Source Code Yes Code is available at: https: //github.com/Ne XAIS/Ultra TWD.
Open Datasets Yes We evaluate our methods using four benchmark text datasets: BBCSport, Reuters, Ohsumed, and Recipe (Huang et al., 2016), following previous studies (Takezawa et al., 2021; Chen et al., 2024). These datasets are publicly available1, with detailed statistics provided in Table 2. 1https://github.com/mkusner/wmd
Dataset Splits Yes Each dataset contains 5 test sets. For each test set, we construct the vocabulary X = [x1, . . . , xn] Rd n, where d = 300 is the word embedding dimension and n is the number of unique words (average values shown in Table 2). Each test document µi is represented as a normalized bag-of-words distribution. BBCSport: This dataset consists of 737 BBC sports articles labeled by 5 classes. It contains five test sets with 220 articles each, averaging 6,051 words per test set. Reuters: This dataset contains 7,674 news articles across 8 classes. We randomly generate five test sets with 1,000 articles each, averaging 6,416 words per test set. Ohsumed: This dataset contains 9,152 medical abstracts within 10 classes. We randomly create five test sets with 1,000 abstracts each, averaging 9,467 words per test set. Recipe: This dataset consists of 4,370 recipe procedures with 15 classes. It contains five test sets with 1,311 tweets each, averaging 4,084 words per test set. For Ultra Tree, five training sets are randomly generated for each dataset, each containing 1,000 document distributions. Each document distribution is created using numpy.random.randint(1,11,size=(1,n word)), with 1% of the entries randomly selected to form a sparse vector, which is then normalized to a probability distribution.
Hardware Specification Yes All experiments were conducted using Python 3.8 on a Linux server with an AMD EPYC 7742 64-Core Processor, 256 logical CPUs, and 512 GB RAM.
Software Dependencies No All experiments were conducted using Python 3.8 on a Linux server. The gradient is efficiently computed using Py Torch s automatic differentiation. The Sinkhorn distance is computed using the ot.sinkhorn2 function from the Python Optimal Transport library (Flamary et al., 2021).
Experiment Setup Yes Sinkhorn method: The Sinkhorn distance is computed using the ot.sinkhorn2 function from the Python Optimal Transport library (Flamary et al., 2021), with a regularization parameter of λ = 1 in Eq. (14) and a maximum number of iterations num Itermax = 100. For q TWD and c TWD, the regularization parameter in Eq. (15) is set to λ = 0.001, following the configuration in Yamada et al. (2022). For the tree-sliced methods, the number of multiple trees is set to K = 3 in Eq. (16), consistent with Yamada et al. (2022). For Ultra Tree, five training sets are randomly generated for each dataset, each containing 1,000 document distributions. The official code is used with default settings: a batch size of 32, the Adam optimizer with a learning rate of 0.01, and a maximum of 5 iterations (typically converging within 2). Algorithm 2 Ultra TWD-IP (Iterative Projection) Input: D Rn n: cost matrix, m: maximum number of iterations (default m = 1). Algorithm 3 Ultra TWD-GD (Gradient Descent) Input: D Rn n: cost matrix, m: maximum iterations (default m = 8), α: learning rate (default α = 0.02).