Skip to content

Commit 7b1b3db

Browse files
Merge branch 'main' into refactor/select-core-rag
2 parents a1211eb + 32d394d commit 7b1b3db

15 files changed

Lines changed: 77 additions & 115 deletions

File tree

api/core/ops/aliyun_trace/utils.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,7 @@
2727
def get_user_id_from_message_data(message_data) -> str:
2828
user_id = message_data.from_account_id
2929
if message_data.from_end_user_id:
30-
end_user_data: EndUser | None = (
31-
db.session.query(EndUser).where(EndUser.id == message_data.from_end_user_id).first()
32-
)
30+
end_user_data: EndUser | None = db.session.get(EndUser, message_data.from_end_user_id)
3331
if end_user_data is not None:
3432
user_id = end_user_data.session_id
3533
return user_id

api/core/ops/arize_phoenix_trace/arize_phoenix_trace.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -410,9 +410,7 @@ def message_trace(self, trace_info: MessageTraceInfo):
410410

411411
# Add end user data if available
412412
if trace_info.message_data.from_end_user_id:
413-
end_user_data: EndUser | None = (
414-
db.session.query(EndUser).where(EndUser.id == trace_info.message_data.from_end_user_id).first()
415-
)
413+
end_user_data: EndUser | None = db.session.get(EndUser, trace_info.message_data.from_end_user_id)
416414
if end_user_data is not None:
417415
metadata["end_user_id"] = end_user_data.session_id
418416

api/core/ops/langfuse_trace/langfuse_trace.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -241,9 +241,7 @@ def message_trace(self, trace_info: MessageTraceInfo, **kwargs):
241241

242242
user_id = message_data.from_account_id
243243
if message_data.from_end_user_id:
244-
end_user_data: EndUser | None = (
245-
db.session.query(EndUser).where(EndUser.id == message_data.from_end_user_id).first()
246-
)
244+
end_user_data: EndUser | None = db.session.get(EndUser, message_data.from_end_user_id)
247245
if end_user_data is not None:
248246
user_id = end_user_data.session_id
249247
metadata["user_id"] = user_id

api/core/ops/langsmith_trace/langsmith_trace.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -259,9 +259,7 @@ def message_trace(self, trace_info: MessageTraceInfo):
259259
metadata["user_id"] = user_id
260260

261261
if message_data.from_end_user_id:
262-
end_user_data: EndUser | None = (
263-
db.session.query(EndUser).where(EndUser.id == message_data.from_end_user_id).first()
264-
)
262+
end_user_data: EndUser | None = db.session.get(EndUser, message_data.from_end_user_id)
265263
if end_user_data is not None:
266264
end_user_id = end_user_data.session_id
267265
metadata["end_user_id"] = end_user_id

api/core/ops/mlflow_trace/mlflow_trace.py

Lines changed: 6 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from mlflow.tracing.constant import SpanAttributeKey, TokenUsageKey, TraceMetadataKey
1010
from mlflow.tracing.fluent import start_span_no_context, update_current_trace
1111
from mlflow.tracing.provider import detach_span_from_context, set_span_in_context
12+
from sqlalchemy import select
1213

1314
from core.ops.base_trace_instance import BaseTraceInstance
1415
from core.ops.entities.config_entity import DatabricksConfig, MLflowConfig
@@ -320,7 +321,7 @@ def message_trace(self, trace_info: MessageTraceInfo):
320321

321322
def _get_message_user_id(self, metadata: dict) -> str | None:
322323
if (end_user_id := metadata.get("from_end_user_id")) and (
323-
end_user_data := db.session.query(EndUser).where(EndUser.id == end_user_id).first()
324+
end_user_data := db.session.get(EndUser, end_user_id)
324325
):
325326
return end_user_data.session_id
326327

@@ -447,25 +448,11 @@ def generate_name_trace(self, trace_info: GenerateNameTraceInfo):
447448

