Skip to content

Commit ffc5b86

Browse files
committed
add collaborative to unfied search
1 parent 2394980 commit ffc5b86

1 file changed

Lines changed: 67 additions & 23 deletions

File tree

api/views/search_unified.py

Lines changed: 67 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,15 @@
1414

1515
from api.models import Dataset, Geography, Metadata, UseCase
1616
from api.models.AIModel import AIModel
17+
from api.models.Collaborative import Collaborative
1718
from api.utils.telemetry_utils import trace_method
1819
from DataSpace import settings
19-
from search.documents import AIModelDocument, DatasetDocument, UseCaseDocument
20+
from search.documents import (
21+
AIModelDocument,
22+
CollaborativeDocument,
23+
DatasetDocument,
24+
UseCaseDocument,
25+
)
2026

2127
logger = structlog.get_logger(__name__)
2228

@@ -25,7 +31,7 @@ class UnifiedSearchResultSerializer(serializers.Serializer):
2531
"""Serializer for unified search results."""
2632

2733
id = serializers.CharField()
28-
type = serializers.CharField() # 'dataset', 'usecase', or 'aimodel'
34+
type = serializers.CharField() # 'dataset', 'usecase', 'aimodel', or 'collaborative'
2935
title = serializers.CharField()
3036
description = serializers.CharField()
3137
slug = serializers.CharField(required=False)
@@ -102,6 +108,12 @@ def _get_index_names(self, types_list: List[str]) -> List[str]:
102108
)
103109
index_names.append(aimodel_index)
104110

111+
if "collaborative" in types_list:
112+
collaborative_index = settings.ELASTICSEARCH_INDEX_NAMES.get(
113+
"search.documents.collaborative_document", "collaborative"
114+
)
115+
index_names.append(collaborative_index)
116+
105117
return index_names
106118

107119
def _build_unified_query(self, query: str) -> ESQ:
@@ -172,6 +184,43 @@ def _build_unified_query(self, query: str) -> ESQ:
172184
]
173185
)
174186

