Skip to content

Commit ddd7b3f

Browse files
committed
Fix type errors and test failures
- Fix OAuthClientInformationFull initialization with empty defaults - Fix TokenError and AuthorizeError to use literal strings instead of enums - Fix token_type to use lowercase 'bearer' as required by OAuthToken - Fix potentially None values in _generate_mcp_token calls - Fix ProjectService repository.update calls to include entity_id - Fix prompt_router.py to handle undefined variables - Fix config.py field_validator return type - Fix external_auth_provider.py to add state field - Fix supabase_auth_provider.py redirect_uris parameter - Make Importer.handle_error abstract to fix test failures - Remove unused imports All type checks now pass and all tests are green.
1 parent 3595d6f commit ddd7b3f

10 files changed

Lines changed: 59 additions & 45 deletions

File tree

src/basic_memory/api/routers/prompt_router.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,9 @@ async def continue_conversation(
5353

5454
since = parse(request.timeframe) if request.timeframe else None
5555

56+
# Initialize search results
57+
search_results = []
58+
5659
# Get data needed for template
5760
if request.topic:
5861
query = SearchQuery(text=request.topic, after_date=request.timeframe)
@@ -122,9 +125,12 @@ async def continue_conversation(
122125
relation_count = 0
123126
entity_count = 0
124127

128+
# Get the hierarchical results from the template context
129+
hierarchical_results_for_count = template_context.get("hierarchical_results", [])
130+
125131
# For topic-based search
126132
if request.topic:
127-
for item in all_hierarchical_results:
133+
for item in hierarchical_results_for_count:
128134
if hasattr(item, "observations"):
129135
observation_count += len(item.observations) if item.observations else 0
130136

@@ -137,7 +143,7 @@ async def continue_conversation(
137143
entity_count += 1 # pragma: no cover
138144
# For recent activity
139145
else:
140-
for item in hierarchical_results:
146+
for item in hierarchical_results_for_count:
141147
if hasattr(item, "observations"):
142148
observation_count += len(item.observations) if item.observations else 0
143149

@@ -153,14 +159,14 @@ async def continue_conversation(
153159
metadata = {
154160
"query": request.topic,
155161
"timeframe": request.timeframe,
156-
"search_count": len(results) if request.topic else 0, # Original search results count
157-
"context_count": len(all_hierarchical_results)
162+
"search_count": len(search_results)
158163
if request.topic
159-
else len(hierarchical_results),
164+
else 0, # Original search results count
165+
"context_count": len(hierarchical_results_for_count),
160166
"observation_count": observation_count,
161167
"relation_count": relation_count,
162168
"total_items": (
163-
len(all_hierarchical_results if request.topic else hierarchical_results)
169+
len(hierarchical_results_for_count)
164170
+ observation_count
165171
+ relation_count
166172
+ entity_count

src/basic_memory/cli/commands/auth.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
auth_app = typer.Typer(help="OAuth client management commands")
1313
app.add_typer(auth_app, name="auth")
1414

15+
1516
@auth_app.command()
1617
def register_client(
1718
client_id: Optional[str] = typer.Option(
@@ -27,12 +28,10 @@ def register_client(
2728
# Create provider instance
2829
provider = BasicMemoryOAuthProvider(issuer_url=issuer_url)
2930

30-
31-
3231
# Create client info with required redirect_uris
3332
client_info = OAuthClientInformationFull(
34-
client_id=client_id,
35-
client_secret=client_secret,
33+
client_id=client_id or "", # Provider will generate if empty
34+
client_secret=client_secret or "", # Provider will generate if empty
3635
redirect_uris=[AnyHttpUrl("http://localhost:8000/callback")], # Default redirect URI
3736
client_name="Basic Memory OAuth Client",
3837
grant_types=["authorization_code", "refresh_token"],

src/basic_memory/cli/commands/db.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
from basic_memory import db
99
from basic_memory.cli.app import app
10-
from basic_memory.config import config
10+
from basic_memory.config import app_config
1111

1212

1313
@app.command()
@@ -18,15 +18,15 @@ def reset(
1818
if typer.confirm("This will delete all data in your db. Are you sure?"):
1919
logger.info("Resetting database...")
2020
# Get database path
21-
db_path = config.database_path
21+
db_path = app_config.app_database_path
2222

2323
# Delete the database file if it exists
2424
if db_path.exists():
2525
db_path.unlink()
2626
logger.info(f"Database file deleted: {db_path}")
2727

2828
# Create a new empty database
29-
asyncio.run(db.run_migrations(config))
29+
asyncio.run(db.run_migrations(app_config))
3030
logger.info("Database reset complete")
3131

3232
if reindex:

src/basic_memory/config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ def project_list(self) -> List[ProjectConfig]: # pragma: no cover
129129

130130
@field_validator("projects")
131131
@classmethod
132-
def ensure_project_paths_exists(cls, v: Dict[str, str]) -> Path: # pragma: no cover
132+
def ensure_project_paths_exists(cls, v: Dict[str, str]) -> Dict[str, str]: # pragma: no cover
133133
"""Ensure project path exists."""
134134
for name, path_value in v.items():
135135
path = Path(path_value)

src/basic_memory/importers/base.py

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ def ensure_folder_exists(self, folder: str) -> Path:
6363
folder_path.mkdir(parents=True, exist_ok=True)
6464
return folder_path
6565

66+
@abstractmethod
6667
def handle_error(
6768
self, message: str, error: Optional[Exception] = None
6869
) -> T: # pragma: no cover
@@ -75,13 +76,4 @@ def handle_error(
7576
Returns:
7677
ImportResult with error information.
7778
"""
78-
error_message = f"{message}"
79-
if error:
80-
error_message += f": {str(error)}"
81-
82-
logger.error(error_message)
83-
return ImportResult(
84-
import_count={},
85-
success=False,
86-
error_message=error_message,
87-
)
79+
pass

src/basic_memory/mcp/external_auth_provider.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ class ExternalAuthorizationCode(AuthorizationCode):
2222
"""Authorization code with external provider metadata."""
2323

2424
external_code: Optional[str] = None
25+
state: Optional[str] = None
2526

2627

2728
@dataclass
@@ -103,6 +104,7 @@ async def authorize(
103104
code_challenge=params.code_challenge,
104105
redirect_uri=params.redirect_uri,
105106
redirect_uri_provided_explicitly=params.redirect_uri_provided_explicitly,
107+
state=params.state,
106108
)
107109

108110
# Build external provider URL
@@ -153,6 +155,7 @@ async def handle_callback(self, code: str, state: str) -> str:
153155
redirect_uri=auth_code.redirect_uri,
154156
redirect_uri_provided_explicitly=auth_code.redirect_uri_provided_explicitly,
155157
external_code=code,
158+
state=auth_code.state,
156159
)
157160

158161
self.tokens[internal_code] = external_tokens
@@ -161,7 +164,7 @@ async def handle_callback(self, code: str, state: str) -> str:
161164
return construct_redirect_uri(
162165
str(auth_code.redirect_uri),
163166
code=internal_code,
164-
state=auth_code.issuer_state,
167+
state=auth_code.state,
165168
)
166169

167170
async def load_authorization_code(
@@ -206,7 +209,7 @@ async def exchange_authorization_code(
206209

207210
return OAuthToken(
208211
access_token=access_token,
209-
token_type="Bearer",
212+
token_type="bearer",
210213
expires_in=expires_in,
211214
refresh_token=refresh_token,
212215
scope=" ".join(authorization_code.scopes) if authorization_code.scopes else None,
@@ -269,7 +272,7 @@ async def exchange_refresh_token(
269272

270273
return OAuthToken(
271274
access_token=new_access_token,
272-
token_type="Bearer",
275+
token_type="bearer",
273276
expires_in=expires_in,
274277
refresh_token=new_refresh_token,
275278
scope=" ".join(scopes or refresh_token.scopes),

src/basic_memory/mcp/server.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,12 +53,12 @@ def create_auth_config() -> tuple[AuthSettings | None, Any | None]:
5353

5454
if os.getenv("FASTMCP_AUTH_ENABLED", "false").lower() == "true":
5555
from pydantic import AnyHttpUrl
56-
56+
5757
# Configure OAuth settings
5858
issuer_url = os.getenv("FASTMCP_AUTH_ISSUER_URL", "http://localhost:8000")
5959
required_scopes = os.getenv("FASTMCP_AUTH_REQUIRED_SCOPES", "read,write")
6060
docs_url = os.getenv("FASTMCP_AUTH_DOCS_URL") or "http://localhost:8000/docs/oauth"
61-
61+
6262
auth_settings = AuthSettings(
6363
issuer_url=AnyHttpUrl(issuer_url),
6464
service_documentation_url=AnyHttpUrl(docs_url),

src/basic_memory/mcp/supabase_auth_provider.py

Lines changed: 11 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,7 @@
1616
RefreshToken,
1717
AccessToken,
1818
TokenError,
19-
TokenErrorCode,
2019
AuthorizeError,
21-
AuthorizationErrorCode,
2220
)
2321
from mcp.shared.auth import OAuthClientInformationFull, OAuthToken
2422

@@ -91,6 +89,7 @@ async def get_client(self, client_id: str) -> Optional[OAuthClientInformationFul
9189
return OAuthClientInformationFull(
9290
client_id=client_id,
9391
client_secret="", # Supabase handles secrets
92+
redirect_uris=[], # Supabase handles redirect URIs
9493
)
9594

9695
return None
@@ -150,7 +149,7 @@ async def handle_supabase_callback(self, code: str, state: str) -> str:
150149
auth_request = self.pending_auth_codes.get(state)
151150
if not auth_request:
152151
raise AuthorizeError(
153-
error=AuthorizationErrorCode.INVALID_REQUEST,
152+
error="invalid_request",
154153
error_description="Invalid state parameter",
155154
)
156155

@@ -170,7 +169,7 @@ async def handle_supabase_callback(self, code: str, state: str) -> str:
170169

171170
if not token_response.is_success:
172171
raise AuthorizeError(
173-
error=AuthorizationErrorCode.SERVER_ERROR,
172+
error="server_error",
174173
error_description="Failed to exchange code with Supabase",
175174
)
176175

@@ -233,18 +232,16 @@ async def exchange_authorization_code(
233232
# Get stored Supabase tokens
234233
token_data = self.mcp_to_supabase_tokens.get(authorization_code.code)
235234
if not token_data:
236-
raise TokenError(
237-
error=TokenErrorCode.INVALID_GRANT, error_description="Invalid authorization code"
238-
)
235+
raise TokenError(error="invalid_grant", error_description="Invalid authorization code")
239236

240237
supabase_tokens = token_data["supabase_tokens"]
241238
user = token_data["user"]
242239

243240
# Generate MCP tokens that wrap Supabase tokens
244241
access_token = self._generate_mcp_token(
245242
client_id=client.client_id,
246-
user_id=user.get("id"),
247-
email=user.get("email"),
243+
user_id=user.get("id", ""),
244+
email=user.get("email", ""),
248245
scopes=authorization_code.scopes,
249246
supabase_access_token=supabase_tokens["access_token"],
250247
)
@@ -275,7 +272,7 @@ async def exchange_authorization_code(
275272

276273
return OAuthToken(
277274
access_token=access_token,
278-
token_type="Bearer",
275+
token_type="bearer",
279276
expires_in=supabase_tokens.get("expires_in", 3600),
280277
refresh_token=refresh_token,
281278
scope=" ".join(authorization_code.scopes) if authorization_code.scopes else None,
@@ -320,7 +317,7 @@ async def exchange_refresh_token(
320317

321318
if not token_response.is_success:
322319
raise TokenError(
323-
error=TokenErrorCode.INVALID_GRANT,
320+
error="invalid_grant",
324321
error_description="Failed to refresh with Supabase",
325322
)
326323

@@ -340,8 +337,8 @@ async def exchange_refresh_token(
340337
# Generate new MCP tokens
341338
new_access_token = self._generate_mcp_token(
342339
client_id=client.client_id,
343-
user_id=user_data.get("id"),
344-
email=user_data.get("email"),
340+
user_id=user_data.get("id", ""),
341+
email=user_data.get("email", ""),
345342
scopes=scopes or refresh_token.scopes,
346343
supabase_access_token=supabase_tokens["access_token"],
347344
)
@@ -370,7 +367,7 @@ async def exchange_refresh_token(
370367

371368
return OAuthToken(
372369
access_token=new_access_token,
373-
token_type="Bearer",
370+
token_type="bearer",
374371
expires_in=supabase_tokens.get("expires_in", 3600),
375372
refresh_token=new_refresh_token,
376373
scope=" ".join(scopes or refresh_token.scopes),

src/basic_memory/services/project_service.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -225,14 +225,14 @@ async def update_project( # pragma: no cover
225225

226226
# Update in database
227227
project.path = resolved_path
228-
await self.repository.update(project)
228+
await self.repository.update(project.id, project)
229229

230230
logger.info(f"Updated path for project '{name}' to {resolved_path}")
231231

232232
# Update active status if provided
233233
if is_active is not None:
234234
project.is_active = is_active
235-
await self.repository.update(project)
235+
await self.repository.update(project.id, project)
236236
logger.info(f"Set active status for project '{name}' to {is_active}")
237237

238238
# If project was made inactive and it was the default, we need to pick a new default

tests/importers/test_importer_base.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,23 @@ async def import_data(self, source_data, destination_folder: str, **kwargs):
2626
except Exception as e:
2727
return self.handle_error("Test import failed", e)
2828

29+
def handle_error(self, message: str, error=None) -> ImportResult:
30+
"""Implement the abstract handle_error method."""
31+
import logging
32+
33+
logger = logging.getLogger(__name__)
34+
35+
error_message = f"{message}"
36+
if error:
37+
error_message += f": {str(error)}"
38+
39+
logger.error(error_message)
40+
return ImportResult(
41+
import_count={},
42+
success=False,
43+
error_message=error_message,
44+
)
45+
2946

3047
@pytest.fixture
3148
def mock_markdown_processor():

0 commit comments

Comments
 (0)