Skip to content

Commit e12a10a

Browse files
committed
[WIP] Add anthropic-bedrock model for testing only
1 parent 9e45daf commit e12a10a

11 files changed

Lines changed: 1370 additions & 1060 deletions

File tree

.env.template

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,3 +20,11 @@ version=9.0
2020
#internal_ai_app_key=""
2121
#internal_ai_token_url=""
2222
#internal_ai_base_url=""
23+
24+
#bedrock_model_id=""
25+
#bedrock_aws_region=""
26+
#bedrock_base_model_id=""
27+
28+
# Sonnet fallback for tests that require a more capable model
29+
#bedrock_sonnet_model_id=""
30+
#bedrock_sonnet_base_model_id=""

pyproject.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ compat = ["six>=1.17.0"]
3636
ai = ["httpx==0.28.1", "langchain>=1.2.13", "mcp>=1.26.0", "pydantic>=2.7.4"]
3737
anthropic = ["splunk-sdk[ai]>=2.1.1", "langchain-anthropic>=1.4.0"]
3838
openai = ["splunk-sdk[ai]>=2.1.1", "langchain-openai>=1.1.12"]
39+
bedrock = ["splunk-sdk[anthropic]>=2.1.1", "langchain-aws>=0.2.0"]
3940

4041
# Treat the same as NPM's `devDependencies`
4142
[dependency-groups]
@@ -50,7 +51,7 @@ release = ["build>=1.4.2", "jinja2>=3.1.6", "sphinx>=9.1.0", "twine>=6.2.0"]
5051
lint = ["basedpyright>=1.38.4", "ruff>=0.15.8"]
5152
dev = [
5253
"rich>=14.3.3",
53-
"splunk-sdk[openai, anthropic]",
54+
"splunk-sdk[openai, anthropic, bedrock]",
5455
{ include-group = "test" },
5556
{ include-group = "lint" },
5657
{ include-group = "release" },

splunklib/ai/engines/langchain.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@
121121
LC_ModelRequest = Langchain_ModelRequest["InvokeContext"]
122122

123123
# Set to True to enable debugging mode.
124-
_DEBUG = False
124+
_DEBUG = True
125125

126126
# Disallow _DEBUG == True in CI.
127127
# Github actions sets the CI env var.

tests/ai_test_model.py

Lines changed: 42 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
import collections.abc
2-
from typing import override
2+
from dataclasses import dataclass
3+
from typing import Any, override
34

45
import httpx
56
from httpx import Auth, Request, Response
7+
from langchain_core.language_models import BaseChatModel
68
from pydantic import BaseModel
79

8-
from splunklib.ai import OpenAIModel
10+
from splunklib.ai import AnthropicModel, OpenAIModel
911
from splunklib.ai.model import PredefinedModel
1012

1113

@@ -18,14 +20,51 @@ class InternalAIModel(BaseModel):
1820
base_url: str
1921

2022

23+
@dataclass(frozen=True)
24+
class AnthropicBedrockModel(AnthropicModel):
25+
"""Anthropic model accessed via AWS Bedrock, for testing only."""
26+
27+
api_key: str = ""
28+
base_url: str = ""
29+
aws_region: str = ""
30+
base_model_id: str = ""
31+
32+
def _to_langchain_model(self) -> BaseChatModel:
33+
try:
34+
from langchain_aws import ChatBedrockConverse
35+
36+
kwargs: dict[str, Any] = {"model": self.model}
37+
if self.aws_region:
38+
kwargs["region_name"] = self.aws_region
39+
if self.temperature is not None:
40+
kwargs["temperature"] = self.temperature
41+
if self.model.startswith("arn:"):
42+
kwargs["provider"] = "anthropic"
43+
kwargs["base_model_id"] = (
44+
self.base_model_id or "anthropic.claude-haiku-4-5-20251001"
45+
)
46+
return ChatBedrockConverse(**kwargs)
47+
except ImportError:
48+
raise ImportError(
49+
"AWS Bedrock support is not installed.\n"
50+
+ "To enable Bedrock models, install the optional extra:\n"
51+
+ 'pip install "splunk-sdk[bedrock]"\n'
52+
+ "# or if using uv:\n"
53+
+ "uv add splunk-sdk[bedrock]"
54+
)
55+
56+
2157
class TestLLMSettings(BaseModel):
2258
# TODO: Currently we only support our internal OpenAI-compatible model,
2359
# once we are close to GA we should also support OpenAI and probably Ollama, such
2460
# that external developers can also run our test suite suite locally.
2561
internal_ai: InternalAIModel | None = None
62+
anthropic_bedrock: AnthropicBedrockModel | None = None
2663

2764

2865
async def create_model(s: TestLLMSettings) -> PredefinedModel:
66+
if s.anthropic_bedrock is not None:
67+
return s.anthropic_bedrock
2968
if s.internal_ai is not None:
3069
return await _buildInternalAIModel(
3170
token_url=s.internal_ai.token_url,
@@ -46,7 +85,7 @@ def __init__(self, token: str) -> None:
4685
@override
4786
def auth_flow(
4887
self, request: Request
49-
) -> collections.abc.Generator[Request, Response, None]:
88+
) -> collections.abc.Generator[Request, Response]:
5089
request.headers["api-key"] = self.token
5190
yield request
5291

tests/ai_testlib.py

Lines changed: 59 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,18 @@
11
from typing import override
2+
23
from splunklib.ai.model import PredefinedModel
3-
from tests.ai_test_model import InternalAIModel, TestLLMSettings, create_model
4+
from tests.ai_test_model import (
5+
AnthropicBedrockModel,
6+
InternalAIModel,
7+
TestLLMSettings,
8+
create_model,
9+
)
410
from tests.testlib import SDKTestCase
511

612

713
class AITestCase(SDKTestCase):
814
_model: PredefinedModel | None = None
15+
_sonnet_model: PredefinedModel | None = None
916

1017
@override
1118
def setUp(self) -> None:
@@ -20,6 +27,24 @@ def setUp(self) -> None:
2027

2128
@property
2229
def test_llm_settings(self) -> TestLLMSettings:
30+
bedrock_model_id: str = self.opts.kwargs.get(
31+
"bedrock_model_id", ""
32+
) # ignore: [reportUnknownVariableType]
33+
if bedrock_model_id:
34+
aws_region: str = self.opts.kwargs.get(
35+
"bedrock_aws_region", ""
36+
) # ignore: [reportUnknownVariableType]
37+
base_model_id: str = self.opts.kwargs.get(
38+
"bedrock_base_model_id", ""
39+
) # ignore: [reportUnknownVariableType]
40+
return TestLLMSettings(
41+
anthropic_bedrock=AnthropicBedrockModel(
42+
model=bedrock_model_id, # ignore: [reportUnknownVariableType]
43+
aws_region=aws_region, # ignore: [reportUnknownVariableType]
44+
base_model_id=base_model_id, # ignore: [reportUnknownVariableType]
45+
)
46+
)
47+
2348
client_id: str = self.opts.kwargs["internal_ai_client_id"]
2449
client_secret: str = self.opts.kwargs["internal_ai_client_secret"]
2550
app_key: str = self.opts.kwargs["internal_ai_app_key"]
@@ -42,3 +67,36 @@ async def model(self) -> PredefinedModel:
4267
model = await create_model(self.test_llm_settings)
4368
self._model = model
4469
return model
70+
71+
async def sonnet_model(self) -> PredefinedModel:
72+
"""Returns a Sonnet model for tests that require a more capable model.
73+
74+
Falls back to the default model if no Sonnet config is provided.
75+
"""
76+
if self._sonnet_model is not None:
77+
return self._sonnet_model
78+
79+
sonnet_model_id: str = self.opts.kwargs.get("bedrock_sonnet_model_id", "")
80+
if sonnet_model_id:
81+
aws_region: str = self.opts.kwargs.get("bedrock_aws_region", "")
82+
base_model_id: str = self.opts.kwargs.get("bedrock_sonnet_base_model_id", "")
83+
settings = TestLLMSettings(
84+
anthropic_bedrock=AnthropicBedrockModel(
85+
model=sonnet_model_id,
86+
aws_region=aws_region,
87+
base_model_id=base_model_id,
88+
)
89+
)
90+
model = await create_model(settings)
91+
self._sonnet_model = model
92+
return model
93+
94+
return await self.model()
95+
96+
@property
97+
def supports_provider_strategy(self) -> bool:
98+
"""Returns True if the configured model supports ProviderStrategy (native JSON output).
99+
100+
AnthropicBedrockModel routes through ToolStrategy instead, so it returns False.
101+
"""
102+
return self.test_llm_settings.anthropic_bedrock is None

tests/conftest.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
from collections.abc import Generator
2+
3+
import pytest
4+
from langchain_core.language_models import BaseChatModel
5+
6+
from splunklib.ai.engines import langchain as lc_engine
7+
from splunklib.ai.model import PredefinedModel
8+
from tests.ai_test_model import AnthropicBedrockModel
9+
10+
_original_create_langchain_model = lc_engine._create_langchain_model # pyright: ignore[reportPrivateUsage]
11+
12+
13+
def _patched_create_langchain_model(model: PredefinedModel) -> BaseChatModel:
14+
if isinstance(model, AnthropicBedrockModel):
15+
return model._to_langchain_model() # pyright: ignore[reportPrivateUsage]
16+
return _original_create_langchain_model(model)
17+
18+
19+
@pytest.fixture(autouse=True)
20+
def _patch_langchain_model_factory(request: pytest.FixtureRequest) -> Generator[None]:
21+
if "integration/ai" not in str(request.fspath):
22+
yield
23+
return
24+
lc_engine._create_langchain_model = _patched_create_langchain_model # pyright: ignore[reportPrivateUsage]
25+
yield
26+
lc_engine._create_langchain_model = _original_create_langchain_model # pyright: ignore[reportPrivateUsage]

tests/integration/ai/test_hooks.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
before_agent,
3030
before_model,
3131
)
32-
from splunklib.ai.messages import AIMessage, AgentResponse, HumanMessage
32+
from splunklib.ai.messages import AIMessage, AgentResponse, HumanMessage, StructuredOutputMessage
3333
from splunklib.ai.middleware import AgentRequest, ModelMiddlewareHandler, ModelRequest, ModelResponse, model_middleware
3434
from tests.ai_testlib import AITestCase
3535

