Skip to content

Commit 45ba6ae

Browse files
MaxGhenisclaude
andauthored
Fix 12 bugs from audit (security + correctness + tests) (#458)
* Gate gateway write and job-status endpoints behind auth Fixes #446 * Guard version registry latest pointer against downgrade Fixes #447 * Persist budget-window failure state across polls Fixes #448 * Back off scheduler poll interval exponentially Fixes #449 * Forbid unknown gateway request fields and cap payload size Fixes #450 * Accumulate budget-window totals in Decimal Fixes #451 * Make state_tax_revenue_impact optional for UK results Fixes #452 * Redact exception details from gateway error responses Fixes #453 * Guard budget-window child state transitions from missing entries Fixes #454 * Redact Logfire span payloads for simulation functions Fixes #455 * Clean up GCP credential temp file after use Fixes #456 * Rename scheduler test and scaffold real-Modal integration bucket Fixes #457 * Cache JWTDecoder at module scope to preserve JWKS LRU The gateway auth dependency was constructing a new JWTDecoder per request, which defeated PyJWKClient's internal JWKS cache and triggered a live JWKS fetch on every gated call. Hoist decoder construction behind a functools.lru_cache keyed on (issuer, audience) so one decoder is reused across the lifetime of the process while still honoring env rotation and app.dependency_overrides. New coverage: same-env returns same instance, rotated audience returns fresh instance, and a smoke test asserts JWTDecoder.__init__ runs at most once across five requests. Also adds a positive-path test verifying that app.dependency_overrides[require_auth] keeps gated endpoints reachable (guards the override path other tests rely on). Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> * Refuse to start when GATEWAY_AUTH_DISABLED leaks into production Previously GATEWAY_AUTH_DISABLED was documented as dev-only but had no runtime enforcement, so a single stray env var in a deploy config could silently disable bearer-token auth in production. Add a startup guard (enforce_production_auth_guard) invoked from the ASGI app factory that refuses to boot when the bypass is set but the Modal environment looks like production (missing, main, prod, production), and otherwise requires an explicit GATEWAY_AUTH_DISABLED_ACK=I_UNDERSTAND_THIS_IS_DEV acknowledgement. Even on the happy path we emit a CRITICAL log plus logfire.error so the bypass is unmistakable in audit trails. New coverage: guard no-ops when auth is enabled, refuses on missing MODAL_ENVIRONMENT, refuses on every variant of prod MODAL_ENVIRONMENT (main/prod/production/PROD), refuses in dev without ACK, refuses in dev with wrong ACK, and allows+logs the banner in dev with correct ACK. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> * Cover the 256 KB payload-cap boundary on both sides #450 added MAX_GATEWAY_REQUEST_BYTES=262_144 and an oversized-payload rejection test but the existing coverage didn't pin the exact cap — a future refactor could regress the limit to, say, 1 MB without any test failing. Add two boundary tests: one constructs a payload encoded just-under the cap and asserts SimulationRequest accepts it, the other tops the cap by a handful of bytes and asserts the same ``too large`` ValidationError fires. The assertions log the actual encoded size so diagnosing a drift on either side is straightforward. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> --------- Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent 5a8e95b commit 45ba6ae

27 files changed

Lines changed: 1741 additions & 86 deletions

projects/policyengine-api-simulation/fixtures/gateway/shared.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,17 +4,26 @@
44
from fastapi import FastAPI
55
from fastapi.testclient import TestClient
66

7+
from src.modal.gateway.auth import require_auth
78
from src.modal.gateway.endpoints import router
89

910

10-
def create_gateway_app() -> FastAPI:
11-
"""Create a FastAPI app with the gateway router for testing."""
11+
def create_gateway_app(*, authenticate: bool = True) -> FastAPI:
12+
"""Create a FastAPI app with the gateway router for testing.
13+
14+
By default the auth dependency is overridden with a no-op callable so
15+
individual tests don't need to stage JWT material. Tests that exercise
16+
the auth failure path can pass ``authenticate=False`` to keep the real
17+
dependency wired up.
18+
"""
1219
app = FastAPI(
1320
title="Test PolicyEngine Simulation API",
1421
description="Test instance for unit tests",
1522
version="0.0.1",
1623
)
1724
app.include_router(router)
25+
if authenticate:
26+
app.dependency_overrides[require_auth] = lambda: None
1827
return app
1928

2029

projects/policyengine-api-simulation/pyproject.toml

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,3 +41,10 @@ pythonpath = [
4141
]
4242

4343
testpaths = ["tests"]
44+
45+
# Skip real-Modal integration smoke tests unless the operator explicitly
46+
# opts in with ``-m integration``. The default test run stays hermetic.
47+
addopts = "-m 'not integration'"
48+
markers = [
49+
"integration: runs against a real (ephemeral) Modal deployment",
50+
]

projects/policyengine-api-simulation/src/modal/app.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import os
1212

1313
from src.modal._image_setup import snapshot_models
14+
from src.modal.logging_redaction import redact_params_for_logging
1415

1516
# Get versions from environment or use defaults
1617
US_VERSION = os.environ.get("POLICYENGINE_US_VERSION", "1.562.3")
@@ -94,14 +95,18 @@ def run_simulation(params: dict) -> dict:
9495

9596
configure_logfire()
9697

98+
# We deliberately avoid sending full ``params`` or ``result`` blobs to
99+
# Logfire: both can embed signed URLs, reform parameter trees with
100+
# sensitive policy details, or result payloads large enough to blow the
101+
# span attribute size budget. The redacted summary keeps correlation
102+
# traceability via run_id while leaving the heavy payload in memory.
103+
redacted_params = redact_params_for_logging(params)
97104
try:
98105
with logfire.span(
99106
"run_simulation",
100-
input_params=params,
101-
) as span:
102-
result = run_simulation_impl(params)
103-
span.set_attribute("output_result", result)
104-
return result
107+
**redacted_params,
108+
):
109+
return run_simulation_impl(params)
105110
finally:
106111
logfire.force_flush()
107112

