Skip to content

Commit da491ed

Browse files
committed
added a new mcp tool
add a new mcp tool for orders by product.
1 parent 9535121 commit da491ed

4 files changed

Lines changed: 122 additions & 17 deletions

File tree

.env.example

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,5 @@ CONNX_DSN=your_dsn_name
33
CONNX_USER=your_username
44
CONNX_PASS=your_password
55
CONNX_TIMEOUT=30
6+
CONNX_MAX_ROWS=1000
7+

CONNX_MCP_Sample.code-workspace

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
"path": "."

connx_server.py

Lines changed: 110 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,19 @@
2020
# Optional security controls
2121
CONNX_ALLOW_WRITES = os.getenv("CONNX_ALLOW_WRITES", "false").strip().lower() == "true"
2222

23+
# Result limits
24+
def _env_int(name: str, default: int, minimum: int = 1) -> int:
25+
try:
26+
value = int(os.getenv(name, str(default)))
27+
if value < minimum:
28+
return default
29+
return value
30+
except (TypeError, ValueError):
31+
return default
32+
33+
34+
MAX_RESULT_ROWS = _env_int("CONNX_MAX_ROWS", default=1000, minimum=1)
35+
2336
# Setup logging (log to stderr to avoid interfering with MCP stdout)
2437
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
2538
logger = logging.getLogger(__name__)
@@ -97,6 +110,18 @@ def _is_select_only(sql: str) -> bool:
97110
return s.startswith("select")
98111

99112

113+
def _first_keyword(sql: str) -> str:
114+
"""Return the first keyword/token of the SQL (lowercased)."""
115+
return (sql or "").lstrip().split(" ", 1)[0].lower()
116+
117+
118+
def _effective_limit(requested: Optional[int]) -> int:
119+
"""Clamp requested row limit to the configured maximum."""
120+
if requested and requested > 0:
121+
return min(requested, MAX_RESULT_ROWS)
122+
return MAX_RESULT_ROWS
123+
124+
100125
def get_connx_connection():
101126
"""Establish a connection to CONNX via pyodbc."""
102127
_assert_config()
@@ -112,28 +137,42 @@ def get_connx_connection():
112137
raise ValueError(f"Failed to connect to CONNX: {str(e)}")
113138

114139

115-
async def execute_query_async(query: str, params: Optional[List[Any]] = None) -> List[Dict[str, Any]]:
140+
async def execute_query_async(
141+
query: str,
142+
params: Optional[List[Any]] = None,
143+
max_rows: Optional[int] = None
144+
) -> List[Dict[str, Any]]:
116145
"""Asynchronous execution of SELECT queries via CONNX."""
117146
loop = asyncio.get_running_loop()
118-
return await loop.run_in_executor(None, execute_query, query, params)
147+
return await loop.run_in_executor(None, execute_query, query, params, max_rows)
119148

120149

121-
def execute_query(query: str, params: Optional[List[Any]] = None) -> List[Dict[str, Any]]:
150+
def execute_query(
151+
query: str,
152+
params: Optional[List[Any]] = None,
153+
max_rows: Optional[int] = None
154+
) -> List[Dict[str, Any]]:
122155
"""Execute SELECT query and return results as list of dicts."""
123156
conn = get_connx_connection()
124157
fp = _sql_fingerprint(query)
158+
limit = max_rows if max_rows and max_rows > 0 else MAX_RESULT_ROWS
125159
try:
126160
cursor = conn.cursor()
127-
# cursor.timeout = int(os.getenv("CONNX_TIMEOUT", "30"))
161+
# cursor.timeout = int(os.getenv("CONNX_TIMEOUT", "30"))
128162
cursor.execute(query, params or [])
129163
if cursor.description is None:
130164
# A SELECT should provide a description; if not, treat as an error.
131165
raise ValueError("Query did not return a result set (cursor.description is None).")
132166

