Skip to content

Commit 7b360e6

Browse files
authored
add mypy-based validation, drop syntactic check on return type (#536)
* add mypy-based validation, drop syntactic check on return type * resurrects tests for test_handlers_llm_encoding.py * switched to type error instead of value error * neatened up type checking tests * added mypy to llm dependencies * switched to use typing_extensions TypeAliasType * moved type checking to an operation * refined exception guard and switched to any instead of skipping bindings with un-representable types * refactored code to construct asts, added more systematic tests, and checks * updated deps of llm submodule * llm tests all pass * minor bug * ruff formatting * updated imports to include sys.modules * restricted sys.modules in imports and suppressed warnings on importing untyped modules * added ignore to ast.FunctionDef * updated to use ruff to clean up generated code and avoid redundant mypy lag * updated codeadapt fixture * added --unsafe-fixes to ruff invocation * updated prompt * format notebook * format notebook * switched from ruff to autoflake * fix for __init__ * fixed test with __init__ returning None * final fixes * added type_checking
1 parent c2491e6 commit 7b360e6

12 files changed

Lines changed: 2825 additions & 888 deletions

docs/source/llm.ipynb

Lines changed: 853 additions & 853 deletions
Large diffs are not rendered by default.

effectful/handlers/llm/encoding.py

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -272,7 +272,7 @@ def _format_callable_type(callable_type: type[Callable]) -> str:
272272

273273
if param_types is ...:
274274
params_str = "..."
275-
elif isinstance(param_types, (list, tuple)):
275+
elif isinstance(param_types, list | tuple):
276276
params_str = ", ".join(getattr(t, "__name__", str(t)) for t in param_types)
277277
else:
278278
params_str = str(param_types)
@@ -366,14 +366,6 @@ def _validate_signature_callable(
366366
"decode() requires synthesized function to have a return type annotation"
367367
)
368368

369-
expected_name = getattr(expected_return, "__name__", str(expected_return))
370-
actual_name = getattr(actual_return, "__name__", str(actual_return))
371-
if expected_name != actual_name:
372-
raise ValueError(
373-
f"decode() expected function with return type {expected_name}, "
374-
f"got {actual_name}"
375-
)
376-
377369

378370
@dataclass
379371
class CallableEncodable(Encodable[Callable, SynthesizedFunction]):
@@ -455,6 +447,11 @@ def decode(self, encoded_value: SynthesizedFunction) -> Callable:
455447
# Validate signature from AST before execution
456448
_validate_signature_ast(last_stmt, self.expected_params)
457449

450+
# Type-check with mypy; pass original module_code so mypy sees exact source
451+
evaluation.type_check(
452+
module, self.ctx, self.expected_params, self.expected_return
453+
)
454+
458455
# Compile and execute
459456
# https://docs.python.org/3/library/functions.html#exec
460457
g: MutableMapping[str, Any] = {}
@@ -620,7 +617,7 @@ def _encodable_callable(
620617

621618
# Ellipsis means any params, skip param validation
622619
expected_params: list[type] | None = None
623-
if param_types is not ... and isinstance(param_types, (list, tuple)):
620+
if param_types is not ... and isinstance(param_types, list | tuple):
624621
expected_params = list(param_types)
625622

626623
return CallableEncodable(ty, typed_enc, ctx, expected_params, expected_return)

0 commit comments

Comments
 (0)