diff --git a/src/openstack_mcp_server/prompts/__init__.py b/src/openstack_mcp_server/prompts/__init__.py index e69de29..e27b30e 100644 --- a/src/openstack_mcp_server/prompts/__init__.py +++ b/src/openstack_mcp_server/prompts/__init__.py @@ -0,0 +1,26 @@ +from fastmcp import FastMCP + + +def register_prompt(mcp: FastMCP): + """ + Register Openstack MCP prompts. + """ + + @mcp.prompt() + def get_servers_by_security_group(security_group_name: str) -> str: + """ + Get servers associated with a specific security group. + + :param security_group_name: The name of the security group to filter servers by. + """ + return ( + f"Find all compute servers that have the security group " + f"'{security_group_name}' attached.\n\n" + f"Steps:\n" + f"1. Call get_servers to list all servers.\n" + f"2. Check each server's security_groups field.\n" + f"3. Return only the servers where security_groups contains " + f"an entry with name '{security_group_name}'.\n" + f"4. For each matching server, show the server name, ID, " + f"status, and the full list of its security groups." + ) diff --git a/src/openstack_mcp_server/server.py b/src/openstack_mcp_server/server.py index f71e788..a8be598 100644 --- a/src/openstack_mcp_server/server.py +++ b/src/openstack_mcp_server/server.py @@ -2,6 +2,7 @@ from fastmcp.server.middleware.error_handling import ErrorHandlingMiddleware from fastmcp.server.middleware.logging import LoggingMiddleware +from openstack_mcp_server.prompts import register_prompt from openstack_mcp_server.tools import register_tool @@ -13,7 +14,7 @@ def serve(transport: str, **kwargs): register_tool(mcp) # resister_resources(mcp) - # register_prompt(mcp) + register_prompt(mcp) # Add middlewares mcp.add_middleware(ErrorHandlingMiddleware()) diff --git a/tests/prompts/test_network_prompts.py b/tests/prompts/test_network_prompts.py new file mode 100644 index 0000000..a18fc09 --- /dev/null +++ b/tests/prompts/test_network_prompts.py @@ -0,0 +1,31 @@ +from unittest.mock import MagicMock + +from fastmcp import FastMCP + +from openstack_mcp_server.prompts import register_prompt + + +class TestPrompts: + """Test cases for MCP prompts.""" + + def test_get_servers_by_security_group_prompt_registered(self): + """Test that the prompt is registered with the MCP instance.""" + mcp = MagicMock() + register_prompt(mcp) + mcp.prompt.assert_called() + + def test_get_servers_by_security_group_prompt_content(self): + """Test that the prompt returns expected content.""" + mcp = FastMCP("test") + register_prompt(mcp) + + prompts = mcp._prompt_manager._prompts + assert "get_servers_by_security_group" in prompts + + prompt_obj = prompts["get_servers_by_security_group"] + assert prompt_obj.fn is not None + + result = prompt_obj.fn(security_group_name="my-sg") + assert "my-sg" in result + assert "get_servers" in result + assert "security_groups" in result