Skip to content

Commit 8741728

Browse files
authored
Restrict subagent names (#722)
1 parent 3490b18 commit 8741728

2 files changed

Lines changed: 73 additions & 1 deletion

File tree

splunklib/ai/engines/langchain.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import json
1616
import logging
1717
import os
18+
import string
1819
import uuid
1920
from collections.abc import Awaitable, Callable, Sequence
2021
from dataclasses import asdict, dataclass
@@ -1413,11 +1414,22 @@ def _denormalize_tool_name(name: str) -> str:
14131414
return name
14141415

14151416

1417+
def _is_agent_name_valid(name: str) -> bool:
1418+
AGENT_NAME_ALLOWED_CHARS = string.ascii_letters + string.digits + "_-"
1419+
if not (1 <= len(name) <= 128):
1420+
return False
1421+
1422+
return set(name).issubset(AGENT_NAME_ALLOWED_CHARS)
1423+
1424+
14161425
def _agent_as_tool(agent: BaseAgent[OutputT]) -> StructuredTool:
14171426
if not agent.name:
14181427
raise AssertionError("Agent must have a name to be used by other Agents")
14191428

1420-
# TODO: restrict subagent names
1429+
if not _is_agent_name_valid(agent.name):
1430+
raise AssertionError(
1431+
"Agent name is invalid, must contain only letters, numbers, '_' or '-' and have max 128 characters"
1432+
)
14211433

14221434
async def invoke_agent(
14231435
message: HumanMessage, thread_id: str | None

tests/integration/ai/test_agent.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -440,6 +440,66 @@ async def test_duplicated_subagent_name(self) -> None:
440440
):
441441
pass
442442

443+
@pytest.mark.asyncio
444+
async def test_subagent_with_invalid_name(self) -> None:
445+
pytest.importorskip("langchain_openai")
446+
447+
async with (
448+
Agent(
449+
model=(await self.model()),
450+
system_prompt="",
451+
service=self.service,
452+
name="invalid name",
453+
) as subagent_invalid,
454+
Agent(
455+
model=(await self.model()),
456+
system_prompt="",
457+
service=self.service,
458+
name="invalid@name",
459+
) as subagent_invalid2,
460+
Agent(
461+
model=(await self.model()),
462+
system_prompt="",
463+
service=self.service,
464+
name="a" * 129,
465+
) as subagent_too_long,
466+
):
467+
with pytest.raises(
468+
AssertionError,
469+
match="Agent name is invalid",
470+
):
471+
async with Agent(
472+
model=(await self.model()),
473+
system_prompt="",
474+
service=self.service,
475+
agents=[subagent_invalid],
476+
):
477+
pass
478+
479+
with pytest.raises(
480+
AssertionError,
481+
match="Agent name is invalid",
482+
):
483+
async with Agent(
484+
model=(await self.model()),
485+
system_prompt="",
486+
service=self.service,
487+
agents=[subagent_invalid2],
488+
):
489+
pass
490+
491+
with pytest.raises(
492+
AssertionError,
493+
match="Agent name is invalid",
494+
):
495+
async with Agent(
496+
model=(await self.model()),
497+
system_prompt="",
498+
service=self.service,
499+
agents=[subagent_too_long],
500+
):
501+
pass
502+
443503
@pytest.mark.asyncio
444504
async def test_subagent_soft_failure_with_invalid_args(self) -> None:
445505
pytest.importorskip("langchain_openai")

0 commit comments

Comments
 (0)