Skip to content

Commit 2db1d6d

Browse files
committed
Added PSF cache, optimised IA transform, fixed multi-threading.
1 parent 15c721d commit 2db1d6d

12 files changed

Lines changed: 1083 additions & 78 deletions

.codex

Whitespace-only changes.

benchmarks/benchmarks.py

Lines changed: 129 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,11 @@
11
"""Benchmarks to check computation time and memory usage for batsim."""
22
import batsim.stamp as batstamp
33
import batsim.transforms as batforms
4+
import batsim
5+
import contextlib
46
import galsim
7+
import io
8+
import numpy as np
59
import time
610

711
def time_shear_speed(nn=64, scale=0.2):
@@ -74,7 +78,131 @@ def time_ia_speed(nn=128, scale=0.1):
7478
aff_time = aff_end - aff_start
7579

7680
return {'IA time' : ia_time, 'Lens time' : aff_time}
77-
81+
82+
83+
def _parse_simulate_profile_logs(log_lines):
84+
stats = {}
85+
timings = {}
86+
for line in log_lines:
87+
msg = line.split("] ", 1)[-1]
88+
if msg.startswith("stats "):
89+
for token in msg[6:].split():
90+
if "=" not in token:
91+
continue
92+
key, value = token.split("=", 1)
93+
try:
94+
stats[key] = int(value)
95+
except ValueError:
96+
try:
97+
stats[key] = float(value)
98+
except ValueError:
99+
stats[key] = value
100+
elif "=" in msg:
101+
key, value = msg.split("=", 1)
102+
value = value[:-1] if value.endswith("s") else value
103+
try:
104+
timings[key] = float(value)
105+
except ValueError:
106+
timings[key] = value
107+
return {"timings": timings, "stats": stats}
108+
109+
110+
def _extract_parametric_profile_info(cosmos_catalog, catalog_index, gal_obj):
111+
info = {
112+
"catalog_index": int(catalog_index),
113+
"gsobject_type": type(gal_obj).__name__,
114+
}
115+
for attr in ("flux", "nyquist_scale"):
116+
if hasattr(gal_obj, attr):
117+
try:
118+
info[attr] = float(getattr(gal_obj, attr))
119+
except Exception:
120+
pass
121+
param_cat = getattr(cosmos_catalog, "param_cat", None)
122+
if param_cat is None:
123+
return info
124+
keys = []
125+
if hasattr(param_cat, "colnames"):
126+
keys = ["mag_auto", "flux_radius", "zphot"] + [k for k in ("use_bulgefit", "viable_sersic") if k in param_cat.colnames]
127+
elif hasattr(param_cat, "dtype") and param_cat.dtype.names:
128+
keys = [k for k in ("mag_auto", "flux_radius", "zphot", "use_bulgefit", "viable_sersic") if k in param_cat.dtype.names]
129+
if not keys:
130+
return info
131+
row = param_cat[int(catalog_index)]
132+
for key in keys:
133+
try:
134+
value = row[key]
135+
if hasattr(value, "item"):
136+
value = value.item()
137+
info[key] = value
138+
except Exception:
139+
pass
140+
return info
141+
142+
143+
def benchmark_parametric_cosmos_profiles(
144+
n_galaxies=5,
145+
ngrid=128,
146+
pix_scale=0.2,
147+
psf_obj=None,
148+
draw_method="auto",
149+
truncate_ratio=1.0,
150+
maximum_num_grids=4096,
151+
force_ngrid=False,
152+
seed=1234,
153+
cosmos_catalog=None,
154+
):
155+
"""Run a lightweight per-galaxy benchmark using parametric COSMOS profiles.
156+
157+
Returns a list of dictionaries containing profile metadata, parsed
158+
`simulate_galaxy(profile=True)` logs, and end-to-end elapsed time.
159+
"""
160+
cosmos_catalog = cosmos_catalog or galsim.COSMOSCatalog()
161+
rng = np.random.RandomState(seed)
162+
indices = rng.choice(len(cosmos_catalog), size=n_galaxies, replace=(n_galaxies > len(cosmos_catalog)))
163+
164+
records = []
165+
for i, idx in enumerate(indices):
166+
gal = cosmos_catalog.makeGalaxy(index=int(idx), gal_type="parametric")
167+
profile_info = _extract_parametric_profile_info(cosmos_catalog, idx, gal)
168+
169+
log_buf = io.StringIO()
170+
t0 = time.perf_counter()
171+
with contextlib.redirect_stdout(log_buf):
172+
image = batsim.simulate_galaxy(
173+
ngrid=ngrid,
174+
pix_scale=pix_scale,
175+
gal_obj=gal,
176+
psf_obj=psf_obj,
177+
truncate_ratio=truncate_ratio,
178+
maximum_num_grids=maximum_num_grids,
179+
draw_method=draw_method,
180+
force_ngrid=force_ngrid,
181+
profile=True,
182+
)
183+
elapsed_s = time.perf_counter() - t0
184+
185+
profile_logs = [line for line in log_buf.getvalue().splitlines() if line.startswith("[simulate_galaxy]")]
186+
parsed_logs = _parse_simulate_profile_logs(profile_logs)
187+
record = {
188+
"galaxy_number": i,
189+
"profile": profile_info,
190+
"logger": parsed_logs,
191+
"elapsed_s": elapsed_s,
192+
"image_shape": tuple(image.shape),
193+
"image_sum": float(np.sum(image)),
194+
}
195+
records.append(record)
196+
197+
print(
198+
f"[benchmark_parametric_cosmos_profiles] i={i} idx={int(idx)} "
199+
f"nn={parsed_logs['stats'].get('nn')} downsample_ratio={parsed_logs['stats'].get('downsample_ratio')} "
200+
f"elapsed_s={elapsed_s:.4e}"
201+
)
202+
print(f"[benchmark_parametric_cosmos_profiles] profile={profile_info}")
203+
for line in profile_logs:
204+
print(line)
205+
return records
78206