133167
columns = [desc[0] for desc in cursor.description]
134-
rows = cursor.fetchall()
168+
rows = cursor.fetchmany(limit + 1) if limit else cursor.fetchall()
169+
truncated = len(rows) > limit if limit else False
170+
if truncated:
171+
rows = rows[:limit]
135172
results = [dict(zip(columns, row)) for row in rows]
136173
logger.info("Query OK fp=%s rows=%d", fp, len(results))
174+
if truncated:
175+
logger.info("Query truncated fp=%s limit=%d", fp, limit)
137176
return results
138177
except (pyodbc.Error, ValueError) as e:
139178
logger.error("Query failed fp=%s err=%s", fp, e)
@@ -183,7 +222,7 @@ async def query_connx(query: str) -> Dict[str, Any]:
183222
return {"error": "Only SELECT statements are allowed for query_connx."}
184223

185224
try:
186-
results = await execute_query_async(query)
225+
results = await execute_query_async(query, max_rows=MAX_RESULT_ROWS)
187226
return {"results": results, "count": len(results)}
188227
except ValueError as e:
189228
return {"error": str(e)}
@@ -201,9 +240,13 @@ async def update_connx(operation: str, query: str) -> Dict[str, Any]:
201240
if not CONNX_ALLOW_WRITES:
202241
return {"error": "Writes are disabled. Set CONNX_ALLOW_WRITES=true to enable update operations."}
203242

204-
if operation.lower() not in ["insert", "update", "delete"]:
243+
op = operation.strip().lower()
244+
if op not in ["insert", "update", "delete"]:
205245
return {"error": "Invalid operation. Must be 'insert', 'update', or 'delete'."}
206246

247+
if _first_keyword(query) != op:
248+
return {"error": f"SQL must start with {op.upper()} for this operation."}
249+
207250
if not _is_single_statement(query):
208251
return {"error": "Only a single SQL statement is allowed (no semicolons)."}
209252

@@ -236,7 +279,7 @@ async def count_customers() -> Dict[str, Any]:
236279
async def get_schema() -> Dict[str, Any]:
237280
query = "SELECT TABLE_NAME, COLUMN_NAME, DATA_TYPE FROM INFORMATION_SCHEMA.COLUMNS"
238281
try:
239-
results = await execute_query_async(query)
282+
results = await execute_query_async(query, max_rows=MAX_RESULT_ROWS)
240283
return {"schemas": results}
241284
except ValueError as e:
242285
return {"error": str(e)}
@@ -250,7 +293,7 @@ async def get_schema_for_table(table_name: str) -> Dict[str, Any]:
250293
"WHERE TABLE_NAME = ?"
251294
)
252295
try:
253-
results = await execute_query_async(query, params=[table_name])
296+
results = await execute_query_async(query, params=[table_name], max_rows=MAX_RESULT_ROWS)
254297
return {"schemas": results}
255298
except ValueError as e:
256299
return {"error": str(e)}
@@ -447,12 +490,13 @@ async def find_customers(state: str, city: Optional[str] = None, max_rows: int =
447490
sql += " ORDER BY RTRIM(CUSTOMERNAME)"
448491

449492
try:
450-
results = await execute_query_async(sql, params=params)
493+
limit = _effective_limit(max_rows)
494+
fetch_limit = min(limit + 1, MAX_RESULT_ROWS + 1)
495+
results = await execute_query_async(sql, params=params, max_rows=fetch_limit)
451496

452-
truncated = False
453-
if max_rows and max_rows > 0 and len(results) > max_rows:
454-
results = results[:max_rows]
455-
truncated = True
497+
truncated = len(results) > limit
498+
if truncated:
499+
results = results[:limit]
456500

457501
return {"results": results, "count": len(results), "truncated": truncated}
458502
except ValueError as e:
@@ -528,7 +572,58 @@ async def get_semantic_entities() -> Dict[str, Any]:
528572
},
529573
]
530574
}
575+
576+
@mcp.tool()
577+
async def customer_orders_for_product(
578+
customer_id: str,
579+
product_name: str,
580+
max_rows: int = 50
581+
) -> Dict[str, Any]:
582+
"""
583+
Get detailed order information for a specific customer and product.
584+
585+
Args:
586+
customer_id: Customer identifier
587+
product_name: Name of the product
588+
max_rows: Maximum number of orders to return (default: 50)
589+
590+
Returns order details including dates, quantities, etc.
591+
"""
592+
sql = """
593+
SELECT
594+
o.ORDERID,
595+
o.ORDERDATE,
596+
o.PRODUCTQUANTITY,
597+
RTRIM(p.PRODUCTNAME) AS PRODUCTNAME,
598+
RTRIM(c.CUSTOMERNAME) AS CUSTOMERNAME
599+
FROM daea_Mainframe_VSAM.dbo.ORDERS_VSAM o
600+
INNER JOIN daea_Mainframe_VSAM.dbo.CUSTOMERS_VSAM c
601+
ON RTRIM(c.CUSTOMERID) = RTRIM(o.CUSTOMERID)
602+
INNER JOIN daea_Mainframe_VSAM.dbo.PRODUCTS_VSAM p
603+
ON o.PRODUCTID = p.PRODUCTID
604+
WHERE RTRIM(c.CUSTOMERID) = ?
605+
AND UPPER(RTRIM(p.PRODUCTNAME)) = UPPER(?)
606+
ORDER BY o.ORDERDATE DESC
607+
"""
608+
609+
try:
610+
limit = _effective_limit(max_rows)
611+
results = await execute_query_async(
612+
sql,
613+
params=[customer_id.strip(), product_name.strip()],
614+
max_rows=limit
615+
)
616+
617+
return {
618+
"customer_id": customer_id,
619+
"product_name": product_name,
620+
"orders": results,
621+
"count": len(results)
622+
}
623+
except ValueError as e:
624+
return {"error": str(e)}
625+
531626
# Main Entry Point
532627
if __name__ == "__main__": # pragma: no cover
533628
# FastMCP.run() manages its own event loop via anyio.run()
534-
mcp.run(transport="stdio")
629+
mcp.run(transport="stdio")

