Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 51 additions & 0 deletions .github/workflows/ai-agent-ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,57 @@ on:
- '.github/workflows/ai-agent-ci.yml'

jobs:
lint:
name: Lint · Format
runs-on: ubuntu-latest

defaults:
run:
working-directory: apps/ai_agent

steps:
- name: Checkout
uses: actions/checkout@v4

- name: Setup uv
uses: astral-sh/setup-uv@v5
with:
enable-cache: true
cache-dependency-glob: 'apps/ai_agent/uv.lock'

- name: Install dependencies
run: uv sync --group dev

- name: Ruff check
run: uv run ruff check .

- name: Ruff format check
run: uv run ruff format --check .

typecheck:
name: Typecheck
runs-on: ubuntu-latest

defaults:
run:
working-directory: apps/ai_agent

steps:
- name: Checkout
uses: actions/checkout@v4

- name: Setup uv
uses: astral-sh/setup-uv@v5
with:
enable-cache: true
cache-dependency-glob: 'apps/ai_agent/uv.lock'

- name: Install dependencies
run: uv sync --group dev

- name: Mypy
run: uv run mypy main.py

test:
name: Test · Coverage
runs-on: ubuntu-latest
Expand Down
59 changes: 32 additions & 27 deletions apps/ai_agent/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,9 @@
from typing import Literal

import uvicorn
import weaviate
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel

import weaviate
from weaviate.classes.query import Filter

app = FastAPI(title="AI Agent API")
Expand All @@ -24,6 +23,7 @@

# ── Request / response models ─────────────────────────────────────────────────


class ChatRequest(BaseModel):
message: str
conversation_id: str
Expand Down Expand Up @@ -69,16 +69,19 @@ class ProposalSummariseResponse(BaseModel):

# ── Helpers ───────────────────────────────────────────────────────────────────


def _openai_client():
api_key = os.environ.get("OPENAI_API_KEY")
if not api_key:
raise HTTPException(status_code=500, detail="OPENAI_API_KEY is not configured")
from openai import OpenAI # imported lazily so missing package gives a clear error

return OpenAI(api_key=api_key)


# ── Endpoints ─────────────────────────────────────────────────────────────────


@app.get("/health")
def health_check():
return {"status": "ok"}
Expand Down Expand Up @@ -142,9 +145,9 @@ def summarise_proposal(request: ProposalSummariseRequest):
f"Description: {request.description}\n"
f"Amount: {request.amount} XLM\n\n"
"Reply with JSON only using keys: summary (a plain-English summary of "
"exactly 2 sentences), risk (one of \"low\", \"medium\", \"high\"). "
"Use \"high\" for large amounts, unclear intent, or obvious red flags; "
"\"low\" for small, well-scoped, low-impact proposals; otherwise \"medium\"."
'exactly 2 sentences), risk (one of "low", "medium", "high"). '
'Use "high" for large amounts, unclear intent, or obvious red flags; '
'"low" for small, well-scoped, low-impact proposals; otherwise "medium".'
)
response = client.chat.completions.create(
model="gpt-4o-mini",
Expand Down Expand Up @@ -172,20 +175,20 @@ def index_message(request: IndexMessageRequest):
try:
# Attempt connection to Weaviate
client = weaviate.connect_to_local()
except Exception as e:
except Exception:
raise HTTPException(status_code=503, detail="Weaviate connection failed")

try:
if not client.collections.exists("Message"):
client.collections.create(name="Message")

collection = client.collections.get("Message")

# Get embedding via OpenAI
openai_client = _openai_client()
res = openai_client.embeddings.create(input=request.content, model="text-embedding-3-small")
vector = res.data[0].embedding

# Upsert
if collection.data.exists(request.messageId):
collection.data.replace(
Expand All @@ -196,7 +199,7 @@ def index_message(request: IndexMessageRequest):
"senderId": request.senderId,
"content": request.content,
},
vector=vector
vector=vector,
)
else:
collection.data.insert(
Expand All @@ -207,49 +210,51 @@ def index_message(request: IndexMessageRequest):
"senderId": request.senderId,
"content": request.content,
},
vector=vector
vector=vector,
)
except Exception as e:
raise HTTPException(status_code=503, detail=str(e))
finally:
client.close()

return {"status": "ok"}


@app.get("/search")
def search_messages(q: str, conversationId: str):
try:
client = weaviate.connect_to_local()
except Exception as e:
except Exception:
raise HTTPException(status_code=503, detail="Weaviate connection failed")

try:
if not client.collections.exists("Message"):
return {"results": []}

collection = client.collections.get("Message")