448449
def _get_workflow_nodes(self, workflow_run_id: str):
449450
"""Helper method to get workflow nodes"""
450-
workflow_nodes = (
451-
db.session.query(
452-
WorkflowNodeExecutionModel.id,
453-
WorkflowNodeExecutionModel.tenant_id,
454-
WorkflowNodeExecutionModel.app_id,
455-
WorkflowNodeExecutionModel.title,
456-
WorkflowNodeExecutionModel.node_type,
457-
WorkflowNodeExecutionModel.status,
458-
WorkflowNodeExecutionModel.inputs,
459-
WorkflowNodeExecutionModel.outputs,
460-
WorkflowNodeExecutionModel.created_at,
461-
WorkflowNodeExecutionModel.elapsed_time,
462-
WorkflowNodeExecutionModel.process_data,
463-
WorkflowNodeExecutionModel.execution_metadata,
464-
)
465-
.filter(WorkflowNodeExecutionModel.workflow_run_id == workflow_run_id)
451+
workflow_nodes = db.session.scalars(
452+
select(WorkflowNodeExecutionModel)
453+
.where(WorkflowNodeExecutionModel.workflow_run_id == workflow_run_id)
466454
.order_by(WorkflowNodeExecutionModel.created_at)
467-
.all()
468-
)
455+
).all()
469456
return workflow_nodes
470457

471458
def _get_node_span_type(self, node_type: str) -> str:

api/core/ops/opik_trace/opik_trace.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -288,9 +288,7 @@ def message_trace(self, trace_info: MessageTraceInfo):
288288
metadata["file_list"] = file_list
289289

290290
if message_data.from_end_user_id:
291-
end_user_data: EndUser | None = (
292-
db.session.query(EndUser).where(EndUser.id == message_data.from_end_user_id).first()
293-
)
291+
end_user_data: EndUser | None = db.session.get(EndUser, message_data.from_end_user_id)
294292
if end_user_data is not None:
295293
end_user_id = end_user_data.session_id
296294
metadata["end_user_id"] = end_user_id

api/core/ops/ops_trace_manager.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -420,10 +420,10 @@ def get_decrypted_tracing_config(cls, app_id: str, tracing_provider: str):
420420
:param tracing_provider: tracing provider
421421
:return:
422422
"""
423-
trace_config_data: TraceAppConfig | None = (
424-
db.session.query(TraceAppConfig)
423+
trace_config_data: TraceAppConfig | None = db.session.scalar(
424+
select(TraceAppConfig)
425425
.where(TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider)
426-
.first()
426+
.limit(1)
427427
)
428428

429429
if not trace_config_data:
@@ -463,7 +463,7 @@ def get_ops_trace_instance(
463463
if isinstance(app_id, str) and app_id.startswith("tenant-"):
464464
return None
465465

466-
app: App | None = db.session.query(App).where(App.id == app_id).first()
466+
app = db.session.get(App, app_id)
467467

468468
if app is None:
469469
return None
@@ -537,7 +537,7 @@ def update_app_tracing_config(cls, app_id: str, enabled: bool, tracing_provider:
537537
except KeyError:
538538
raise ValueError(f"Invalid tracing provider: {tracing_provider}")
539539

540-
app_config: App | None = db.session.query(App).where(App.id == app_id).first()
540+
app_config: App | None = db.session.get(App, app_id)
541541
if not app_config:
542542
raise ValueError("App not found")
543543
app_config.tracing = json.dumps(
@@ -555,7 +555,7 @@ def get_app_tracing_config(cls, app_id: str):
555555
:param app_id: app id
556556
:return:
557557
"""
558-
app: App | None = db.session.query(App).where(App.id == app_id).first()
558+
app: App | None = db.session.get(App, app_id)
559559
if not app:
560560
raise ValueError("App not found")
561561
if not app.tracing:
@@ -883,7 +883,7 @@ def message_trace(self, message_id: str | None, **kwargs):
883883
inputs = message_data.message
884884

