Skip to content

Commit 4f1e292

Browse files
committed
updated tests to ensure eval works with reasonable subset of python operations, and classes
1 parent f1de016 commit 4f1e292

2 files changed

Lines changed: 96 additions & 11 deletions

File tree

effectful/handlers/llm/evaluation.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,12 @@
44
from types import CodeType
55
from typing import Any
66

7-
from RestrictedPython import compile_restricted, safe_globals
7+
from RestrictedPython import (
8+
Eval,
9+
Guards,
10+
compile_restricted,
11+
safe_globals,
12+
)
813

914
from effectful.ops.syntax import ObjectInterpretation, defop, implements
1015

@@ -117,9 +122,24 @@ def exec(
117122
# Build restricted globals from RestrictedPython's defaults, then layer `env` on top
118123
# (without letting callers replace the restricted builtins).
119124
rglobals = safe_globals.copy()
125+
126+
# Enable class definitions (required for Python 3)
127+
rglobals["__metaclass__"] = type
128+
rglobals["__name__"] = "restricted"
129+
130+
# Layer `env` on top (without letting callers replace the restricted builtins).
120131
for k, v in env.items():
121132
if k != "__builtins__":
122133
rglobals[k] = v
123134

135+
# Enable for loops and comprehensions
136+
rglobals["_getiter_"] = Eval.default_guarded_getiter
137+
138+
# Enable sequence unpacking in comprehensions and for loops
139+
rglobals["_iter_unpack_sequence_"] = Guards.guarded_iter_unpack_sequence
140+
rglobals["getattr"] = Guards.safer_getattr
141+
rglobals["setattr"] = Guards.guarded_setattr
142+
rglobals["_write_"] = lambda x: x
143+
124144
# Execute with locals=env so top-level defs land in `env` (like your UnsafeEvalProvider).
125145
builtins.exec(bytecode, rglobals, env)

tests/test_handlers_llm_encoding.py

Lines changed: 75 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -829,16 +829,7 @@ def bar():
829829
encodable.decode(source, {})
830830

831831
@pytest.mark.parametrize(
832-
"eval_provider",
833-
[
834-
UnsafeEvalProvider,
835-
pytest.param(
836-
RestrictedEvalProvider,
837-
marks=pytest.mark.skip(
838-
reason="RestrictedPython doesn't support class definitions"
839-
),
840-
),
841-
],
832+
"eval_provider", [UnsafeEvalProvider, RestrictedEvalProvider]
842833
)
843834
def test_decode_class(self, eval_provider):
844835
from collections.abc import Callable
@@ -858,6 +849,80 @@ def greet(self):
858849
instance = decoded("World")
859850
assert instance.greet() == "Hello, World!"
860851

852+
@pytest.mark.parametrize(
853+
"eval_provider", [UnsafeEvalProvider, RestrictedEvalProvider]
854+
)
855+
def test_decode_function_with_for_loop(self, eval_provider):
856+
from collections.abc import Callable
857+
858+
encodable = type_to_encodable_type(Callable)
859+
# Test function with for loop
860+
source = """def sum_list(items):
861+
total = 0
862+
for item in items:
863+
total = total + item
864+
return total"""
865+
866+
with handler(eval_provider()):
867+
decoded = encodable.decode(source, {})
868+
assert callable(decoded)
869+
assert decoded([1, 2, 3, 4]) == 10
870+
assert decoded([5, 10]) == 15
871+
872+
@pytest.mark.parametrize(
873+
"eval_provider", [UnsafeEvalProvider, RestrictedEvalProvider]
874+
)
875+
def test_decode_function_with_list_comprehension(self, eval_provider):
876+
from collections.abc import Callable
877+
878+
encodable = type_to_encodable_type(Callable)
879+
# Test function with list comprehension
880+
source = """def double_items(items):
881+
return [x * 2 for x in items]"""
882+
883+
with handler(eval_provider()):
884+
decoded = encodable.decode(source, {})
885+
assert callable(decoded)
886+
assert decoded([1, 2, 3]) == [2, 4, 6]
887+
assert decoded([5, 10, 15]) == [10, 20, 30]
888+
889+
@pytest.mark.parametrize(
890+
"eval_provider", [UnsafeEvalProvider, RestrictedEvalProvider]
891+
)
892+
def test_decode_function_with_dict_comprehension(self, eval_provider):
893+
from collections.abc import Callable
894+
895+
encodable = type_to_encodable_type(Callable)
896+
# Test function with dict comprehension
897+
source = """def square_dict(items):
898+
return {x: x * x for x in items}"""
899+
900+
with handler(eval_provider()):
901+
decoded = encodable.decode(source, {})
902+
assert callable(decoded)
903+
assert decoded([1, 2, 3]) == {1: 1, 2: 4, 3: 9}
904+
assert decoded([5, 10]) == {5: 25, 10: 100}
905+
906+
@pytest.mark.parametrize(
907+
"eval_provider", [UnsafeEvalProvider, RestrictedEvalProvider]
908+
)
909+
def test_decode_function_with_unpacking(self, eval_provider):
910+
from collections.abc import Callable
911+
912+
encodable = type_to_encodable_type(Callable)
913+
# Test function with tuple unpacking
914+
source = """def process_pairs(pairs):
915+
results = []
916+
for a, b in pairs:
917+
results.append(a + b)
918+
return results"""
919+
920+
with handler(eval_provider()):
921+
decoded = encodable.decode(source, {})
922+
assert callable(decoded)
923+
assert decoded([(1, 2), (3, 4)]) == [3, 7]
924+
assert decoded([(10, 20)]) == [30]
925+
861926
@pytest.mark.parametrize(
862927
"eval_provider", [UnsafeEvalProvider, RestrictedEvalProvider]
863928
)

0 commit comments

Comments
 (0)