79207
if __name__ == "__main__":
80208
time_shear_speed()

notebooks/dev/optimise_c.ipynb

Lines changed: 414 additions & 0 deletions
Large diffs are not rendered by default.

notebooks/dev/test_render_size.ipynb

Lines changed: 367 additions & 0 deletions
Large diffs are not rendered by default.

notebooks/ia-analysis/test_models.ipynb

Lines changed: 51 additions & 11 deletions
Large diffs are not rendered by default.

notebooks/ia-analysis/test_non-affine_response.ipynb

Lines changed: 7 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
"cells": [
33
{
44
"cell_type": "code",
5-
"execution_count": null,
5+
"execution_count": 2,
66
"id": "3295c8e5",
77
"metadata": {},
88
"outputs": [],
@@ -89,6 +89,7 @@
8989
" phi = np.radians(0),\n",
9090
" clip_radius=5 # clip the transform at 5*hlr to prevent edge effects\n",
9191
" )\n",
92+
"\n",
9293
" gal_img = batsim.simulate_galaxy(\n",
9394
" ngrid=nn,\n",
9495
" pix_scale=scale,\n",
@@ -100,16 +101,15 @@
100101
"\n",
101102
" # Apply lensing shear directly\n",
102103
" gal_img = galsim.Image(gal_img, scale=scale)\n",
103-
" gal = galsim.InterpolatedImage(gal_img, scale=scale)\n",
104104
" else:\n",
105105
" # convolve, shift, and draw the galaxy\n",
106106
" gal = gal.shift(0.5*scale, 0.5*scale) # shift the galaxy to center\n",
107107
"\n",
108-
" # Apply lensing shear\n",
109-
" gal = gal.shear(g1=lens_shear, g2=0.0)\n",
108+
" # Apply lensing shear\n",
109+
" gal = gal.shear(g1=lens_shear, g2=0.0)\n",
110110
"\n",
111-
" # Convolve after both IA and lensing\n",
112-
" gal = galsim.Convolve([gal, psf])\n",
111+
" # Convolve after both IA and lensing\n",
112+
" gal = galsim.Convolve([gal, psf])\n",
113113
"\n",
114114
" # Set the subimage in the stamp\n",
115115
" gal_img = gal.drawImage(nx=nn, ny=nn, scale=scale).array\n",
@@ -119,14 +119,6 @@
119119
"\n",
120120
" return stamp"
121121
]
122-
},
123-
{
124-
"cell_type": "code",
125-
"execution_count": null,
126-
"id": "d33c0296",
127-
"metadata": {},
128-
"outputs": [],
129-
"source": []
130122
}
131123
],
132124
"metadata": {
@@ -145,7 +137,7 @@
145137
"name": "python",
146138
"nbconvert_exporter": "python",
147139
"pygments_lexer": "ipython3",
148-
"version": "3.10.19"
140+
"version": "3.10.20"
149141
}
150142
},
151143
"nbformat": 4,

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ def uniq(xs):
6868
include_dirs=[pybind11.get_include()],
6969
libraries=["galsim"],
7070
language="c++",
71-
extra_compile_args=["-std=c++11", "-fopenmp", "-O3"],
71+
extra_compile_args=["-std=c++17", "-fopenmp", "-O3"],
7272
extra_link_args=["-flto", "-fopenmp"],
7373
)
7474