tests/test_server.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -132,12 +132,13 @@ def test_execute_query_success_returns_list_of_dicts(self, mock_get_conn):
132132
fake_conn.cursor.return_value = fake_cursor
133133

134134
fake_cursor.description = [("ID",), ("NAME",)]
135-
fake_cursor.fetchall.return_value = [(1, "Alice"), (2, "Bob")]
135+
fake_cursor.fetchmany.return_value = [(1, "Alice"), (2, "Bob")]
136136

137137
results = mod.execute_query("SELECT ID, NAME FROM T WHERE ID > ?", params=[0])
138138

139139
self.assertEqual(results, [{"ID": 1, "NAME": "Alice"}, {"ID": 2, "NAME": "Bob"}])
140140
fake_cursor.execute.assert_called_once_with("SELECT ID, NAME FROM T WHERE ID > ?", [0])
141+
fake_cursor.fetchmany.assert_called_once()
141142
fake_conn.close.assert_called_once()
142143

143144
@patch(f"{MODULE_UNDER_TEST}.get_connx_connection")
@@ -262,6 +263,12 @@ async def test_update_connx_rejects_invalid_operation_when_writes_enabled(self):
262263
self.assertIn("error", out)
263264
self.assertIn("invalid operation", out["error"].lower())
264265

266+
@patch(f"{MODULE_UNDER_TEST}.CONNX_ALLOW_WRITES", True)
267+
async def test_update_connx_requires_sql_keyword_match(self):
268+
out = await mod.update_connx("update", "DELETE FROM T")
269+
self.assertIn("error", out)
270+
self.assertIn("must start with update", out["error"].lower())
271+
265272
@patch(f"{MODULE_UNDER_TEST}.CONNX_ALLOW_WRITES", True)
266273
async def test_update_connx_rejects_semicolons_when_writes_enabled(self):
267274
out = await mod.update_connx("update", "UPDATE T SET A=1; UPDATE T SET A=2")
@@ -450,4 +457,4 @@ async def test_count_entities_known_entity_calls_db(self):
450457

451458

452459
if __name__ == "__main__":
453-
unittest.main(verbosity=2)
460+
unittest.main(verbosity=2)

0 commit comments

Comments
 (0)