Skip to content

Commit d653238

Browse files
Add support for PersonaAgent in PersonaDialogGenerator and tests
1 parent 1531c38 commit d653238

3 files changed

Lines changed: 32 additions & 7 deletions

File tree

src/sdialog/generators.py

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from langchain_core.messages import HumanMessage, SystemMessage
1818

1919
from . import Dialog, Turn
20-
from .personas import Persona
20+
from .personas import Persona, PersonaAgent
2121

2222

2323
class LLMDialogOutput(BaseModel):
@@ -152,10 +152,13 @@ class PersonaDialogGenerator(DialogGenerator):
152152
:ivar persona_b: The second persona.
153153
:vartype persona_b: Persona
154154
"""
155+
_agent_a = None
156+
_agent_b = None
157+
155158
def __init__(self,
156159
model: Union[ChatOllama, str],
157-
persona_a: Persona,
158-
persona_b: Persona,
160+
persona_a: Union[Persona, PersonaAgent],
161+
persona_b: Union[Persona, PersonaAgent],
159162
dialogue_details: str = "",
160163
response_details: str = "responses SHOULD NOT be too long and wordy, should be "
161164
"approximately one utterance long",
@@ -166,9 +169,9 @@ def __init__(self,
166169
:param model: The LLM or model name to use.
167170
:type model: Union[ChatOllama, str]
168171
:param persona_a: The first persona.
169-
:type persona_a: Persona
172+
:type persona_a: Persona (or PersonaAgent)
170173
:param persona_b: The second persona.
171-
:type persona_b: Persona
174+
:type persona_b: Persona (or PersonaAgent)
172175
:param dialogue_details: Additional dialogue instructions.
173176
:type dialogue_details: str
174177
:param response_details: Instructions for response style.
@@ -177,6 +180,10 @@ def __init__(self,
177180
:type scenario: dict
178181
"""
179182

183+
if isinstance(persona_a, PersonaAgent) and isinstance(persona_b, PersonaAgent):
184+
self._agent_a = persona_a
185+
self._agent_b = persona_b
186+
180187
dialogue_details = f"""Role play as the following two characters having a conversations. The characters are defined by the personas in the following lines. You always stay in character.
181188
[[ ## BEGING FIRST PERSONA ## ]]
182189
{persona_a}
@@ -194,3 +201,12 @@ def __init__(self,
194201
super().__init__(model=model,
195202
dialogue_details=dialogue_details,
196203
scenario=scenario)
204+
205+
def generate(self, seed: int = None, id: int = None, max_iterations: int = 20):
206+
if self._agent_a and self._agent_b:
207+
return self._agent_a.dialog_with(self._agent_b,
208+
max_iterations=max_iterations,
209+
id=id,
210+
seed=seed)
211+
else:
212+
return super().generate(seed=seed, id=id)

tests/test_generators.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from sdialog.generators import DialogGenerator, PersonaDialogGenerator, LLMDialogOutput, Turn
2-
from sdialog.personas import Persona
2+
from sdialog.personas import Persona, PersonaAgent
33

44

55
MODEL = "smollm:135m"
@@ -38,3 +38,12 @@ def test_persona_dialog_generator(monkeypatch):
3838
gen = PersonaDialogGenerator(MODEL, persona_a, persona_b)
3939
dialog = gen()
4040
assert hasattr(dialog, "turns")
41+
42+
43+
def test_persona_dialog_generator_with_agents(monkeypatch):
44+
monkeypatch.setattr("sdialog.generators.ChatOllama", DummyLLM)
45+
persona_a = PersonaAgent(DummyLLM(), name="A")
46+
persona_b = PersonaAgent(DummyLLM(), name="B")
47+
gen = PersonaDialogGenerator(MODEL, persona_a, persona_b)
48+
dialog = gen()
49+
assert hasattr(dialog, "turns")

tests/test_personas.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ def test_persona_agent_init(monkeypatch):
4545
persona = Persona(name="Alice")
4646
agent = PersonaAgent(DummyLLM(), persona=persona, name="Alice")
4747
assert agent.get_name() == "Alice"
48-
assert "Role play" in agent.get_prompt()
48+
assert "role play" in agent.get_prompt().lower()
4949
agent.set_first_utterances("Hi!")
5050
assert agent.first_utterances == "Hi!"
5151
agent.clear_orchestrators()

0 commit comments

Comments
 (0)