Skip to content

Commit 28b2391

Browse files
committed
Use scipy distance_transform_edt to replace slow loop
Use scipy distance_transform_edt to replace slow loop, speeding up get_dist_to_edge_and_gl by about five orders of magnitude.
1 parent da9e861 commit 28b2391

1 file changed

Lines changed: 52 additions & 62 deletions

File tree

compass/landice/mesh.py

Lines changed: 52 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from mpas_tools.mesh.creation.sort_mesh import sort_mesh
1717
from netCDF4 import Dataset
1818
from scipy.interpolate import NearestNDInterpolator, interpn
19+
from scipy.ndimage import distance_transform_edt
1920

2021

2122
def mpas_flood_fill(seed_mask, grow_mask, cellsOnCell, nEdgesOnCell,
@@ -435,8 +436,8 @@ def set_cell_width(self, section_name, thk, bed=None, vx=None, vy=None,
435436
return cell_width
436437

437438

438-
def get_dist_to_edge_and_gl(self, thk, topg, x, y,
439-
section_name, window_size=None):
439+
def get_dist_to_edge_and_gl(self, thk, topg, x, y, section_name,
440+
window_size=None):
440441
"""
441442
Calculate distance from each point to ice edge and grounding line,
442443
to be used in mesh density functions in
@@ -486,8 +487,10 @@ def get_dist_to_edge_and_gl(self, thk, topg, x, y,
486487
dist_to_grounding_line : numpy.ndarray
487488
Distance from each cell to the grounding line
488489
"""
490+
489491
logger = self.logger
490492
section = self.config[section_name]
493+
491494
tic = time.time()
492495

493496
high_dist = float(section.get('high_dist'))
@@ -496,78 +499,65 @@ def get_dist_to_edge_and_gl(self, thk, topg, x, y,
496499
if window_size is None:
497500
window_size = max(high_dist, high_dist_bed)
498501
elif window_size < min(high_dist, high_dist_bed):
499-
logger.info('WARNING: window_size was set to a value smaller'
500-
' than high_dist and/or high_dist_bed. Resetting'
501-
f' window_size to {max(high_dist, high_dist_bed)},'
502-
' which is max(high_dist, high_dist_bed)')
502+
logger.info(
503+
'WARNING: window_size was set smaller than high_dist and/or '
504+
'high_dist_bed. Resetting window_size to '
505+
f'{max(high_dist, high_dist_bed)}'
506+
)
503507
window_size = max(high_dist, high_dist_bed)
504508

505-
dx = x[1] - x[0] # assumed constant and equal in x and y
506-
nx = len(x)
507-
ny = len(y)
508-
sz = thk.shape
509+
dx = float(x[1] - x[0])
510+
dy = float(y[1] - y[0])
509511

510-
# Create masks to define ice edge and grounding line
511-
neighbors = np.array([[1, 0], [-1, 0], [0, 1], [0, -1],
512-
[1, 1], [-1, 1], [1, -1], [-1, -1]])
512+
# Same masks as the current implementation
513+
neighbors = np.array([
514+
[1, 0], [-1, 0], [0, 1], [0, -1],
515+
[1, 1], [-1, 1], [1, -1], [-1, -1]
516+
])
513517

514518
ice_mask = thk > 0.0
515519
grounded_mask = thk > (-1028.0 / 910.0 * topg)
516-
margin_mask = np.zeros(sz, dtype='i')
517-
grounding_line_mask = np.zeros(sz, dtype='i')
520+
521+
margin_mask = np.zeros(thk.shape, dtype=bool)
522+
grounding_line_mask = np.zeros(thk.shape, dtype=bool)
518523

519524
for n in neighbors:
520525
not_ice_mask = np.logical_not(np.roll(ice_mask, n, axis=[0, 1]))
521-
margin_mask = np.logical_or(margin_mask, not_ice_mask)
522-
523-
not_grounded_mask = np.logical_not(np.roll(grounded_mask,
524-
n, axis=[0, 1]))
525-
grounding_line_mask = np.logical_or(grounding_line_mask,
526-
not_grounded_mask)
527-
528-
# where ice exists and neighbors non-ice locations
529-
margin_mask = np.logical_and(margin_mask, ice_mask)
530-
# optional - plot mask
531-
# plt.pcolor(margin_mask); plt.show()
532-
533-
# Calculate dist to margin and grounding line
534-
[XPOS, YPOS] = np.meshgrid(x, y)
535-
dist_to_edge = np.zeros(sz)
536-
dist_to_grounding_line = np.zeros(sz)
537-
538-
d = int(np.ceil(window_size / dx))
539-
rng = np.arange(-1 * d, d, dtype='i')
540-
max_dist = float(d) * dx
526+
margin_mask |= not_ice_mask
541527

542-
# just look over areas with ice
543-
# ind = np.where(np.ravel(thk, order='F') > 0)[0]
544-
ind = np.where(np.ravel(thk, order='F') >= 0)[0] # do it everywhere
545-
for iii in range(len(ind)):
546-
[i, j] = np.unravel_index(ind[iii], sz, order='F')
547-
548-
irng = i + rng
549-
jrng = j + rng
550-
551-
# only keep indices in the grid
552-
irng = irng[np.nonzero(np.logical_and(irng >= 0, irng < ny))]
553-
jrng = jrng[np.nonzero(np.logical_and(jrng >= 0, jrng < nx))]
554-
555-
dist_to_here = ((XPOS[np.ix_(irng, jrng)] - x[j]) ** 2 +
556-
(YPOS[np.ix_(irng, jrng)] - y[i]) ** 2) ** 0.5
557-
558-
dist_to_here_edge = dist_to_here.copy()
559-
dist_to_here_grounding_line = dist_to_here.copy()
560-
561-
dist_to_here_edge[margin_mask[np.ix_(irng, jrng)] == 0] = max_dist
562-
dist_to_here_grounding_line[grounding_line_mask
563-
[np.ix_(irng, jrng)] == 0] = max_dist
564-
565-
dist_to_edge[i, j] = dist_to_here_edge.min()
566-
dist_to_grounding_line[i, j] = dist_to_here_grounding_line.min()
528+
not_grounded_mask = np.logical_not(
529+
np.roll(grounded_mask, n, axis=[0, 1])
530+
)
531+
grounding_line_mask |= not_grounded_mask
532+
533+
# Preserve current semantics for margin
534+
margin_mask &= ice_mask
535+
536+
# NOTE:
537+
# The current code does *not* apply "& ice_mask" to grounding_line_mask.
538+
# If that was intentional, keep it. If it was accidental, add:
539+
# grounding_line_mask &= ice_mask
540+
541+
# EDT computes distance to nearest zero.
542+
# So invert the boundary masks: non-boundary=1, boundary=0.
543+
dist_to_edge = distance_transform_edt(
544+
~margin_mask, sampling=(dy, dx)
545+
)
546+
dist_to_grounding_line = distance_transform_edt(
547+
~grounding_line_mask, sampling=(dy, dx)
548+
)
549+
550+
# Preserve the current "window_size" behavior by clipping large distances.
551+
# The old code returned max_dist for anything outside the search window.
552+
max_dist = float(np.ceil(window_size / max(dx, dy))) * max(dx, dy)
553+
dist_to_edge = np.minimum(dist_to_edge, max_dist)
554+
dist_to_grounding_line = np.minimum(dist_to_grounding_line, max_dist)
567555

568556
toc = time.time()
569-
logger.info('compass.landice.mesh.get_dist_to_edge_and_gl() took {:0.2f} '
570-
'seconds'.format(toc - tic))
557+
logger.info(
558+
'compass.landice.mesh.get_dist_to_edge_and_gl() took '
559+
f'{toc - tic:0.2f} seconds'
560+
)
571561

572562
return dist_to_edge, dist_to_grounding_line
573563

0 commit comments

Comments
 (0)