@@ -123,13 +128,12 @@ def run_budget_window_batch(params: dict) -> dict:
123128

124129
configure_logfire()
125130

131+
redacted_params = redact_params_for_logging(params)
126132
try:
127133
with logfire.span(
128134
"run_budget_window_batch",
129-
input_params=params,
130-
) as span:
131-
result = run_budget_window_batch_impl(params)
132-
span.set_attribute("output_result", result)
133-
return result
135+
**redacted_params,
136+
):
137+
return run_budget_window_batch_impl(params)
134138
finally:
135139
logfire.force_flush()

projects/policyengine-api-simulation/src/modal/budget_window_results.py

Lines changed: 51 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from __future__ import annotations
44

5+
from decimal import Decimal
56
from typing import Any
67

78
from src.modal.gateway.models import (
@@ -10,12 +11,25 @@
1011
BudgetWindowTotals,
1112
)
1213

14+
# The UK microsimulation has no state/province fiscal layer, so worker child
15+
# results for ``country="uk"`` never emit ``state_tax_revenue_impact``. The
16+
# parent aggregator treats it as optional with a zero default; US results are
17+
# expected to supply it as a real number. All other keys remain mandatory.
1318
REQUIRED_BUDGET_KEYS = (
1419
"tax_revenue_impact",
15-
"state_tax_revenue_impact",
1620
"benefit_spending_impact",
1721
"budgetary_impact",
1822
)
23+
OPTIONAL_BUDGET_KEYS = ("state_tax_revenue_impact",)
24+
25+
26+
def _as_decimal(value: float | int) -> Decimal:
27+
"""Convert an annual impact float to Decimal without reintroducing
28+
binary-float quantisation noise. ``Decimal(str(...))`` is the canonical
29+
idiom because it serialises the float to its shortest round-trippable
30+
decimal form before parsing."""
31+
32+
return Decimal(str(value))
1933

2034