# Get embedding for query
openai_client = _openai_client()
res = openai_client.embeddings.create(input=q, model="text-embedding-3-small")
vector = res.data[0].embedding

results = collection.query.near_vector(
near_vector=vector,
limit=5,
filters=Filter.by_property("conversationId").equal(conversationId)
filters=Filter.by_property("conversationId").equal(conversationId),
)

hits = []
for obj in results.objects:
hits.append({
"messageId": obj.properties.get("messageId"),
"conversationId": obj.properties.get("conversationId"),
"senderId": obj.properties.get("senderId"),
"content": obj.properties.get("content"),
})

hits.append(
{
"messageId": obj.properties.get("messageId"),
"conversationId": obj.properties.get("conversationId"),
"senderId": obj.properties.get("senderId"),
"content": obj.properties.get("content"),
}
)

return {"results": hits}
except Exception as e:
raise HTTPException(status_code=503, detail=str(e))
Expand Down
19 changes: 19 additions & 0 deletions apps/ai_agent/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@ dev = [
"pytest-mock>=3.14.0",
"pytest-cov>=5.0.0",
"pip-audit>=2.7",
"ruff>=0.4",
"mypy>=1.10",
"types-requests",
]

[dependency-groups]
Expand All @@ -27,8 +30,24 @@ dev = [
"httpx>=0.27.0",
"pytest-mock>=3.14.0",
"pip-audit>=2.7",
"ruff>=0.4",
"mypy>=1.10",
"types-requests",
]

[tool.ruff]
target-version = "py312"
line-length = 100

[tool.ruff.lint]
select = ["E", "F", "I", "W"]

[tool.mypy]
python_version = "3.12"
warn_return_any = true
warn_unused_configs = true
ignore_missing_imports = true

[tool.pytest.ini_options]
pythonpath = ["."]
testpaths = ["tests"]
Expand Down
2 changes: 1 addition & 1 deletion apps/ai_agent/tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
"""Shared pytest fixtures for the ai_agent test suite."""

import os
import pytest
from fastapi.testclient import TestClient

Expand All @@ -15,6 +14,7 @@ def set_openai_key(monkeypatch: pytest.MonkeyPatch) -> None:
def client() -> TestClient:
"""FastAPI TestClient for the main app."""
from main import app

return TestClient(app)


Expand Down
2 changes: 1 addition & 1 deletion apps/ai_agent/tests/test_proposals.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
"""Unit tests for POST /proposals/summarise (issue #147)."""

import json
import pytest
from unittest.mock import MagicMock, patch

from fastapi.testclient import TestClient

from main import app

client = TestClient(app)
Expand Down
20 changes: 13 additions & 7 deletions apps/ai_agent/tests/test_search.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
"""Unit tests for GET /search (issue #149)."""

import pytest
from unittest.mock import MagicMock, patch

from fastapi.testclient import TestClient

from main import app

client = TestClient(app)
Expand Down Expand Up @@ -62,8 +62,10 @@ def test_returns_results_with_correct_shape():
}
mock_wv = _make_weaviate_client(exists=True, objects=[obj])

with patch("main.weaviate.connect_to_local", return_value=mock_wv), \
patch("main._openai_client", return_value=_make_openai_embedding()):
with (
patch("main.weaviate.connect_to_local", return_value=mock_wv),
patch("main._openai_client", return_value=_make_openai_embedding()),
):
response = client.get("/search", params=_BASE_PARAMS)

assert response.status_code == 200
Expand All @@ -87,8 +89,10 @@ def test_filters_by_conversation_id():
}
mock_wv = _make_weaviate_client(exists=True, objects=[obj])

with patch("main.weaviate.connect_to_local", return_value=mock_wv), \
patch("main._openai_client", return_value=_make_openai_embedding()):
with (
patch("main.weaviate.connect_to_local", return_value=mock_wv),
patch("main._openai_client", return_value=_make_openai_embedding()),
):
response = client.get("/search", params={"q": "transfer", "conversationId": "conv-xyz"})

assert response.status_code == 200
Expand All @@ -103,8 +107,10 @@ def test_close_called_on_success():
mock_wv = _make_weaviate_client(exists=True, objects=[])
mock_wv.collections.get.return_value.query.near_vector.return_value.objects = []

with patch("main.weaviate.connect_to_local", return_value=mock_wv), \
patch("main._openai_client", return_value=_make_openai_embedding()):
with (
patch("main.weaviate.connect_to_local", return_value=mock_wv),
patch("main._openai_client", return_value=_make_openai_embedding()),
):
response = client.get("/search", params=_BASE_PARAMS)

assert response.status_code == 200
Expand Down
Loading
Loading