@@ -127,7 +127,10 @@ async def after_agent_hook(resp: AgentResponse) -> None:
127127
person = resp.structured_output
128128
assert type(person) is Person
129129
assert person.name.lower() == "stefan"
130-
assert len(resp.messages) == 2
130+
# ProviderStrategy: 2 messages (human + AI).
131+
# ToolStrategy: 3 messages (human + AI tool_use + StructuredOutputMessage).
132+
uses_tool_strategy = any(isinstance(m, StructuredOutputMessage) for m in resp.messages)
133+
assert len(resp.messages) == (3 if uses_tool_strategy else 2)
131134

132135
@after_agent
133136
async def after_async_agent_hook(resp: AgentResponse) -> None:
@@ -137,7 +140,10 @@ async def after_async_agent_hook(resp: AgentResponse) -> None:
137140
person = resp.structured_output
138141
assert type(person) is Person
139142
assert person.name.lower() == "stefan"
140-
assert len(resp.messages) == 2
143+
# ProviderStrategy: 2 messages (human + AI).
144+
# ToolStrategy: 3 messages (human + AI tool_use + StructuredOutputMessage).
145+
uses_tool_strategy = any(isinstance(m, StructuredOutputMessage) for m in resp.messages)
146+
assert len(resp.messages) == (3 if uses_tool_strategy else 2)
141147

