Skip to content

Commit 64ba8fe

Browse files
committed
enable subpixel shifts
1 parent 8929344 commit 64ba8fe

2 files changed

Lines changed: 71 additions & 72 deletions

File tree

src/batsim/_gsinterface.cpp

Lines changed: 25 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,6 @@ py::array_t<double> convolvePsf(
7171
const int downsample_ratio,
7272
const int ngrid
7373
){
74-
7574
bool test = is_c_contiguous(gal_prof);
7675
if (! test) {
7776
throw std::runtime_error(
@@ -130,46 +129,37 @@ py::array_t<double> convolvePsf(
130129
fftw_destroy_plan(p_backward);
131130
fftw_free(out2);
132131

133-
// Wrap the result in a numpy array
134-
// prevent memory leakage
132+
// Normalize once; FFTW inverse is unnormalized
133+
const double inv_norm = 1.0 / static_cast<double>(dim2) / static_cast<double>(dim2);
134+
135135
auto result = py::array_t<double>({ngrid, ngrid});
136-
// Use unchecked for faster access
137136
auto r = result.mutable_unchecked<2>();
138137

139-
int dim2_center = dim2 / 2;
140-
int res_center = ngrid / 2;
141-
142-
// Normalize the inverse FFT result
143-
const int norm_factor = dim2 * dim2;
144-
if (dim2_center >= res_center) {
145-
// shrinking
146-
for (int y = 0; y < ngrid; ++y) {
147-
int yf = y - res_center + dim2_center;
148-
for (int x = 0; x < ngrid; ++x) {
149-
// Calculate the corresponding source index in ifft_out
150-
int xf = x - res_center + dim2_center;
151-
// Copy and normalize
152-
r(y, x) = ifft_out[yf * dim2 + xf] / norm_factor;
138+
// Define source and destination rectangles centered
139+
const int src_w = dim2, src_h = dim2;
140+
const int dst_w = ngrid, dst_h = ngrid;
141+
142+
// Centers with explicit floor for clarity
143+
const int src_cx = src_w / 2; // floor
144+
const int src_cy = src_h / 2;
145+
const int dst_cx = dst_w / 2;
146+
const int dst_cy = dst_h / 2;
147+
148+
// Compute the overlap box in destination coordinates
149+
// We want to place the src centered into dst.
150+
for (int dy = 0; dy < dst_h; ++dy) {
151+
int sy = dy - dst_cy + src_cy;
152+
bool in_y = (0 <= sy && sy < src_h);
153+
for (int dx = 0; dx < dst_w; ++dx) {
154+
int sx = dx - dst_cx + src_cx;
155+
bool in_x = (0 <= sx && sx < src_w);
156+
if (in_x && in_y) {
157+
r(dy, dx) = ifft_out[sy * src_w + sx] * inv_norm;
158+
} else {
159+
r(dy, dx) = 0.0; // pad outside
153160
}
154161
}
155162
}
156-
else {
157-
// padding zeros (dim2 < dim_res)
158-
std::fill(result.mutable_data(), result.mutable_data() + ngrid * ngrid, 0.0);
159-
int start = res_center - dim2_center;
160-
int end = res_center + dim2_center;
161-
for (int y = start; y < end; ++y) {
162-
int yf = y - res_center + dim2_center;
163-
for (int x = start; x < end; ++x) {
164-
// Calculate the corresponding source index in ifft_out
165-
int xf = x - res_center + dim2_center;
166-
// Copy and normalize
167-
r(y, x) = ifft_out[yf * dim2 + xf] / norm_factor;
168-
}
169-
}
170-
171-
}
172-
173163
// Cleanup fftw
174164
fftw_free(ifft_out);
175165
return result;

src/batsim/sim.py

Lines changed: 46 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,9 @@ def simulate_galaxy(
1616
truncate_ratio=1.0,
1717
maximum_num_grids=4096,
1818
draw_method="auto",
19-
force_ngrid=False
19+
force_ngrid=False,
20+
delta_image_x=0.0,
21+
delta_image_y=0.0,
2022
):
2123
"""The function samples the surface density field of a galaxy at the grids
2224
This function only conduct sampling; PSF and pixel response are not
@@ -32,14 +34,19 @@ def simulate_galaxy(
3234
truncate at truncate_ratio times good_image_size
3335
maximum_num_grids (int):
3436
maximum number of grids for simulation in real space
35-
draw_method (str): method to draw the galaxy image, "auto" will convolve with
36-
pixel response, "no_pixel" is as it implies
37-
force_ngrid (bool): If True, force the number of grids to be ngrid even if a smaller
38-
number of grids is sufficient for the simulation
37+
draw_method (str): method to draw the galaxy image, "auto" will convolve
38+
with pixel response, "no_pixel" is as it implies
39+
force_ngrid (bool): If True, force the number of grids to be ngrid even if
40+
a smaller number of grids is sufficient for the
41+
simulation
3942
Returns:
4043
outcome (ndarray): 2D galaxy image on the grids
4144
"""
4245

46+
gobj = gal_obj.shift(
47+
delta_image_x * pix_scale,
48+
delta_image_y * pix_scale,
49+
)
4350
# Initialize variables based on PSF presence
4451
if psf_obj is None and draw_method == "no_pixel":
4552
# In this case we just get the fluxes for the requested stamp size
@@ -52,25 +59,24 @@ def simulate_galaxy(
5259
pad_arcsec = 0.0
5360
downsample_ratio = 1
5461
else:
55-
scale = min(gal_obj.nyquist_scale, psf_obj.nyquist_scale / 4.0, pix_scale / 4.0)
56-
pad_arcsec = psf_obj.calculateMomentRadius(size=32, scale=pix_scale / 2.0)
62+
scale = min(gobj.nyquist_scale, pix_scale / 4.0)
63+
pad_arcsec = psf_obj.calculateMomentRadius(
64+
size=32, scale=pix_scale / 2.0)
5765
downsample_ratio = min(int(2 ** np.ceil(np.log2(pix_scale / scale))), 128)
58-
66+
5967
scale = pix_scale / downsample_ratio
6068

6169
# Calculate the number of grids considering padding and truncation
6270
npad = int(pad_arcsec / scale + 0.5) * 4
63-
nn = npad * 2 + min(gal_obj.getGoodImageSize(pixel_scale=scale)
64-
* truncate_ratio, ngrid * downsample_ratio
65-
)
71+
nn = npad * 2 + min(
72+
gobj.getGoodImageSize(scale)
73+
* truncate_ratio, ngrid * downsample_ratio
74+
)
6675
nn = min(int(2 ** np.ceil(np.log2(nn))), maximum_num_grids)
6776

68-
print(nn)
6977
if force_ngrid and nn < ngrid:
7078
nn = ngrid
7179
scale = pix_scale
72-
73-
print(nn, scale)
7480
# Initialize and Distort Coordinates
7581
stamp = Stamp(nn=nn, scale=scale)
7682
if transform_obj is not None:
@@ -81,32 +87,35 @@ def simulate_galaxy(
8187
# Sample the galaxy flux
8288
gal_prof = _gsinterface.getFluxVec(
8389
scale=scale,
84-
gsobj=gal_obj._sbp,
90+
gsobj=gobj._sbp,
8591
xy_coords=gal_coords
86-
)
87-
92+
)
8893
# No convolution necessary in this case so just return the fluxes
89-
if draw_method == "no_pixel" and psf_obj is None:
90-
return gal_prof
91-
92-
# Construct pixel response
93-
pixel_response = galsim.Pixel(scale=pix_scale)
94-
if psf_obj is None:
95-
psf_obj = pixel_response
94+
if draw_method == "no_pixel":
95+
if psf_obj is None:
96+
return gal_prof
97+
else:
98+
pass
99+
elif draw_method == "auto":
100+
# Construct pixel response
101+
pixel_response = galsim.Pixel(scale=pix_scale)
102+
if psf_obj is None:
103+
psf_obj = pixel_response
104+
else:
105+
psf_obj = galsim.Convolve([psf_obj, pixel_response])
96106
else:
97-
psf_obj = galsim.Convolve([psf_obj, pixel_response])
98-
107+
raise ValueError("do not support draw_method=%s" %draw_method)
99108
# Convolution in Fourier space
100109
gal_prof = _gsinterface.convolvePsf(
101110
scale=scale,
102111
gsobj=psf_obj._sbp,
103112
gal_prof=gal_prof,
104113
downsample_ratio=downsample_ratio,
105114
ngrid=ngrid
106-
)
107-
115+
)
108116
return gal_prof
109117

118+
110119
def simulate_galaxy_batch(
111120
ngrid,
112121
pix_scale,
@@ -119,7 +128,7 @@ def simulate_galaxy_batch(
119128
nproc=4,
120129
force_ngrid=False
121130
):
122-
131+
123132
"""
124133
The function samples the surface density field of a galaxy at the grids
125134
@@ -143,21 +152,21 @@ def simulate_galaxy_batch(
143152
mp.set_start_method('spawn', force=True)
144153

145154
with mp.Pool(nproc) as p:
146-
155+
147156
args_list = [
148157
(
149158
ngrid,
150-
pix_scale,
151-
gal_obj,
152-
transform_obj,
153-
psf_obj,
154-
truncate_ratio,
155-
maximum_num_grids,
159+
pix_scale,
160+
gal_obj,
161+
transform_obj,
162+
psf_obj,
163+
truncate_ratio,
164+
maximum_num_grids,
156165
draw_method,
157166
force_ngrid
158167
) for gal_obj in gal_obj_list
159168
]
160-
169+
161170
outcome = p.starmap(simulate_galaxy, args_list)
162171

163172
if original_omp_num_threads is None:

0 commit comments

Comments
 (0)