Skip to content

Commit 06d5891

Browse files
committed
Update to v1.2.0
1 parent a790b72 commit 06d5891

6 files changed

Lines changed: 79 additions & 22 deletions

File tree

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# YANA - Yet Another Network Agent
1+
# YANA Yet Another Network Agent
22

33
[![Version](https://img.shields.io/badge/version-1.2-1a1a2e)](https://github.com/pdudotdev/YANA/releases/tag/v1.2.0)
44
![License](https://img.shields.io/badge/license-GPLv3-1a1a2e)
@@ -52,7 +52,7 @@ Run your tests with any framework. When something fails, YANA investigates - it
5252

5353
**Step 1 - Install and ingest:**
5454
```bash
55-
sudo apt install git make python3.12-venv
55+
sudo apt install git make python3.12-venv -y
5656
cd ~ && git clone https://github.com/pdudotdev/YANA
5757
cd YANA && make setup
5858
```

input_models/models.py

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,17 @@
11
"""Pydantic input models for MCP tools."""
2+
import ipaddress
23
import json
34
import re
45
from typing import Literal, Optional
56

67
from pydantic import BaseModel, Field, field_validator, model_validator
78

89
_VRF_RE = re.compile(r'^[a-zA-Z0-9_-]{1,32}$')
10+
_DEVICE_RE = re.compile(r'^[A-Za-z0-9_-]{1,64}$')
911

1012

1113
class BaseParamsModel(BaseModel):
12-
"""Base class with JSON string parsing and VRF validation."""
14+
"""Base class with JSON string parsing, device and VRF validation."""
1315

1416
@model_validator(mode='before')
1517
@classmethod
@@ -22,6 +24,15 @@ def parse_string_input(cls, v):
2224
raise ValueError(f"Could not parse params as JSON: {v!r}") from e
2325
return v
2426

27+
@field_validator('device', mode='before', check_fields=False)
28+
@classmethod
29+
def _validate_device(cls, v):
30+
if v is None:
31+
return v
32+
if not _DEVICE_RE.match(str(v)):
33+
raise ValueError(f"device must be alphanumeric with underscores/dashes, max 64 chars. Got: {v!r}")
34+
return v
35+
2536
@field_validator('vrf', mode='before', check_fields=False)
2637
@classmethod
2738
def _validate_vrf(cls, v):
@@ -64,10 +75,21 @@ class IntentQuery(BaseParamsModel):
6475

6576
class TracerouteInput(BaseParamsModel):
6677
device: str = Field(..., description="Device name from inventory")
67-
destination: str = Field(..., description="Destination IP address")
68-
source: Optional[str] = Field(None, description="Source IP address (forces traceroute to use this interface)")
78+
destination: str = Field(..., description="Destination IP address", max_length=45)
79+
source: Optional[str] = Field(None, description="Source IP address (forces traceroute to use this interface)", max_length=45)
6980
vrf: str | None = Field(None, description="Optional VRF name")
7081

82+
@field_validator('destination', 'source', mode='before')
83+
@classmethod
84+
def _validate_ip(cls, v):
85+
if v is None:
86+
return v
87+
try:
88+
ipaddress.ip_address(str(v).strip())
89+
except ValueError:
90+
raise ValueError(f"Invalid IP address: {v!r}")
91+
return str(v).strip()
92+
7193

7294
class KBQuery(BaseParamsModel):
7395
query: str = Field(..., description="Search question for the network knowledge base", max_length=500)

platforms/platform_map.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@
124124
},
125125
"tools": {
126126
# count=1 limits to one probe per hop so the command terminates (default is continuous)
127-
"traceroute": "/tool/traceroute count=1",
127+
"traceroute": {"default": "/tool/traceroute count=1", "vrf": "/tool/traceroute count=1 routing-table={vrf}"},
128128
},
129129
},
130130

testing/automated/test_suite.py

Lines changed: 36 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,15 @@
88
from pydantic import ValidationError
99

1010
from input_models.models import (
11-
DeviceListQuery, IntentQuery, InterfacesQuery, KBQuery,
12-
OspfQuery, TracerouteInput,
11+
DeviceListQuery, IntentQuery, KBQuery,
12+
OspfQuery, RoutingQuery, TracerouteInput,
1313
)
14-
from platforms.platform_map import PLATFORM_MAP, _apply_vrf, get_action
14+
from platforms.platform_map import PLATFORM_MAP, _apply_vrf
1515
from tools.intent import query_intent
1616
from tools.inventory_tool import list_devices
1717
from tools.operational import traceroute
1818
from tools.ospf import get_ospf
19+
from tools.routing import get_routing
1920
from tools.status import get_status
2021
from transport.ssh import _build_cli, execute_ssh
2122

@@ -177,6 +178,15 @@ async def test_traceroute_source_appended():
177178
assert "source 192.168.1.1" in result["_command"]
178179

179180

181+
async def test_routing_full_stack_ios():
182+
with patch("transport.execute_ssh", new_callable=AsyncMock) as mock_ssh:
183+
mock_ssh.return_value = "O 10.0.0.0/24 [110/20] via 10.0.0.1"
184+
result = await get_routing(RoutingQuery(device="R1", query="ip_route"))
185+
assert result["device"] == "R1"
186+
assert result["cli_style"] == "ios"
187+
assert result["_command"] == "show ip route vrf VRF1"
188+
189+
180190
# ── Inventory & device listing ───────────────────────────────────────────────
181191

182192
def test_device_lookup_returns_dict(monkeypatch):
@@ -201,6 +211,10 @@ async def test_status_structure():
201211
mi.exists.return_value = False
202212
result = await get_status()
203213
assert set(result.keys()) == {"inventory", "intent", "chromadb"}
214+
assert result["inventory"]["device_count"] == 6
215+
assert result["inventory"]["source"] == "network_json"
216+
assert result["intent"]["source"] == "unavailable"
217+
assert result["chromadb"]["available"] is False
204218

205219

206220
async def test_intent_single_device_filter(tmp_path):
@@ -238,7 +252,22 @@ async def test_kb_basic_query():
238252

239253
# ── Adversarial input ────────────────────────────────────────────────────────
240254

241-
@pytest.mark.parametrize("bad_device", ["; rm -rf /", "' OR 1=1 --", "a" * 1000, ""])
242-
async def test_adversarial_device_name(bad_device):
243-
result = await get_ospf(OspfQuery(device=bad_device, query="neighbors"))
244-
assert "error" in result
255+
_BAD_DEVICES = [
256+
"; rm -rf /", "' OR 1=1 --", "a" * 65, "",
257+
"R1 extra", "R1\nshow run", "R1;cmd", "$(reboot)",
258+
"R1\x00cmd", "R1&& cat /etc/shadow",
259+
]
260+
261+
_VALID_DEVICES = ["R1", "my_device", "core-rtr_01", "A" * 64]
262+
263+
264+
@pytest.mark.parametrize("bad_device", _BAD_DEVICES)
265+
def test_adversarial_device_name_blocked(bad_device):
266+
with pytest.raises(ValidationError):
267+
OspfQuery(device=bad_device, query="neighbors")
268+
269+
270+
@pytest.mark.parametrize("good_device", _VALID_DEVICES)
271+
def test_valid_device_names_accepted(good_device):
272+
q = OspfQuery(device=good_device, query="neighbors")
273+
assert q.device == good_device

tools/operational.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,9 @@ async def traceroute(params: TracerouteInput) -> dict:
6060
command = f"{base_cmd} {params.destination}"
6161
if params.source:
6262
command += f" source {params.source}"
63-
if cli_style == "ios":
63+
if cli_style in ("ios", "eos", "vyos"):
6464
command += " probe 1 timeout 2"
65+
elif cli_style == "junos":
66+
command += " wait 1"
6567

6668
return await execute_command(params.device, command, timeout_ops=SSH_TIMEOUT_OPS_LONG)

tools/rag.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""RAG tool: search the network knowledge base."""
22
import asyncio
33
import logging
4+
import threading
45

56
from input_models.models import KBQuery
67
from tools import CHROMA_DIR
@@ -15,19 +16,22 @@
1516
# the device tools (get_ospf, get_interfaces) from loading.
1617
_embeddings = None
1718
_vectorstore = None
19+
_init_lock = threading.Lock()
1820

1921

2022
def _get_vectorstore():
2123
global _embeddings, _vectorstore
2224
if _vectorstore is None:
23-
from langchain_huggingface import HuggingFaceEmbeddings
24-
from langchain_chroma import Chroma
25-
_embeddings = HuggingFaceEmbeddings(model_name=_EMBEDDING_MODEL)
26-
_vectorstore = Chroma(
27-
persist_directory=_CHROMA_DIR,
28-
embedding_function=_embeddings,
29-
collection_name=_COLLECTION,
30-
)
25+
with _init_lock:
26+
if _vectorstore is None:
27+
from langchain_huggingface import HuggingFaceEmbeddings
28+
from langchain_chroma import Chroma
29+
_embeddings = HuggingFaceEmbeddings(model_name=_EMBEDDING_MODEL)
30+
_vectorstore = Chroma(
31+
persist_directory=_CHROMA_DIR,
32+
embedding_function=_embeddings,
33+
collection_name=_COLLECTION,
34+
)
3135
return _vectorstore
3236

3337

0 commit comments

Comments
 (0)