187+
# Collaborative nested fields
188+
common_queries.extend(
189+
[
190+
ESQ(
191+
"nested",
192+
path="datasets",
193+
query=ESQ(
194+
"multi_match",
195+
query=query,
196+
fields=["datasets.title", "datasets.description"],
197+
fuzziness="AUTO",
198+
),
199+
ignore_unmapped=True,
200+
),
201+
ESQ(
202+
"nested",
203+
path="use_cases",
204+
query=ESQ(
205+
"multi_match",
206+
query=query,
207+
fields=["use_cases.title", "use_cases.summary"],
208+
fuzziness="AUTO",
209+
),
210+
ignore_unmapped=True,
211+
),
212+
ESQ(
213+
"nested",
214+
path="contributors",
215+
query=ESQ(
216+
"match",
217+
**{"contributors.name": {"query": query, "fuzziness": "AUTO"}},
218+
),
219+
ignore_unmapped=True,
220+
),
221+
]
222+
)
223+
175224
# Organization and user (common across types)
176225
common_queries.extend(
177226
[
@@ -187,9 +236,7 @@ def _build_unified_query(self, query: str) -> ESQ:
187236
ESQ(
188237
"nested",
189238
path="user",
190-
query=ESQ(
191-
"match", **{"user.name": {"query": query, "fuzziness": "AUTO"}}
192-
),
239+
query=ESQ("match", **{"user.name": {"query": query, "fuzziness": "AUTO"}}),
193240
ignore_unmapped=True,
194241
),
195242
]
@@ -209,9 +256,7 @@ def _apply_filters(self, search: Search, filters: Dict[str, str]) -> Search:
209256

210257
if "geographies" in filters:
211258
filter_values = filters["geographies"].split(",")
212-
filter_values = Geography.get_geography_names_with_descendants(
213-
filter_values
214-
)
259+
filter_values = Geography.get_geography_names_with_descendants(filter_values)
215260
search = search.filter("terms", **{"geographies.raw": filter_values})
216261

217262
if "status" in filters:
@@ -233,6 +278,8 @@ def _normalize_result(self, hit: Any) -> Dict[str, Any]:
233278
result["type"] = "usecase"
234279
elif "aimodel" in index_name:
235280
result["type"] = "aimodel"
281+
elif "collaborative" in index_name:
282+
result["type"] = "collaborative"
236283
else:
237284
result["type"] = "unknown"
238285

@@ -256,6 +303,11 @@ def _normalize_result(self, hit: Any) -> Dict[str, Any]:
256303
result["created"] = result["created_at"]
257304
if "updated_at" in result:
258305
result["modified"] = result["updated_at"]
306+
elif result["type"] == "collaborative":
307+
if "summary" in result:
308+
result["description"] = result.get("summary", "")
309+
if "title" not in result:
310+
result["title"] = ""
259311
else: # dataset
260312
if "title" not in result:
261313
result["title"] = ""
@@ -323,6 +375,8 @@ def perform_unified_search(
323375
aggregations["types"]["usecase"] = bucket["doc_count"]
324376
elif "aimodel" in index_name:
325377
aggregations["types"]["aimodel"] = bucket["doc_count"]
378+
elif "collaborative" in index_name:
379+
aggregations["types"]["collaborative"] = bucket["doc_count"]
326380

327381
# Process other aggregations
328382
for agg_name in ["tags", "sectors", "geographies", "status"]:
@@ -331,11 +385,7 @@ def perform_unified_search(
331385
for bucket in aggs_dict[agg_name]["buckets"]:
332386
aggregations[agg_name][bucket["key"]] = bucket["doc_count"]
333387

334-
total = (
335-
response.hits.total.value
336-
if hasattr(response.hits.total, "value")
337-
else len(results)
338-
)
388+
total = response.hits.total.value if hasattr(response.hits.total, "value") else len(results)
339389

340390
return results, total, aggregations
341391

@@ -347,7 +397,7 @@ def get(self, request: Any) -> Response:
347397
page: int = int(request.GET.get("page", 1))
348398
size: int = int(request.GET.get("size", 10))
349399
entity_types: str = request.GET.get(
350-
"types", "dataset,usecase,aimodel"
400+
"types", "dataset,usecase,aimodel,collaborative"
351401
) # Which entity types to search
352402

353403
# Parse entity types
@@ -383,9 +433,7 @@ def get(self, request: Any) -> Response:
383433
self.logger.error("unified_search_error", error=str(e), exc_info=True)
384434
return Response({"error": "An internal error has occurred."}, status=500)
385435

386-
def _build_aggregations(
387-
self, results: List[Dict[str, Any]]
388-
) -> Dict[str, Dict[str, int]]:
436+
def _build_aggregations(self, results: List[Dict[str, Any]]) -> Dict[str, Dict[str, int]]:
389437
"""Build aggregations from results."""
390438
aggregations: Dict[str, Dict[str, int]] = {
391439
"types": {},
@@ -398,19 +446,15 @@ def _build_aggregations(
398446
for result in results:
399447
# Count by type
400448
result_type = result.get("type", "unknown")
401-
aggregations["types"][result_type] = (
402-
aggregations["types"].get(result_type, 0) + 1
403-
)
449+
aggregations["types"][result_type] = aggregations["types"].get(result_type, 0) + 1
404450

405451
# Count by tags
406452
for tag in result.get("tags", []):
407453
aggregations["tags"][tag] = aggregations["tags"].get(tag, 0) + 1
408454

409455
# Count by sectors
410456
for sector in result.get("sectors", []):
411-
aggregations["sectors"][sector] = (
412-
aggregations["sectors"].get(sector, 0) + 1
413-
)
457+
aggregations["sectors"][sector] = aggregations["sectors"].get(sector, 0) + 1
414458

415459
# Count by geographies
416460
for geography in result.get("geographies", []):

0 commit comments

Comments
 (0)