142148
async with Agent(
143149
model=(await self.model()),
@@ -159,8 +165,14 @@ async def after_async_agent_hook(resp: AgentResponse) -> None:
159165
]
160166
)
161167

162-
response = result.final_message.content.strip().lower().replace(".", "")
163-
assert '{"name":"stefan"}' == response
168+
# With ProviderStrategy the final message is plain JSON text.
169+
# With ToolStrategy the structured output is in result.structured_output.
170+
person = result.structured_output
171+
if person is not None:
172+
assert person.name.lower() == "stefan"
173+
else:
174+
response = result.final_message.content.strip().lower().replace(".", "")
175+
assert '{"name":"stefan"}' == response
164176
assert hook_calls == 4
165177

166178
@pytest.mark.asyncio

tests/integration/ai/test_middleware.py

Lines changed: 24 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -358,8 +358,9 @@ class NicknameGeneratorInput(BaseModel):
358358
Agent(
359359
model=await self.model(),
360360
system_prompt=(
361-
"You are a helpful assistant that generates nicknames. A valid "
362-
+ "nickname consists of the provided name suffixed with '-zilla.'"
361+
"You are a helpful assistant that generates nicknames. "
362+
+ "The nickname MUST be formatted as exactly '<name>-zilla' with a hyphen. "
363+
+ "For example: Chris -> Chris-zilla, Alice -> Alice-zilla."
363364
),
364365
service=self.service,
365366
name="NicknameGeneratorAgent",
@@ -406,15 +407,16 @@ async def test_middleware(
406407
first_response = await handler(request)
407408
second_response = await handler(request)
408409
assert isinstance(first_response.result, SubagentTextResult)
409-
assert second_response == first_response
410+
assert isinstance(second_response.result, SubagentTextResult)
410411
return second_response
411412

412413
async with (
413414
Agent(
414415
model=await self.model(),
415416
system_prompt=(
416-
"You are a helpful assistant that generates nicknames. A valid "
417-
+ "nickname consists of the provided name suffixed with '-zilla.'"
417+
"You are a helpful assistant that generates nicknames. "
418+
+ "The nickname MUST be formatted as exactly '<name>-zilla' with a hyphen. "
419+
+ "For example: Chris -> Chris-zilla, Alice -> Alice-zilla."
418420
),
419421
service=self.service,
420422
name="NicknameGeneratorAgent",
@@ -472,8 +474,9 @@ async def test_middleware(
472474
Agent(
473475
model=await self.model(),
474476
system_prompt=(
475-
"You are a helpful assistant that generates nicknames. A valid "
476-
+ "nickname consists of the provided name suffixed with '-zilla.'"
477+
"You are a helpful assistant that generates nicknames. "
478+
+ "The nickname MUST be formatted as exactly '<name>-zilla' with a hyphen. "
479+
+ "For example: Chris -> Chris-zilla, Alice -> Alice-zilla."
477480
),
478481
service=self.service,
479482
name="NicknameGeneratorAgent",
@@ -601,8 +604,9 @@ async def test_middleware(
601604
Agent(
602605
model=await self.model(),
603606
system_prompt=(
604-
"You are a helpful assistant that generates nicknames. A valid "
605-
+ "nickname consists of the provided name suffixed with '-zilla.'"
607+
"You are a helpful assistant that generates nicknames. "
608+
+ "The nickname MUST be formatted as exactly '<name>-zilla' with a hyphen. "
609+
+ "For example: Chris -> Chris-zilla, Alice -> Alice-zilla."
606610
),
607611
service=self.service,
608612
name="NicknameGeneratorAgent",
@@ -777,8 +781,9 @@ async def mutating_middleware(
777781
Agent(
778782
model=await self.model(),
779783
system_prompt=(
780-
"You are a helpful assistant that generates nicknames. A valid "
781-
"nickname consists of the provided name suffixed with '-zilla.'"
784+
"You are a helpful assistant that generates nicknames. "
785+
"The nickname MUST be formatted as exactly '<name>-zilla' with a hyphen. "
786+
"For example: Chris -> Chris-zilla, Alice -> Alice-zilla."
782787
),
783788
service=self.service,
784789
name="NicknameGeneratorAgent",
@@ -796,7 +801,14 @@ async def mutating_middleware(
796801
result = await supervisor.invoke(
797802
[HumanMessage(content="Generate a nickname for Bob")]
798803
)
799-
assert "Alice-zilla" in result.final_message.content
804+
# The middleware mutated the arg to "Alice", so the subagent must have
805+
# received "Alice" and returned "Alice-zilla". Check the subagent message.
806+
subagent_msg = next(
807+
(m for m in result.messages if isinstance(m, SubagentMessage)), None
808+
)
809+
assert subagent_msg is not None
810+
assert isinstance(subagent_msg.result, SubagentTextResult)
811+
assert "Alice-zilla" in subagent_msg.result.content
800812

801813
@pytest.mark.asyncio
802814
async def test_model_middleware_structured_output(self) -> None:

0 commit comments

Comments
 (0)