2135
def extract_annual_impact(
@@ -38,8 +52,13 @@ def extract_annual_impact(
3852
f"Malformed budget-window child result: missing numeric {missing}"
3953
)
4054

41-
state_tax_revenue_impact = budget["state_tax_revenue_impact"]
4255
tax_revenue_impact = budget["tax_revenue_impact"]
56+
# UK worker results omit the state fiscal layer entirely; coerce to 0.0
57+
# so the parent aggregator can still report federal/state splits with a
58+
# uniform shape across countries.
59+
state_tax_revenue_impact = budget.get("state_tax_revenue_impact")
60+
if not isinstance(state_tax_revenue_impact, int | float):
61+
state_tax_revenue_impact = 0.0
4362

4463
return BudgetWindowAnnualImpact(
4564
year=simulation_year,
@@ -54,22 +73,40 @@ def extract_annual_impact(
5473
def sum_annual_impacts(
5574
annual_impacts: list[BudgetWindowAnnualImpact],
5675
) -> BudgetWindowTotals:
57-
totals = {
58-
"taxRevenueImpact": 0,
59-
"federalTaxRevenueImpact": 0,
60-
"stateTaxRevenueImpact": 0,
61-
"benefitSpendingImpact": 0,
62-
"budgetaryImpact": 0,
76+
"""Sum per-year impacts using Decimal accumulators.
77+
78+
Binary-float addition accumulates rounding error for long budget windows
79+
(10-year sums over billion-dollar baselines quickly drift by ``1e-6`` or
80+
more). Accumulating in :class:`decimal.Decimal` keeps the answer exact
81+
to the input precision; we cast back to ``float`` at the serialisation
82+
boundary so the JSON schema stays numeric and clients that parse the
83+
response as ``number`` continue to work unchanged. Clients that need
84+
bit-exact accounting should request the individual per-year impacts and
85+
sum them in their preferred numeric type.
86+
"""
87+
88+
totals: dict[str, Decimal] = {
89+
"taxRevenueImpact": Decimal(0),
90+
"federalTaxRevenueImpact": Decimal(0),
91+
"stateTaxRevenueImpact": Decimal(0),
92+
"benefitSpendingImpact": Decimal(0),
93+
"budgetaryImpact": Decimal(0),
6394
}
6495

6596
for annual_impact in annual_impacts:
66-
totals["taxRevenueImpact"] += annual_impact.taxRevenueImpact
67-
totals["federalTaxRevenueImpact"] += annual_impact.federalTaxRevenueImpact
68-
totals["stateTaxRevenueImpact"] += annual_impact.stateTaxRevenueImpact
69-
totals["benefitSpendingImpact"] += annual_impact.benefitSpendingImpact
70-
totals["budgetaryImpact"] += annual_impact.budgetaryImpact
97+
totals["taxRevenueImpact"] += _as_decimal(annual_impact.taxRevenueImpact)
98+
totals["federalTaxRevenueImpact"] += _as_decimal(
99+
annual_impact.federalTaxRevenueImpact
100+
)
101+
totals["stateTaxRevenueImpact"] += _as_decimal(
102+
annual_impact.stateTaxRevenueImpact
103+
)
104+
totals["benefitSpendingImpact"] += _as_decimal(
105+
annual_impact.benefitSpendingImpact
106+
)
107+
totals["budgetaryImpact"] += _as_decimal(annual_impact.budgetaryImpact)
71108

72-
return BudgetWindowTotals(**totals)
109+
return BudgetWindowTotals(**{key: float(value) for key, value in totals.items()})
73110

74111

75112
def build_budget_window_result(

projects/policyengine-api-simulation/src/modal/budget_window_scheduler.py

Lines changed: 52 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,23 @@
2929
put_batch_job_seed,
3030
put_batch_job_state,
3131
)
32-
33-
POLL_INTERVAL_SECONDS = 0.1
32+
from src.modal.gateway.errors import log_and_redact_exception
33+
34+
# Polling tuning. The runner busy-loops across child FunctionCall.get(timeout=0)
35+
# probes; when no child resolved we sleep before the next probe to stop the
36+
# Modal control-plane from getting hammered. We start aggressive (0.5s) so
37+
# fast child runs don't inflate end-to-end latency, then double up to 30s so a
38+
# sluggish child doesn't keep the parent container hot polling. A blocking
39+
# FunctionCall.get(timeout=...) would be even better, but its interaction with
40+
# max_parallel means we'd have to juggle per-year deadlines and give up early
41+
# termination on child failure; the exponential walk keeps the control flow
42+
# simple while matching Modal's recommended polling cadence.
43+
POLL_INTERVAL_INITIAL_SECONDS = 0.5
44+
POLL_INTERVAL_MAX_SECONDS = 30.0
45+
POLL_INTERVAL_BACKOFF_FACTOR = 2.0
46+
# Retained for backward compatibility with callers that imported the original
47+
# constant; new code should use the initial/max pair above.
48+
POLL_INTERVAL_SECONDS = POLL_INTERVAL_INITIAL_SECONDS
3449

3550

3651
def serialize_batch_status(state) -> dict[str, Any]:
@@ -59,10 +74,16 @@ def __init__(
5974
context: BudgetWindowBatchContext,
6075
*,
6176
modal_module=None,
62-
poll_interval_seconds: float = POLL_INTERVAL_SECONDS,
77+
poll_interval_seconds: float = POLL_INTERVAL_INITIAL_SECONDS,
78+
poll_interval_max_seconds: float = POLL_INTERVAL_MAX_SECONDS,
79+
poll_interval_backoff_factor: float = POLL_INTERVAL_BACKOFF_FACTOR,
6380
):
6481
self.context = context
6582
self.modal = modal if modal_module is None else modal_module
83+
self.poll_interval_initial_seconds = poll_interval_seconds
84+
self.poll_interval_max_seconds = poll_interval_max_seconds
85+
self.poll_interval_backoff_factor = poll_interval_backoff_factor
86+
# Kept for tests that still read this attribute.
6687
self.poll_interval_seconds = poll_interval_seconds
6788
self.state = load_or_create_batch_state(context)
6889
self.child_func = self.modal.Function.from_name(
@@ -75,13 +96,22 @@ def run(self) -> dict[str, Any]:
7596
mark_batch_running(self.state)
7697
put_batch_job_state(self.state)
7798

99+
# Exponential backoff: reset on any progress, double on empty polls.
100+
current_sleep = self.poll_interval_initial_seconds
101+
78102
while self.has_pending_work():
79103
self.spawn_until_capacity()
80104
progress_made = self.poll_running_children_once()
81105
if self.state.status == "failed":
82106
return serialize_batch_status(self.state)
83107
if self.state.running_years and not progress_made:
84-
time.sleep(self.poll_interval_seconds)
108+
time.sleep(current_sleep)
109+
current_sleep = min(
110+
current_sleep * self.poll_interval_backoff_factor,
111+
self.poll_interval_max_seconds,
112+
)
113+
elif progress_made:
114+
current_sleep = self.poll_interval_initial_seconds
85115

86116
return self.complete_batch()
87117

@@ -122,9 +152,17 @@ def poll_running_children_once(self) -> bool:
122152
except TimeoutError:
123153
continue
124154
except Exception as exc:
155+
redacted = log_and_redact_exception(
156+
exc,
157+
scope="budget_window_child_call",
158+
context={
159+
"batch_job_id": self.context.batch_job_id,
160+
"simulation_year": simulation_year,
161+
},
162+
)
125163
self.fail_batch_for_child_error(
126164
simulation_year=simulation_year,
127-
error=str(exc),
165+
error=redacted,
128166
)
129167
return False
130168

@@ -134,9 +172,17 @@ def poll_running_children_once(self) -> bool:
134172
child_result=child_result,
135173
)
136174
except Exception as exc:
175+
redacted = log_and_redact_exception(
176+
exc,
177+
scope="budget_window_child_result_parsing",
178+
context={
179+
"batch_job_id": self.context.batch_job_id,
180+
"simulation_year": simulation_year,
181+
},
182+
)
137183
self.fail_batch_for_child_error(
138184
simulation_year=simulation_year,
139-
error=str(exc),
185+
error=redacted,
140186
)
141187
return False
142188

projects/policyengine-api-simulation/src/modal/budget_window_state.py

Lines changed: 35 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from __future__ import annotations
44

5+
import logging
56
from datetime import UTC, datetime
67

78
import modal
@@ -16,6 +17,10 @@
1617
PolicyEngineBundle,
1718
)
1819

20+
logger = logging.getLogger(__name__)
21+
22+
_UNKNOWN_CHILD_JOB_ID = "unknown"
23+
1924
BUDGET_WINDOW_JOB_DICT_NAME = "simulation-api-budget-window-jobs"
2025
BUDGET_WINDOW_JOB_SEED_DICT_NAME = "simulation-api-budget-window-job-seeds"
2126

@@ -129,6 +134,34 @@ def mark_child_started(
129134
return _touch(state)
130135

131136

137+
def _existing_child_or_sentinel(
138+
state: BudgetWindowBatchState, *, year: str
139+
) -> BatchChildJobStatus:
140+
"""Return the tracked child for ``year`` or synthesise a sentinel.
141+
142+
Callers (``mark_child_completed`` / ``mark_child_failed``) used to index
143+
``state.child_jobs[year]`` directly which would raise ``KeyError`` if
144+
transition helpers were invoked out of order (e.g., after recovery from
145+
a dropped ``mark_child_started`` due to a crash between spawn and seed
146+
persistence). In that unusual case we'd rather surface a redacted
147+
terminal state with a synthetic job id than abort the whole batch. The
148+
anomaly is logged at WARNING so operators can investigate separately.
149+
"""
150+
child = state.child_jobs.get(year)
151+
if child is not None:
152+
return child
153+
154+
logger.warning(
155+
"Transitioning child state for year %s with no prior child_jobs entry;"
156+
" synthesising a sentinel job id",
157+
year,
158+
extra={"year": year, "batch_job_id": state.batch_job_id},
159+
)
160+
sentinel = BatchChildJobStatus(job_id=_UNKNOWN_CHILD_JOB_ID, status="pending")
161+
state.child_jobs[year] = sentinel
162+
return sentinel
163+
164+
132165
def mark_child_completed(
133166
state: BudgetWindowBatchState,
134167
*,
@@ -140,7 +173,7 @@ def mark_child_completed(
140173
if year not in state.completed_years:
141174
state.completed_years.append(year)
142175

143-
child = state.child_jobs[year]
176+
child = _existing_child_or_sentinel(state, year=year)
144177
state.child_jobs[year] = BatchChildJobStatus(
145178
job_id=child.job_id,
146179
status="complete",
@@ -160,7 +193,7 @@ def mark_child_failed(
160193
if year not in state.failed_years:
161194
state.failed_years.append(year)
162195

163-
child = state.child_jobs[year]
196+
child = _existing_child_or_sentinel(state, year=year)
164197
state.child_jobs[year] = BatchChildJobStatus(
165198
job_id=child.job_id,
166199
status="failed",

0 commit comments

Comments
 (0)