885885
# get message file data
886-
message_file_data = db.session.query(MessageFile).filter_by(message_id=message_id).first()
886+
message_file_data = db.session.scalar(select(MessageFile).where(MessageFile.message_id == message_id).limit(1))
887887
file_list = []
888888
if message_file_data and message_file_data.url is not None:
889889
file_url = f"{self.file_base_url}/{message_file_data.url}" if message_file_data else ""
@@ -972,8 +972,8 @@ def moderation_trace(self, message_id, timer, **kwargs):
972972
# get workflow_app_log_id
973973
workflow_app_log_id = None
974974
if message_data.workflow_run_id:
975-
workflow_app_log_data = (
976-
db.session.query(WorkflowAppLog).filter_by(workflow_run_id=message_data.workflow_run_id).first()
975+
workflow_app_log_data = db.session.scalar(
976+
select(WorkflowAppLog).where(WorkflowAppLog.workflow_run_id == message_data.workflow_run_id).limit(1)
977977
)
978978
workflow_app_log_id = str(workflow_app_log_data.id) if workflow_app_log_data else None
979979

@@ -1015,8 +1015,8 @@ def suggested_question_trace(self, message_id, timer, **kwargs):
10151015
# get workflow_app_log_id
10161016
workflow_app_log_id = None
10171017
if message_data.workflow_run_id:
1018-
workflow_app_log_data = (
1019-
db.session.query(WorkflowAppLog).filter_by(workflow_run_id=message_data.workflow_run_id).first()
1018+
workflow_app_log_data = db.session.scalar(
1019+
select(WorkflowAppLog).where(WorkflowAppLog.workflow_run_id == message_data.workflow_run_id).limit(1)
10201020
)
10211021
workflow_app_log_id = str(workflow_app_log_data.id) if workflow_app_log_data else None
10221022

@@ -1171,7 +1171,7 @@ def tool_trace(self, message_id, timer, **kwargs):
11711171
metadata["node_execution_id"] = node_execution_id
11721172

11731173
file_url = ""
1174-
message_file_data = db.session.query(MessageFile).filter_by(message_id=message_id).first()
1174+
message_file_data = db.session.scalar(select(MessageFile).where(MessageFile.message_id == message_id).limit(1))
11751175
if message_file_data:
11761176
message_file_id = message_file_data.id if message_file_data else None
11771177
type = message_file_data.type

api/core/ops/weave_trace/weave_trace.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -245,9 +245,7 @@ def message_trace(self, trace_info: MessageTraceInfo):
245245
attributes["user_id"] = user_id
246246

247247
if message_data.from_end_user_id:
248-
end_user_data: EndUser | None = (
249-
db.session.query(EndUser).where(EndUser.id == message_data.from_end_user_id).first()
250-
)
248+
end_user_data: EndUser | None = db.session.get(EndUser, message_data.from_end_user_id)
251249
if end_user_data is not None:
252250
end_user_id = end_user_data.session_id
253251
attributes["end_user_id"] = end_user_id

api/tests/unit_tests/core/ops/aliyun_trace/test_aliyun_trace_utils.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -45,11 +45,8 @@ def test_get_user_id_from_message_data_with_end_user(monkeypatch):
4545
end_user_data = MagicMock(spec=EndUser)
4646
end_user_data.session_id = "session_id"
4747

48-
mock_query = MagicMock()
49-
mock_query.where.return_value.first.return_value = end_user_data
50-
5148
mock_session = MagicMock()
52-
mock_session.query.return_value = mock_query
49+
mock_session.get.return_value = end_user_data
5350

5451
from core.ops.aliyun_trace.utils import db
5552

@@ -63,11 +60,8 @@ def test_get_user_id_from_message_data_end_user_not_found(monkeypatch):
6360
message_data.from_account_id = "account_id"
6461
message_data.from_end_user_id = "end_user_id"
6562

66-
mock_query = MagicMock()
67-
mock_query.where.return_value.first.return_value = None
68-
6963
mock_session = MagicMock()
70-
mock_session.query.return_value = mock_query
64+
mock_session.get.return_value = None
7165

7266
from core.ops.aliyun_trace.utils import db
7367

api/tests/unit_tests/core/ops/langfuse_trace/test_langfuse_trace.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -365,9 +365,7 @@ def test_message_trace_with_end_user(trace_instance, monkeypatch):
365365
mock_end_user = MagicMock(spec=EndUser)
366366
mock_end_user.session_id = "session-id-123"
367367

368-
mock_query = MagicMock()
369-
mock_query.where.return_value.first.return_value = mock_end_user
370-
monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.db.session.query", lambda model: mock_query)
368+
monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.db.session.get", lambda model, pk: mock_end_user)
371369

372370
trace_instance.add_trace = MagicMock()
373371
trace_instance.add_generation = MagicMock()

0 commit comments

Comments
 (0)