Skip to content

Commit 2070126

Browse files
authored
Merge pull request #365 from MiraGeoscience/GEOPY-2781
GEOPY-2781: Inversion stalls on tiling for large problems during redistribution of clusters
2 parents 7033425 + a9b8088 commit 2070126

3 files changed

Lines changed: 31 additions & 47 deletions

File tree

pyproject.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@ requires = [
88
"jinja2 ==3.*",
99
"packaging >=24.0",
1010
"tomlkit >=0.13",
11+
"setuptools>=80",
12+
"setuptools-scm[simple]>=9.2.*"
1113
]
1214
build-backend = "poetry_dynamic_versioning.backend"
1315

simpeg_drivers/utils/nested.py

Lines changed: 1 addition & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -13,19 +13,13 @@
1313
import warnings
1414
from collections.abc import Iterable
1515
from copy import copy
16-
from itertools import chain
1716
from pathlib import Path
1817

1918
import numpy as np
20-
from dask import compute, delayed
21-
from dask.distributed import get_client
2219
from discretize import TensorMesh, TreeMesh
2320
from geoh5py.shared.utils import uuid_from_values
24-
from scipy.optimize import linear_sum_assignment
2521
from scipy.spatial import cKDTree
26-
from scipy.spatial.distance import cdist
2722
from simpeg import data, data_misfit, maps, meta, objective_function
28-
from simpeg.dask.objective_function import DistributedComboMisfits
2923
from simpeg.data_misfit import L2DataMisfit
3024
from simpeg.electromagnetics.base_1d import BaseEM1DSimulation
3125
from simpeg.electromagnetics.frequency_domain.simulation import BaseFDEMSimulation
@@ -539,21 +533,9 @@ def tile_locations(
539533
from sklearn.cluster import KMeans
540534

541535
kmeans = KMeans(n_clusters=n_tiles, random_state=0, n_init="auto")
542-
cluster_size = int(np.ceil(grid_locs.shape[0] / n_tiles))
543536
kmeans.fit(grid_locs)
544537

545-
if labels is not None:
546-
cluster_id = kmeans.labels_
547-
else:
548-
# Redistribute cluster centers to even out the number of points
549-
centers = kmeans.cluster_centers_
550-
centers = (
551-
centers.reshape(-1, 1, grid_locs.shape[1])
552-
.repeat(cluster_size, 1)
553-
.reshape(-1, grid_locs.shape[1])
554-
)
555-
distance_matrix = cdist(grid_locs, centers)
556-
cluster_id = linear_sum_assignment(distance_matrix)[1] // cluster_size
538+
cluster_id = kmeans.labels_
557539

558540
tiles = []
559541
for tid in set(cluster_id):

tests/locations_test.py

Lines changed: 28 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -115,34 +115,34 @@ def test_filter(tmp_path: Path):
115115
assert np.all(filtered_data["key"] == [2, 3, 4])
116116

117117

118-
def test_tile_locations(tmp_path: Path):
119-
with Workspace.create(tmp_path / f"{__name__}.geoh5") as ws:
120-
grid_x, grid_y = np.meshgrid(np.arange(100), np.arange(100))
121-
choices = np.c_[grid_x.ravel(), grid_y.ravel(), np.zeros(grid_x.size)]
122-
inds = np.random.randint(0, 10000, 1000)
123-
pts = Points.create(
124-
ws,
125-
name="test-points",
126-
vertices=choices[inds],
127-
)
128-
tiles = tile_locations(pts.vertices[:, :2], n_tiles=8)
129-
130-
values = np.zeros(pts.n_vertices)
131-
pop = []
132-
for ind, tile in enumerate(tiles):
133-
values[tile] = ind
134-
pop.append(len(tile))
135-
136-
pts.add_data(
137-
{
138-
"values": {
139-
"values": values,
140-
}
141-
}
142-
)
143-
assert np.std(pop) / np.mean(pop) < 0.02, (
144-
"Population of tiles are not almost equal."
145-
)
118+
# TODO Find a scalable algo better than linear_sum_assignment to do even split
119+
# The tiling strategy should yield even "densities" (area x n_receivers)
120+
# def test_tile_locations(tmp_path: Path):
121+
# with Workspace.create(tmp_path / f"{__name__}.geoh5") as ws:
122+
# grid_x, grid_y = np.meshgrid(np.arange(100), np.arange(100))
123+
# choices = np.c_[grid_x.ravel(), grid_y.ravel(), np.zeros(grid_x.size)]
124+
# inds = np.random.randint(0, 10000, 1000)
125+
# pts = Points.create(
126+
# ws,
127+
# name="test-points",
128+
# vertices=choices[inds],
129+
130+
131+
def test_tile_locations():
132+
n_points = 1000
133+
rng = np.random.default_rng(0)
134+
locations = rng.standard_normal((n_points, 2))
135+
136+
tiles = tile_locations(locations, n_tiles=8)
137+
138+
# All indices should be covered exactly once across tiles
139+
all_indices = np.concatenate(tiles)
140+
assert np.array_equal(np.sort(all_indices), np.arange(n_points))
141+
142+
# Tiles should be reasonably balanced in population
143+
pop = np.array([len(tile) for tile in tiles])
144+
assert pop.min() > 0
145+
assert np.std(pop) / np.mean(pop) < 0.5
146146

147147

148148
def test_tile_locations_labels(tmp_path: Path):

0 commit comments

Comments
 (0)