226 KB
Binary file not shown.
242 KB
Binary file not shown.

src/batsim/_gsinterface.cpp

Lines changed: 21 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -7,32 +7,33 @@
77
#include <fftw3.h>
88
#include <cmath>
99

10-
1110
namespace py = pybind11;
1211

1312
py::array_t<double> getFluxVec(
1413
const double scale,
1514
const galsim::SBProfile& gsobj,
1615
const py::array_t<double>& xy_coords
1716
){
18-
if (xy_coords.ndim() != 2 || xy_coords.shape(0) != 2) {
19-
throw std::runtime_error("xy_coords must be a 2D array with shape (2, n)");
20-
}
21-
2217
auto xy = xy_coords.unchecked<2>();
2318
const int n_points = xy_coords.shape(1);
24-
std::vector<double> fluxes(n_points);
25-
19+
const int dim = std::sqrt(n_points);
20+
const int n_used = dim * dim;
21+
auto result = py::array_t<double>({dim, dim});
22+
auto out = result.mutable_data();
2623
double area = scale * scale;
27-
#pragma omp parallel for
28-
for(int i = 0; i < n_points; ++i) {
29-
fluxes[i] = gsobj.xValue(
24+
25+
// Pre-warm GalSim's internal cache with a single serial call
26+
// before entering the parallel region
27+
gsobj.xValue(galsim::Position<double>(xy(0, 0), xy(1, 0)));
28+
29+
#pragma omp parallel for schedule(static)
30+
for(int i = 0; i < n_used; ++i) {
31+
out[i] = gsobj.xValue(
3032
galsim::Position<double>(xy(0, i), xy(1, i))
3133
) * area;
3234
}
3335

34-
int dim = std::sqrt(n_points);
35-
return py::array_t<double>({dim, dim}, fluxes.data());
36+
return result;
3637
}
3738

3839
// Utility function to generate rfftfreq
@@ -85,8 +86,8 @@ py::array_t<double> convolvePsf(
8586
int dim2 = dim / downsample_ratio;
8687

8788
// Frequency grids for the down sampled signal
88-
const auto x_freqs2 = rfftfreq(dim2, scale2 / M_PI / 2.0);
89-
const auto y_freqs2 = fftfreq(dim2, scale2 / M_PI / 2.0);
89+
const auto x_freqs2 = rfftfreq(dim2, scale2);
90+
const auto y_freqs2 = fftfreq(dim2, scale2);
9091

9192
// Allocate FFTW arrays with pointers
9293
double* in = static_cast<double*>(info.ptr);
@@ -97,6 +98,10 @@ py::array_t<double> convolvePsf(
9798
fftw_plan p_forward = fftw_plan_dft_r2c_2d(dim, dim, in, out, FFTW_ESTIMATE);
9899
fftw_execute(p_forward);
99100

101+
// Pre-warm GalSim's internal cache with a single serial call
102+
// before entering the parallel region
103+
gsobj.kValue(galsim::Position<double>(x_freqs2[0], y_freqs2[0]));
104+
100105
// Process FFT result using gsobj
101106
#pragma omp parallel for
102107
for (int y2 = 0; y2 < dim2; ++y2) {
@@ -108,8 +113,8 @@ py::array_t<double> convolvePsf(
108113
std::complex<double> fft_val(out[index][0], out[index][1]);
109114
std::complex<double> result = fft_val * gsobj.kValue(
110115
galsim::Position<double>(
111-
x_freqs2[x2],
112-
y_freqs2[y2]
116+
2.0 * M_PI * x_freqs2[x2],
117+
2.0 * M_PI * y_freqs2[y2]
113118
)
114119
);
115120
out2[index2][0] = result.real();

0 commit comments

Comments
 (0)