Skip to content

Commit fdb3021

Browse files
committed
style: format updated files with black
1 parent 69c736d commit fdb3021

5 files changed

Lines changed: 21 additions & 27 deletions

File tree

app/cache.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -302,9 +302,7 @@ async def get(
302302
self._stats.misses += 1
303303
elapsed = (time.perf_counter() - start_time) * 1000
304304
self._update_miss_time(elapsed)
305-
self._record_cache_operation(
306-
namespace.value, "get", "miss", elapsed / 1000
307-
)
305+
self._record_cache_operation(namespace.value, "get", "miss", elapsed / 1000)
308306
logger.debug("cache_miss", key=key, namespace=namespace.value)
309307
return None
310308

@@ -362,9 +360,7 @@ async def set(
362360
namespace=namespace.value,
363361
ttl=effective_ttl,
364362
)
365-
self._record_cache_operation(
366-
namespace.value, "set", "success", 0.0
367-
)
363+
self._record_cache_operation(namespace.value, "set", "success", 0.0)
368364
return True
369365

370366
# Fall back to in-memory cache

app/text2sql_engine.py

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -358,14 +358,18 @@ def validate(
358358
syntax_valid, syntax_errors = self._check_syntax(sql)
359359
if not syntax_valid:
360360
errors.extend(syntax_errors)
361-
self._record_validation_metrics(database_id, ValidationStatus.INVALID_SYNTAX)
361+
self._record_validation_metrics(
362+
database_id, ValidationStatus.INVALID_SYNTAX
363+
)
362364
return ValidationStatus.INVALID_SYNTAX, errors
363365

364366
# Security check
365367
is_safe, security_errors = self._check_security(sql)
366368
if not is_safe:
367369
errors.extend(security_errors)
368-
self._record_validation_metrics(database_id, ValidationStatus.DANGEROUS_QUERY)
370+
self._record_validation_metrics(
371+
database_id, ValidationStatus.DANGEROUS_QUERY
372+
)
369373
return ValidationStatus.DANGEROUS_QUERY, errors
370374

371375
# Schema alignment check
@@ -392,7 +396,9 @@ def _record_validation_metrics(
392396
valid=status == ValidationStatus.VALID,
393397
)
394398
if status == ValidationStatus.INVALID_SYNTAX:
395-
metrics.record_sql_syntax_error(database_id=database_id, error_type="syntax")
399+
metrics.record_sql_syntax_error(
400+
database_id=database_id, error_type="syntax"
401+
)
396402

397403
def _check_syntax(self, sql: str) -> tuple[bool, list[str]]:
398404
"""Check basic SQL syntax."""
@@ -801,11 +807,7 @@ async def generate_sql(
801807
if show_reasoning:
802808
reasoning_trace[-1].observation = "SQL validation passed"
803809

804-
if (
805-
self._enable_validation
806-
and valid_syntax
807-
and inference_result.sql
808-
):
810+
if self._enable_validation and valid_syntax and inference_result.sql:
809811
semantic_feedback = self._validator.check_semantics(
810812
natural_query=natural_query,
811813
sql=inference_result.sql,
@@ -1066,10 +1068,9 @@ def _build_prompt_cache_key(
10661068
"use_default_examples": use_default_examples,
10671069
}
10681070

1069-
return (
1070-
hashlib.sha256(json.dumps(key_data, sort_keys=True).encode())
1071-
.hexdigest()[:32]
1072-
)
1071+
return hashlib.sha256(
1072+
json.dumps(key_data, sort_keys=True).encode()
1073+
).hexdigest()[:32]
10731074

10741075
async def _generate_with_retry(
10751076
self,

db/executor.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -269,9 +269,7 @@ async def execute(
269269
self._execute_query(sql, params),
270270
timeout=timeout,
271271
)
272-
self._tracing.set_span_attribute(
273-
"db.rows_returned", result.row_count
274-
)
272+
self._tracing.set_span_attribute("db.rows_returned", result.row_count)
275273

276274
execution_time_ms = (time.perf_counter() - start_time) * 1000
277275
self._metrics.record_sql_query(

db/schema.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -112,12 +112,9 @@ def from_dict(cls, data: dict[str, Any]) -> "TableInfo":
112112
"""Create from dictionary."""
113113
return cls(
114114
name=data.get("name", ""),
115-
columns=[
116-
ColumnInfo.from_dict(item) for item in data.get("columns", [])
117-
],
115+
columns=[ColumnInfo.from_dict(item) for item in data.get("columns", [])],
118116
foreign_keys=[
119-
ForeignKeyInfo.from_dict(item)
120-
for item in data.get("foreign_keys", [])
117+
ForeignKeyInfo.from_dict(item) for item in data.get("foreign_keys", [])
121118
],
122119
primary_keys=list(data.get("primary_keys", [])),
123120
row_count=data.get("row_count"),

tests/unit/test_text2sql_engine.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -392,7 +392,9 @@ def test_semantic_top_limit_warning(self) -> None:
392392
)
393393

394394
assert any("limit" in warning.lower() for warning in feedback.warnings)
395-
assert any("limit 5" in suggestion.lower() for suggestion in feedback.suggestions)
395+
assert any(
396+
"limit 5" in suggestion.lower() for suggestion in feedback.suggestions
397+
)
396398

397399

398400
# =============================================================================

0 commit comments

Comments
 (0)