|
6 | 6 | from PIL import Image |
7 | 7 |
|
8 | 8 | from effectful.handlers.llm.encoding import Encodable |
| 9 | +from effectful.handlers.llm.evaluation import UnsafeEvalProvider |
| 10 | +from effectful.ops.semantics import handler |
9 | 11 | from effectful.ops.types import Operation, Term |
10 | 12 |
|
11 | 13 |
|
@@ -718,3 +720,125 @@ class Person(pydantic.BaseModel): |
718 | 720 | assert decoded_from_model == person |
719 | 721 | assert isinstance(decoded_from_model, Person) |
720 | 722 | assert isinstance(decoded_from_model.address, Address) |
| 723 | + |
| 724 | + |
| 725 | +class TestCallableEncodable: |
| 726 | + """Tests for CallableEncodable - encoding/decoding callables as source code.""" |
| 727 | + |
| 728 | + def test_encode_decode_function(self): |
| 729 | + from collections.abc import Callable |
| 730 | + |
| 731 | + def add(a: int, b: int) -> int: |
| 732 | + return a + b |
| 733 | + |
| 734 | + encodable = Encodable.define(Callable, {}) |
| 735 | + encoded = encodable.encode(add) |
| 736 | + assert isinstance(encoded, str) |
| 737 | + assert "def add" in encoded |
| 738 | + assert "return a + b" in encoded |
| 739 | + |
| 740 | + with handler(UnsafeEvalProvider()): |
| 741 | + decoded = encodable.decode(encoded) |
| 742 | + assert callable(decoded) |
| 743 | + assert decoded(2, 3) == 5 |
| 744 | + assert decoded.__name__ == "add" |
| 745 | + |
| 746 | + def test_decode_lambda(self): |
| 747 | + from collections.abc import Callable |
| 748 | + |
| 749 | + # Lambdas should work if defined in a way that inspect.getsource can find them |
| 750 | + # Note: lambdas defined inline may not always have retrievable source |
| 751 | + encodable = Encodable.define(Callable, {}) |
| 752 | + |
| 753 | + # Test decoding a lambda from source string |
| 754 | + lambda_source = "f = lambda x: x * 2" |
| 755 | + with handler(UnsafeEvalProvider()): |
| 756 | + decoded = encodable.decode(lambda_source) |
| 757 | + assert callable(decoded) |
| 758 | + assert decoded(5) == 10 |
| 759 | + |
| 760 | + def test_decode_with_env(self): |
| 761 | + from collections.abc import Callable |
| 762 | + |
| 763 | + # Test decoding a function that uses env variables |
| 764 | + encodable = Encodable.define(Callable, {"factor": 3}) |
| 765 | + source = """def multiply(x): |
| 766 | + return x * factor""" |
| 767 | + |
| 768 | + with handler(UnsafeEvalProvider()): |
| 769 | + decoded = encodable.decode(source) |
| 770 | + assert callable(decoded) |
| 771 | + assert decoded(4) == 12 |
| 772 | + |
| 773 | + def test_encode_non_callable_raises(self): |
| 774 | + from collections.abc import Callable |
| 775 | + |
| 776 | + encodable = Encodable.define(Callable, {}) |
| 777 | + with pytest.raises(TypeError, match="Expected callable"): |
| 778 | + encodable.encode("not a callable", {}) |
| 779 | + |
| 780 | + def test_encode_builtin_raises(self): |
| 781 | + from collections.abc import Callable |
| 782 | + |
| 783 | + encodable = Encodable.define(Callable, {}) |
| 784 | + # Built-in functions don't have source code |
| 785 | + with pytest.raises(RuntimeError, match="Source code of callable .* not found"): |
| 786 | + with handler(UnsafeEvalProvider()): |
| 787 | + encodable.encode(len) |
| 788 | + |
| 789 | + def test_decode_no_callable_raises(self): |
| 790 | + from collections.abc import Callable |
| 791 | + |
| 792 | + encodable = Encodable.define(Callable, {}) |
| 793 | + # Source code that defines no callable |
| 794 | + source = "x = 42" |
| 795 | + with pytest.raises(ValueError, match="exactly one callable"): |
| 796 | + with handler(UnsafeEvalProvider()): |
| 797 | + encodable.decode(source) |
| 798 | + |
| 799 | + def test_decode_multiple_callables_raises(self): |
| 800 | + from collections.abc import Callable |
| 801 | + |
| 802 | + encodable = Encodable.define(Callable, {}) |
| 803 | + # Source code that defines multiple callables |
| 804 | + source = """def foo(): |
| 805 | + return 1 |
| 806 | +
|
| 807 | +def bar(): |
| 808 | + return 2""" |
| 809 | + with pytest.raises(ValueError, match="exactly one callable"): |
| 810 | + with handler(UnsafeEvalProvider()): |
| 811 | + encodable.decode(source) |
| 812 | + |
| 813 | + def test_decode_class(self): |
| 814 | + from collections.abc import Callable |
| 815 | + |
| 816 | + encodable = Encodable.define(Callable, {}) |
| 817 | + # Classes are callable, decode should work with class definitions |
| 818 | + source = """class Greeter: |
| 819 | + def __init__(self, name): |
| 820 | + self.name = name |
| 821 | +
|
| 822 | + def greet(self): |
| 823 | + return f"Hello, {self.name}!\"""" |
| 824 | + |
| 825 | + with handler(UnsafeEvalProvider()): |
| 826 | + decoded = encodable.decode(source) |
| 827 | + assert callable(decoded) |
| 828 | + instance = decoded("World") |
| 829 | + assert instance.greet() == "Hello, World!" |
| 830 | + |
| 831 | + def test_roundtrip(self): |
| 832 | + from collections.abc import Callable |
| 833 | + |
| 834 | + def greet(name: str) -> str: |
| 835 | + return f"Hello, {name}!" |
| 836 | + |
| 837 | + encodable = Encodable.define(Callable, {}) |
| 838 | + with handler(UnsafeEvalProvider()): |
| 839 | + encoded = encodable.encode(greet) |
| 840 | + decoded = encodable.decode(encoded) |
| 841 | + |
| 842 | + assert callable(decoded) |
| 843 | + assert decoded("Alice") == "Hello, Alice!" |
| 844 | + assert decoded.__name__ == "greet" |
0 commit comments