Skip to content

Commit 8501b3d

Browse files
feat: first pass at carrying contextvars though async flows (#878)
* first pass at carrying contextvars though async flows * update docstring * utilize with context instead of iter token * add todo regarding log context and contextfilter
1 parent ec8cc18 commit 8501b3d

12 files changed

Lines changed: 772 additions & 225 deletions

File tree

mellea/backends/huggingface.py

Lines changed: 58 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@
5555
instrument_generate_from_raw,
5656
start_generate_span,
5757
)
58+
from ..telemetry.context import generate_request_id, with_context
5859
from .adapters import (
5960
AdapterMixin,
6061
AdapterType,
@@ -389,71 +390,80 @@ async def _generate_from_context(
389390
span = start_generate_span(
390391
backend=self, action=action, ctx=ctx, format=format, tool_calls=tool_calls
391392
)
392-
await self.do_generate_walk(action)
393393

394-
# Upsert model options.
395-
model_opts = self._simplify_and_merge(model_options)
394+
with with_context(
395+
request_id=generate_request_id(),
396+
model_id=str(getattr(self, "model_id", "unknown")),
397+
):
398+
await self.do_generate_walk(action)
396399

397-
# Requirements can be automatically rerouted to a requirement adapter.
398-
if isinstance(action, Requirement):
399-
# See docs/dev/requirement_aLoRA_rerouting.md
400-
reroute_to_alora = self.default_to_constraint_checking_alora
401-
adapter_name = "requirement_check"
400+
# Upsert model options.
401+
model_opts = self._simplify_and_merge(model_options)
402402

403-
if isinstance(action, ALoraRequirement):
404-
reroute_to_alora = True
405-
adapter_name = action.intrinsic_name
406-
alora_action = action
407-
else:
408-
assert action.description is not None, (
409-
"must have a description when generating from a requirement"
410-
)
411-
alora_action = ALoraRequirement(action.description, adapter_name)
403+
# Requirements can be automatically rerouted to a requirement adapter.
404+
if isinstance(action, Requirement):
405+
# See docs/dev/requirement_aLoRA_rerouting.md
406+
reroute_to_alora = self.default_to_constraint_checking_alora
407+
adapter_name = "requirement_check"
412408

413-
# Check if a requirement_check (or AloraRequirement specified) adapter
414-
# exists.
415-
alora_req_adapter = get_adapter_for_intrinsic(
416-
adapter_name, [AdapterType.ALORA], self._added_adapters
417-
)
418-
if alora_req_adapter is None:
419-
# Log a warning if using an AloraRequirement but no adapter fit.
420-
if reroute_to_alora and isinstance(action, ALoraRequirement):
421-
MelleaLogger.get_logger().warning(
422-
f"attempted to use an AloraRequirement but backend {self} doesn't have the specified adapter added {adapter_name}; defaulting to regular generation"
409+
if isinstance(action, ALoraRequirement):
410+
reroute_to_alora = True
411+
adapter_name = action.intrinsic_name
412+
alora_action = action
413+
else:
414+
assert action.description is not None, (
415+
"must have a description when generating from a requirement"
423416
)
424-
reroute_to_alora = False
417+
alora_action = ALoraRequirement(action.description, adapter_name)
418+
419+
# Check if a requirement_check (or AloraRequirement specified) adapter
420+
# exists.
421+
alora_req_adapter = get_adapter_for_intrinsic(
422+
adapter_name, [AdapterType.ALORA], self._added_adapters
423+
)
424+
if alora_req_adapter is None:
425+
# Log a warning if using an AloraRequirement but no adapter fit.
426+
if reroute_to_alora and isinstance(action, ALoraRequirement):
427+
MelleaLogger.get_logger().warning(
428+
f"attempted to use an AloraRequirement but backend {self} doesn't have the specified adapter added {adapter_name}; defaulting to regular generation"
429+
)
430+
reroute_to_alora = False
425431

426-
if issubclass(type(action), LLMaJRequirement):
427-
reroute_to_alora = False
432+
if issubclass(type(action), LLMaJRequirement):
433+
reroute_to_alora = False
428434

429-
if reroute_to_alora:
430-
# Keep the alora requirement handling separate for now.
435+
if reroute_to_alora:
436+
# Keep the alora requirement handling separate for now.
437+
mot = await self._generate_from_intrinsic(
438+
alora_action, ctx, model_options=model_opts
439+
)
440+
# Store span for telemetry
441+
if span is not None:
442+
mot._meta["_telemetry_span"] = span
443+
return mot, ctx.add(alora_action).add(mot)
444+
445+
elif isinstance(action, Intrinsic):
431446
mot = await self._generate_from_intrinsic(
432-
alora_action, ctx, model_options=model_opts
447+
action, ctx, model_options=model_opts
433448
)
434449
# Store span for telemetry
435450
if span is not None:
436451
mot._meta["_telemetry_span"] = span
437-
return mot, ctx.add(alora_action).add(mot)
452+
return mot, ctx.add(action).add(mot)
438453

439-
elif isinstance(action, Intrinsic):
440-
mot = await self._generate_from_intrinsic(
441-
action, ctx, model_options=model_opts
454+
mot = await self._generate_from_context_standard(
455+
action,
456+
ctx,
457+
_format=format,
458+
model_options=model_opts,
459+
tool_calls=tool_calls,
442460
)
443-
# Store span for telemetry
461+
462+
# Store span in metadata for post_processing to record telemetry
444463
if span is not None:
445464
mot._meta["_telemetry_span"] = span
446-
return mot, ctx.add(action).add(mot)
447465

448-
mot = await self._generate_from_context_standard(
449-
action, ctx, _format=format, model_options=model_opts, tool_calls=tool_calls
450-
)
451-
452-
# Store span in metadata for post_processing to record telemetry
453-
if span is not None:
454-
mot._meta["_telemetry_span"] = span
455-
456-
return mot, ctx.add(action).add(mot)
466+
return mot, ctx.add(action).add(mot)
457467

458468
def _generate_with_adapter_lock(
459469
self, adapter_name: str, generate_func: Callable, *args, **kwargs

mellea/backends/litellm.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
instrument_generate_from_raw,
4646
start_generate_span,
4747
)
48+
from ..telemetry.context import generate_request_id, with_context
4849
from .backend import FormatterBackend
4950
from .model_options import ModelOption
5051
from .tools import (
@@ -166,13 +167,16 @@ async def _generate_from_context(
166167
span = start_generate_span(
167168
backend=self, action=action, ctx=ctx, format=format, tool_calls=tool_calls
168169
)
169-
mot = await self._generate_from_chat_context_standard(
170-
action,
171-
ctx,
172-
_format=format,
173-
model_options=model_options,
174-
tool_calls=tool_calls,
175-
)
170+
171+
_model_id_str = str(getattr(self, "model_id", "unknown"))
172+
with with_context(request_id=generate_request_id(), model_id=_model_id_str):
173+
mot = await self._generate_from_chat_context_standard(
174+
action,
175+
ctx,
176+
_format=format,
177+
model_options=model_options,
178+
tool_calls=tool_calls,
179+
)
176180

177181
# Store span for telemetry recording in post_processing
178182
if span is not None:

mellea/backends/ollama.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,9 @@
3030
from ..telemetry.backend_instrumentation import (
3131
instrument_generate_from_context,
3232
instrument_generate_from_raw,
33+
start_generate_span,
3334
)
35+
from ..telemetry.context import generate_request_id, with_context
3436
from .backend import FormatterBackend
3537
from .model_options import ModelOption
3638
from .tools import add_tools_from_context_actions, add_tools_from_model_options
@@ -286,21 +288,22 @@ async def _generate_from_context(
286288
tuple[ModelOutputThunk[C], Context]: A thunk holding the (lazy) model output
287289
and an updated context that includes ``action`` and the new output.
288290
"""
289-
from ..telemetry.backend_instrumentation import start_generate_span
290-
291291
# Start span without auto-closing (will be closed in post_processing)
292292
span = start_generate_span(self, action, ctx, format, tool_calls)
293293

294294
assert ctx.is_chat_context, (
295295
"The ollama backend only supports chat-like contexts."
296296
)
297-
mot = await self.generate_from_chat_context(
298-
action,
299-
ctx,
300-
_format=format,
301-
model_options=model_options,
302-
tool_calls=tool_calls,
303-
)
297+
298+
_model_id_str = str(getattr(self, "model_id", "unknown"))
299+
with with_context(request_id=generate_request_id(), model_id=_model_id_str):
300+
mot = await self.generate_from_chat_context(
301+
action,
302+
ctx,
303+
_format=format,
304+
model_options=model_options,
305+
tool_calls=tool_calls,
306+
)
304307

305308
# Store span for telemetry recording and closing in post_processing
306309
if span is not None:

mellea/backends/openai.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,9 @@
4747
from ..telemetry.backend_instrumentation import (
4848
instrument_generate_from_context,
4949
instrument_generate_from_raw,
50+
start_generate_span,
5051
)
52+
from ..telemetry.context import generate_request_id, with_context
5153
from .backend import FormatterBackend
5254
from .model_options import ModelOption
5355
from .tools import (
@@ -357,8 +359,6 @@ async def _generate_from_context(
357359
tuple[ModelOutputThunk[C], Context]: A thunk holding the (lazy) model output
358360
and an updated context that includes ``action`` and the new output.
359361
"""
360-
from ..telemetry.backend_instrumentation import start_generate_span
361-
362362
assert ctx.is_chat_context, NotImplementedError(
363363
"The Openai backend only supports chat-like contexts."
364364
)
@@ -372,13 +372,15 @@ async def _generate_from_context(
372372
backend=self, action=action, ctx=ctx, format=format, tool_calls=tool_calls
373373
)
374374

375-
result = await self.generate_from_chat_context(
376-
action,
377-
ctx,
378-
_format=format,
379-
model_options=model_options,
380-
tool_calls=tool_calls,
381-
)
375+
_model_id_str = str(getattr(self, "model_id", "unknown"))
376+
with with_context(request_id=generate_request_id(), model_id=_model_id_str):
377+
result = await self.generate_from_chat_context(
378+
action,
379+
ctx,
380+
_format=format,
381+
model_options=model_options,
382+
tool_calls=tool_calls,
383+
)
382384
# Store span in ModelOutputThunk for later use in post_processing
383385
mot, new_ctx = result
384386
if span is not None:

mellea/backends/watsonx.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
instrument_generate_from_raw,
4949
start_generate_span,
5050
)
51+
from ..telemetry.context import generate_request_id, with_context
5152
from .backend import FormatterBackend
5253
from .model_options import ModelOption
5354
from .tools import (
@@ -304,13 +305,16 @@ async def _generate_from_context(
304305
span = start_generate_span(
305306
backend=self, action=action, ctx=ctx, format=format, tool_calls=tool_calls
306307
)
307-
mot = await self.generate_from_chat_context(
308-
action,
309-
ctx,
310-
_format=format,
311-
model_options=model_options,
312-
tool_calls=tool_calls,
313-
)
308+
309+
_model_id_str = str(getattr(self, "model_id", "unknown"))
310+
with with_context(request_id=generate_request_id(), model_id=_model_id_str):
311+
mot = await self.generate_from_chat_context(
312+
action,
313+
ctx,
314+
_format=format,
315+
model_options=model_options,
316+
tool_calls=tool_calls,
317+
)
314318

315319
# Store span in metadata for post_processing to record telemetry
316320
if span is not None:

mellea/core/utils.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@
3636
except ImportError:
3737
_OTEL_AVAILABLE = False
3838

39+
from ..telemetry.context import _CONTEXT_VARS as _telemetry_vars, MelleaContextFilter
40+
3941
# ---------------------------------------------------------------------------
4042
# Per-task/coroutine context fields (safe for asyncio — each Task gets its own copy)
4143
# ---------------------------------------------------------------------------
@@ -376,13 +378,22 @@ def _build_log_dict(self, record: logging.LogRecord) -> dict[str, Any]:
376378
# Static extra fields (constructor-level)
377379
log_record.update(self._extra)
378380

379-
# Dynamic context fields — prefer record attributes (set by
380-
# ContextFilter) but fall back to ContextVar storage so the
381-
# formatter works standalone without a filter attached.
381+
# Dynamic context fields — prefer record attributes (set by ContextFilter /
382+
# MelleaContextFilter) but fall back to ContextVar storage so the formatter
383+
# works standalone without filters attached.
382384
context_fields: dict[str, Any] = _log_context.get()
383385
for key, value in context_fields.items():
384386
log_record[key] = getattr(record, key, value)
385387

388+
# Telemetry context fields (session_id, request_id, model_id, sampling_iteration).
389+
# MelleaContextFilter stamps these onto the record before formatters run; read
390+
# them back off the record here so they appear in JSON output. Fall back to the
391+
# ContextVar directly so the formatter still works without the filter attached.
392+
for key, var in _telemetry_vars.items():
393+
value = getattr(record, key, var.get())
394+
if value is not None:
395+
log_record.setdefault(key, value)
396+
386397
return log_record
387398

388399
def format(self, record: logging.LogRecord) -> str:
@@ -521,6 +532,9 @@ def get_logger() -> logging.Logger:
521532
logger.addFilter(ContextFilter())
522533
logger.addFilter(OtelTraceFilter())
523534

535+
# Inject telemetry context fields (session_id, request_id, etc.)
536+
logger.addFilter(MelleaContextFilter())
537+
524538
# Only set default level if user hasn't already configured it
525539
if logger.level == logging.NOTSET:
526540
logger.setLevel(MelleaLogger._resolve_log_level())

0 commit comments

Comments
 (0)