Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 7 additions & 9 deletions RELEASES.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,11 @@

This new release adds support for sparse cost matrices and a new lazy EMD solver that computes distances on-the-fly from coordinates, reducing memory usage from O(n×m) to O(n+m). Both implementations are backend-agnostic and preserve gradient computation for automatic differentiation.

#### New features
#### New features

- Refactor lazy EMD network simplex storage to avoid dense per-arc cost,
endpoint, flow, and state storage where possible, and return sparse lazy
transport plans instead of materializing dense plans internally (PR #813)
- Add sliced transport plans (min-pivot sliced and expected sliced) solvers (PR #767)
- Add lazy EMD solver with on-the-fly distance computation from coordinates (PR #788)
- Add Warmstart feature to the EMD solver for existing potentials (PR #793)
Expand All @@ -17,17 +20,12 @@ This new release adds support for sparse cost matrices and a new lazy EMD solver
- Add "BSP-OT: Sparse transport plans between discrete measures in loglinear time" (PR #768)
- Added UOT1D with Frank-Wolfe in `ot.unbalanced.uot_1d` (PR #765)
- Add Sliced UOT and Unbalanced Sliced OT in `ot/unbalanced/_sliced.py` (PR #765)
- Add `ot.utils.DataScaler` class for backend-aware joint normalization of input
distributions, with sklearn-compatible `fit`/`transform`/`fit_transform` API and
support for `'standard'`, `'minmax'`, and `'l2'` methods (PR #808)
- Add `ot.utils.DataScaler` class for backend-aware joint normalization of input distributions, with sklearn-compatible `fit`/`transform`/`fit_transform` API and support for `'standard'`, `'minmax'`, and `'l2'` methods (PR #808)
- Add `ot.utils.apply_scaler` helper that dispatches preprocessing to a scaler object,
a callable, or a no-op (PR #808)
- Add optional `scaler` parameter to `sliced_wasserstein_distance` and
`max_sliced_wasserstein_distance` (PR #808)
- Add optional `scaler` parameter to `sliced_wasserstein_distance` and `max_sliced_wasserstein_distance` (PR #808)
- Add a numerically stable log-domain solver for entropic partial Wasserstein, selectable via the new `method` parameter of `entropic_partial_wasserstein` (`method='sinkhorn_log'`) or directly through `entropic_partial_wasserstein_logscale` (Issue #723)
- Add cost functions between linear operators following
[A Spectral-Grassmann Wasserstein metric for operator representations of dynamical systems](https://arxiv.org/pdf/2509.24920),
implemented in `ot.sgot` (PR #792)
- Add cost functions between linear operators following [A Spectral-Grassmann Wasserstein metric for operator representations of dynamical systems](https://arxiv.org/pdf/2509.24920), implemented in `ot.sgot` (PR #792)
- Build wheels on ubuntu ARM to avoid QEMU emulation (PR #818)

#### Closed issues
Expand Down
9 changes: 7 additions & 2 deletions ot/lp/EMD.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,8 @@ int EMD_wrap_sparse(
uint64_t *flow_sources_out, // Output: source indices of non-zero flows
uint64_t *flow_targets_out, // Output: target indices of non-zero flows
double *flow_values_out, // Output: flow values
uint64_t *n_flows_out,
uint64_t *n_flows_out,
uint64_t max_flows_out,
double *alpha, // Output: dual variables for sources (n1)
double *beta, // Output: dual variables for targets (n2)
double *cost, // Output: total transportation cost
Expand All @@ -62,7 +63,11 @@ int EMD_wrap_lazy(
double *coords_b, // Target coordinates (n2 x dim)
int dim, // Dimension of coordinates
int metric, // Distance metric: 0=sqeuclidean, 1=euclidean, 2=cityblock
double *G, // Output: transport plan (n1 x n2)
uint64_t *flow_sources_out, // Output: source indices of non-zero flows
uint64_t *flow_targets_out, // Output: target indices of non-zero flows
double *flow_values_out, // Output: flow values
uint64_t *n_flows_out,
uint64_t max_flows_out,
double *alpha, // Output: dual variables for sources (n1)
double *beta, // Output: dual variables for targets (n2)
double *cost, // Output: total transportation cost
Expand Down
158 changes: 100 additions & 58 deletions ot/lp/EMD_wrapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,62 @@ inline void extract_compressed_support(
}
}

template <
typename NetType,
typename DigraphType,
typename InvalidType,
typename SourceIndexVector,
typename TargetIndexVector,
typename CostAccessor
>
inline bool extract_sparse_solution(
const NetType& net,
DigraphType& di,
InvalidType invalid,
const SourceIndexVector& idx_a,
const TargetIndexVector& idx_b,
double* alpha,
double* beta,
double* cost,
uint64_t* flow_sources_out,
uint64_t* flow_targets_out,
double* flow_values_out,
uint64_t* n_flows_out,
uint64_t max_flows_out,
CostAccessor cost_accessor,
double min_output_flow
) {
const int n = static_cast<int>(idx_a.size());

for (int i = 0; i < n; i++) {
alpha[static_cast<uint64_t>(idx_a[i])] = -net.potential(i);
}
for (int j = 0; j < static_cast<int>(idx_b.size()); j++) {
beta[static_cast<uint64_t>(idx_b[j])] = net.potential(j + n);
}

typename DigraphType::Arc a;
di.first(a);
for (; a != invalid; di.next(a)) {
const int i = di.source(a);
const int j = di.target(a) - n;
const double flow = net.flow(a);
if (flow != 0) {
*cost += flow * cost_accessor(a, i, j);
}
if (flow > min_output_flow) {
if (*n_flows_out >= max_flows_out) {
return false;
}
flow_sources_out[*n_flows_out] = static_cast<uint64_t>(idx_a[i]);
flow_targets_out[*n_flows_out] = static_cast<uint64_t>(idx_b[j]);
flow_values_out[*n_flows_out] = flow;
++(*n_flows_out);
}
}
return true;
}

} // namespace


Expand Down Expand Up @@ -186,9 +242,9 @@ int EMD_wrap(int n1, int n2, double *X, double *Y, double *D, double *G,
std::vector<double> weights1(n), weights2(m);
Digraph di(n, m);
const SetupPolicy policy = make_setup_policy(n, m, n1, n2, true);
NetworkSimplexSimple<Digraph,double,double, node_id_type> net(
di, policy.use_arc_mixing, (int) (n + m), n * m, maxIter
);
typedef NetworkSimplexSimple<Digraph, double, double, node_id_type> Simplex;
Simplex::SimplexOptions simplex_options(policy.use_arc_mixing);
Simplex net(di, simplex_options, (int) (n + m), n * m, maxIter);

// Set supply and demand, don't account for 0 values (faster)

Expand Down Expand Up @@ -341,6 +397,7 @@ int EMD_wrap_sparse(
uint64_t *flow_targets_out,
double *flow_values_out,
uint64_t *n_flows_out,
uint64_t max_flows_out,
double *alpha,
double *beta,
double *cost,
Expand Down Expand Up @@ -432,9 +489,9 @@ int EMD_wrap_sparse(

di.buildFromEdges(edges);

NetworkSimplexSimple<Digraph, double, double, node_id_type> net(
di, true, (int)(n + m), di.arcNum(), maxIter
);
typedef NetworkSimplexSimple<Digraph, double, double, node_id_type> Simplex;
Simplex::SimplexOptions simplex_options(true);
Simplex net(di, simplex_options, (int)(n + m), di.arcNum(), maxIter);

net.supplyMap(&weights1[0], (int)n, &weights2[0], (int)m);

Expand Down Expand Up @@ -463,41 +520,27 @@ int EMD_wrap_sparse(
int ret = net.run();
if (ret == (int)net.OPTIMAL || ret == (int)net.MAX_ITER_REACHED) {
*cost = 0;
*n_flows_out = 0;
*n_flows_out = 0;

Arc a;
di.first(a);
for (; a != INVALID; di.next(a)) {
uint64_t i = di.source(a);
uint64_t j = di.target(a);
double flow = net.flow(a);

uint64_t orig_i = indI[i];
uint64_t orig_j = indJ[j - n];


double arc_cost = arc_costs[a];

*cost += flow * arc_cost;


*(alpha + orig_i) = -net.potential(i);
*(beta + orig_j) = net.potential(j);

if (flow > 1e-15) {
flow_sources_out[*n_flows_out] = orig_i;
flow_targets_out[*n_flows_out] = orig_j;
flow_values_out[*n_flows_out] = flow;
(*n_flows_out)++;
}
auto sparse_cost = [&arc_costs](Arc a, int, int) {
return arc_costs[a];
};
if (!extract_sparse_solution(
net, di, INVALID, indI, indJ, alpha, beta, cost,
flow_sources_out, flow_targets_out, flow_values_out,
n_flows_out, max_flows_out, sparse_cost, 1e-15)) {
return (int)net.MAX_ITER_REACHED;
}
}
return ret;
}

int EMD_wrap_lazy(int n1, int n2, double *X, double *Y, double *coords_a, double *coords_b,
int dim, int metric, double *G, double *alpha, double *beta,
double *cost, uint64_t maxIter, double *alpha_init, double *beta_init) {
int dim, int metric, uint64_t *flow_sources_out,
uint64_t *flow_targets_out, double *flow_values_out,
uint64_t *n_flows_out, uint64_t max_flows_out,
double *alpha, double *beta, double *cost, uint64_t maxIter,
double *alpha_init, double *beta_init) {
using namespace lemon;
typedef FullBipartiteDigraph Digraph;
DIGRAPH_TYPEDEFS(Digraph);
Expand Down Expand Up @@ -552,8 +595,19 @@ int EMD_wrap_lazy(int n1, int n2, double *X, double *Y, double *coords_a, double
// Create full bipartite graph
Digraph di(n, m);

NetworkSimplexSimple<Digraph, double, double, node_id_type> net(
di, true, (int)(n + m), (uint64_t)(n) * (uint64_t)(m), maxIter
typedef NetworkSimplexSimple<Digraph, double, double, node_id_type> Simplex;
Simplex::SimplexOptions simplex_options(false);
// Lazy mode does not store costs or endpoints for the real complete
// bipartite arcs. Artificial root arcs are still explicit because the
// simplex initialization assigns them costs 0 or ART_COST.
simplex_options.cost_storage_mode = Simplex::CostStorageMode::ArtificialArcCosts;
simplex_options.flow_storage_mode = Simplex::FlowStorageMode::SparseArcFlows;
simplex_options.endpoint_storage_mode =
Simplex::EndpointStorageMode::ArcEndpoints;
simplex_options.state_storage_mode = Simplex::StateStorageMode::PackedArcStates;

Simplex net(
di, simplex_options, (int)(n + m), (uint64_t)(n) * (uint64_t)(m), maxIter
);

// Set supplies
Expand Down Expand Up @@ -583,32 +637,20 @@ int EMD_wrap_lazy(int n1, int n2, double *X, double *Y, double *coords_a, double

if (ret == (int)net.OPTIMAL || ret == (int)net.MAX_ITER_REACHED) {
*cost = 0;
*n_flows_out = 0;

// Initialize output arrays
for (int i = 0; i < n1 * n2; i++) G[i] = 0.0;
for (int i = 0; i < n1; i++) alpha[i] = 0.0;
for (int i = 0; i < n2; i++) beta[i] = 0.0;

// Extract solution
Arc a;
di.first(a);
for (; a != INVALID; di.next(a)) {
int i = di.source(a);
int j = di.target(a) - n;

int orig_i = idx_a[i];
int orig_j = idx_b[j];

double flow = net.flow(a);
G[orig_i * n2 + orig_j] = flow;

alpha[orig_i] = -net.potential(i);
beta[orig_j] = net.potential(j + n);

if (flow > 0) {
double c = net.computeLazyCost(i, j);
*cost += flow * c;
}

auto lazy_cost = [&net](Arc, int i, int j) {
return net.computeLazyCost(i, j);
};
if (!extract_sparse_solution(
net, di, INVALID, idx_a, idx_b, alpha, beta, cost,
flow_sources_out, flow_targets_out, flow_values_out,
n_flows_out, max_flows_out, lazy_cost, 0.0)) {
return (int)net.MAX_ITER_REACHED;
}
}

Expand Down
14 changes: 11 additions & 3 deletions ot/lp/_network_simplex.py
Original file line number Diff line number Diff line change
Expand Up @@ -1037,7 +1037,7 @@ def emd2_lazy(
alpha_init_np = np.asarray(alpha_init_np, dtype=np.float64, order="C")
beta_init_np = np.asarray(beta_init_np, dtype=np.float64, order="C")

G, cost, u, v, result_code = emd_c_lazy(
flow_sources, flow_targets, flow_values, cost, u, v, result_code = emd_c_lazy(
a_np, b_np, X_a_np, X_b_np, metric, numItermax, alpha_init_np, beta_init_np
)

Expand All @@ -1053,8 +1053,6 @@ def emd2_lazy(
stacklevel=2,
)

G_backend = nx.from_numpy(G, type_as=type_as)

cost_backend = nx.set_gradients(
nx.from_numpy(cost, type_as=type_as),
(a0, b0),
Expand All @@ -1075,6 +1073,16 @@ def emd2_lazy(
"result_code": result_code,
}
if return_matrix:
flow_values_backend = nx.from_numpy(flow_values, type_as=type_as)
flow_sources_backend = nx.from_numpy(flow_sources.astype(np.int64))
flow_targets_backend = nx.from_numpy(flow_targets.astype(np.int64))
G_backend = nx.coo_matrix(
flow_values_backend,
flow_sources_backend,
flow_targets_backend,
shape=(n1, n2),
type_as=type_as,
)
log_dict["G"] = G_backend
return cost_backend, log_dict
else:
Expand Down
21 changes: 15 additions & 6 deletions ot/lp/emd_wrap.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ import warnings
cdef extern from "EMD.h":
int EMD_wrap(int n1,int n2, double *X, double *Y,double *D, double *G, double* alpha, double* beta, double *cost, uint64_t maxIter, double* alpha_init, double* beta_init) nogil
int EMD_wrap_omp(int n1,int n2, double *X, double *Y,double *D, double *G, double* alpha, double* beta, double *cost, uint64_t maxIter, int numThreads) nogil
int EMD_wrap_sparse(int n1, int n2, double *X, double *Y, uint64_t n_edges, uint64_t *edge_sources, uint64_t *edge_targets, double *edge_costs, uint64_t *flow_sources_out, uint64_t *flow_targets_out, double *flow_values_out, uint64_t *n_flows_out, double *alpha, double *beta, double *cost, uint64_t maxIter, double* alpha_init, double* beta_init) nogil
int EMD_wrap_lazy(int n1, int n2, double *X, double *Y, double *coords_a, double *coords_b, int dim, int metric, double *G, double* alpha, double* beta, double *cost, uint64_t maxIter, double* alpha_init, double* beta_init) nogil
int EMD_wrap_sparse(int n1, int n2, double *X, double *Y, uint64_t n_edges, uint64_t *edge_sources, uint64_t *edge_targets, double *edge_costs, uint64_t *flow_sources_out, uint64_t *flow_targets_out, double *flow_values_out, uint64_t *n_flows_out, uint64_t max_flows_out, double *alpha, double *beta, double *cost, uint64_t maxIter, double* alpha_init, double* beta_init) nogil
int EMD_wrap_lazy(int n1, int n2, double *X, double *Y, double *coords_a, double *coords_b, int dim, int metric, uint64_t *flow_sources_out, uint64_t *flow_targets_out, double *flow_values_out, uint64_t *n_flows_out, uint64_t max_flows_out, double* alpha, double* beta, double *cost, uint64_t maxIter, double* alpha_init, double* beta_init) nogil
cdef enum ProblemType: INFEASIBLE, OPTIMAL, UNBOUNDED, MAX_ITER_REACHED


Expand Down Expand Up @@ -306,7 +306,7 @@ def emd_c_sparse(np.ndarray[double, ndim=1, mode="c"] a,
n_edges,
<uint64_t*> edge_sources.data, <uint64_t*> edge_targets.data, <double*> edge_costs.data,
<uint64_t*> flow_sources.data, <uint64_t*> flow_targets.data, <double*> flow_values.data,
&n_flows_out,
&n_flows_out, n_edges,
<double*> alpha.data, <double*> beta.data, &cost, max_iter,
alpha_init_ptr, beta_init_ptr
)
Expand All @@ -329,6 +329,8 @@ def emd_c_lazy(np.ndarray[double, ndim=1, mode="c"] a, np.ndarray[double, ndim=1
cdef int result_code = 0
cdef double cost = 0
cdef int metric_code
cdef uint64_t n_flows_out = 0
cdef uint64_t max_flows_out = n1 + n2

# Validate dimension consistency
if coords_b.shape[1] != dim:
Expand All @@ -345,9 +347,11 @@ def emd_c_lazy(np.ndarray[double, ndim=1, mode="c"] a, np.ndarray[double, ndim=1
except KeyError:
raise ValueError(f"Unknown metric: '{metric}'. Supported metrics are: {list(metric_map.keys())}")

cdef np.ndarray[uint64_t, ndim=1, mode="c"] flow_sources = np.zeros(max_flows_out, dtype=np.uint64)
cdef np.ndarray[uint64_t, ndim=1, mode="c"] flow_targets = np.zeros(max_flows_out, dtype=np.uint64)
cdef np.ndarray[double, ndim=1, mode="c"] flow_values = np.zeros(max_flows_out, dtype=np.float64)
cdef np.ndarray[double, ndim=1, mode="c"] alpha = np.zeros(n1)
cdef np.ndarray[double, ndim=1, mode="c"] beta = np.zeros(n2)
cdef np.ndarray[double, ndim=2, mode="c"] G = np.zeros([n1, n2])
if not len(a):
a = np.ones((n1,)) / n1
if not len(b):
Expand All @@ -360,5 +364,10 @@ def emd_c_lazy(np.ndarray[double, ndim=1, mode="c"] a, np.ndarray[double, ndim=1
beta_init_ptr = <double*> beta_init.data

with nogil:
result_code = EMD_wrap_lazy(n1, n2, <double*> a.data, <double*> b.data, <double*> coords_a.data, <double*> coords_b.data, dim, metric_code, <double*> G.data, <double*> alpha.data, <double*> beta.data, <double*> &cost, max_iter, alpha_init_ptr, beta_init_ptr)
return G, cost, alpha, beta, result_code
result_code = EMD_wrap_lazy(n1, n2, <double*> a.data, <double*> b.data, <double*> coords_a.data, <double*> coords_b.data, dim, metric_code, <uint64_t*> flow_sources.data, <uint64_t*> flow_targets.data, <double*> flow_values.data, &n_flows_out, max_flows_out, <double*> alpha.data, <double*> beta.data, <double*> &cost, max_iter, alpha_init_ptr, beta_init_ptr)

flow_sources = flow_sources[:n_flows_out]
flow_targets = flow_targets[:n_flows_out]
flow_values = flow_values[:n_flows_out]

return flow_sources, flow_targets, flow_values, cost, alpha, beta, result_code
Loading
Loading