Skip to content

Commit 1aefe82

Browse files
committed
add dataset count to collaborative document
1 parent 585da03 commit 1aefe82

6 files changed

Lines changed: 725 additions & 8 deletions

File tree

DataSpace/settings.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -278,6 +278,9 @@
278278
"search.documents.usecase_document": "usecase",
279279
"search.documents.aimodel_document": "aimodel",
280280
"search.documents.collaborative_document": "collaborative",
281+
"search.documents.publisher_document.OrganizationPublisherDocument": "organization_publisher",
282+
"search.documents.publisher_document.UserPublisherDocument": "user_publisher",
283+
"search.documents.publisher_document": "publisher",
281284
}
282285

283286

api/urls.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
search_aimodel,
1717
search_collaborative,
1818
search_dataset,
19+
search_publisher,
1920
search_unified,
2021
search_usecase,
2122
trending_datasets,
@@ -55,6 +56,7 @@
5556
name="search_collaborative",
5657
),
5758
path("search/unified/", search_unified.UnifiedSearch.as_view(), name="search_unified"),
59+
path("search/publisher/", search_publisher.SearchPublisher.as_view(), name="search_publisher"),
5860
path(
5961
"aimodels/<model_id>/",
6062
aimodel_detail.AIModelDetailView.as_view(),

api/views/search_publisher.py

Lines changed: 306 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,306 @@
1+
from typing import Any, Dict, List, Tuple
2+
3+
import structlog
4+
from elasticsearch_dsl import Q as ESQ
5+
from elasticsearch_dsl import Search
6+
from rest_framework import serializers
7+
from rest_framework.permissions import AllowAny
8+
from rest_framework.response import Response
9+
10+
from api.utils.telemetry_utils import trace_method, track_metrics
11+
from api.views.paginated_elastic_view import PaginatedElasticSearchAPIView
12+
from search.documents import OrganizationPublisherDocument, UserPublisherDocument
13+
14+
logger = structlog.get_logger(__name__)
15+
16+
17+
class PublisherDocumentSerializer(serializers.Serializer):
18+
"""Serializer for Publisher document (both Organization and User)."""
19+
20+
id = serializers.CharField()
21+
name = serializers.CharField()
22+
description = serializers.CharField()
23+
publisher_type = serializers.CharField() # 'organization' or 'user'
24+
logo = serializers.CharField(required=False)
25+
slug = serializers.CharField(required=False)
26+
created = serializers.DateTimeField(required=False)
27+
modified = serializers.DateTimeField(required=False)
28+
29+
# Counts
30+
published_datasets_count = serializers.IntegerField()
31+
published_usecases_count = serializers.IntegerField()
32+
members_count = serializers.IntegerField(required=False) # Only for organizations
33+
contributed_sectors_count = serializers.IntegerField()
34+
35+
# Organization specific fields
36+
homepage = serializers.CharField(required=False)
37+
contact_email = serializers.CharField(required=False)
38+
organization_types = serializers.CharField(required=False)
39+
github_profile = serializers.CharField(required=False)
40+
linkedin_profile = serializers.CharField(required=False)
41+
twitter_profile = serializers.CharField(required=False)
42+
location = serializers.CharField(required=False)
43+
44+
# User specific fields
45+
bio = serializers.CharField(required=False)
46+
profile_picture = serializers.CharField(required=False)
47+
username = serializers.CharField(required=False)
48+
email = serializers.CharField(required=False)
49+
first_name = serializers.CharField(required=False)
50+
last_name = serializers.CharField(required=False)
51+
full_name = serializers.CharField(required=False)
52+
53+
# Search fields
54+
sectors = serializers.ListField(required=False)
55+
56+
57+
class SearchPublisher(PaginatedElasticSearchAPIView):
58+
"""API view for searching publishers (organizations and users)."""
59+
60+
serializer_class = PublisherDocumentSerializer
61+
permission_classes = [AllowAny]
62+
63+
def __init__(self, **kwargs: Any) -> None:
64+
super().__init__(**kwargs)
65+
self.logger = structlog.get_logger(__name__)
66+
67+
def get_document_classes(self) -> List[Any]:
68+
"""Return the document classes to search."""
69+
return [OrganizationPublisherDocument, UserPublisherDocument]
70+
71+
def get_index_names(self) -> List[str]:
72+
"""Get the index names for publisher search."""
73+
from DataSpace import settings
74+
75+
org_index = settings.ELASTICSEARCH_INDEX_NAMES.get(
76+
"search.documents.publisher_document.OrganizationPublisherDocument",
77+
"organization_publisher",
78+
)
79+
user_index = settings.ELASTICSEARCH_INDEX_NAMES.get(
80+
"search.documents.publisher_document.UserPublisherDocument", "user_publisher"
81+
)
82+
return [org_index, user_index]
83+
84+
@trace_method(name="build_query", attributes={"component": "publisher_search"})
85+
def build_query(self, query: str) -> ESQ:
86+
"""Build the Elasticsearch query for publisher search."""
87+
if not query:
88+
return ESQ("match_all")
89+
90+
# Multi-field search with boosting
91+
queries = [
92+
ESQ(
93+
"multi_match",
94+
query=query,
95+
fields=["name^3", "full_name^3"], # Boost name fields
96+
fuzziness="AUTO",
97+
),
98+
ESQ(
99+
"multi_match",
100+
query=query,
101+
fields=["description^2", "bio^2"], # Boost description/bio
102+
fuzziness="AUTO",
103+
),
104+
ESQ(
105+
"multi_match",
106+
query=query,
107+
fields=["sectors^2"], # Boost sectors
108+
fuzziness="AUTO",
109+
),
110+
ESQ(
111+
"multi_match",
112+
query=query,
113+
fields=[
114+
"username",
115+
"email",
116+
"location",
117+
"organization_types",
118+
"first_name",
119+
"last_name",
120+
],
121+
fuzziness="AUTO",
122+
),
123+
]
124+
125+
return ESQ("bool", should=queries, minimum_should_match=1)
126+
127+
@trace_method(name="apply_filters", attributes={"component": "publisher_search"})
128+
def apply_filters(self, search: Search, filters: Dict[str, str]) -> Search:
129+
"""Apply filters to the search query."""
130+
131+
if "publisher_type" in filters:
132+
# Filter by publisher type (organization or user)
133+
search = search.filter("term", publisher_type=filters["publisher_type"])
134+
135+
if "sectors" in filters:
136+
# Filter by sectors
137+
filter_values = filters["sectors"].split(",")
138+
search = search.filter("terms", **{"sectors.raw": filter_values})
139+
140+
if "organization_types" in filters:
141+
# Filter by organization types
142+
search = search.filter("term", organization_types=filters["organization_types"])
143+
144+
if "location" in filters:
145+
# Filter by location (fuzzy match)
146+
search = search.filter("match", location=filters["location"])
147+
148+
return search
149+
150+
@trace_method(name="build_aggregations", attributes={"component": "publisher_search"})
151+
def build_aggregations(self, search: Search) -> Search:
152+
"""Build aggregations for faceted search."""
153+
154+
# Publisher type aggregation
155+
search.aggs.bucket("publisher_type", "terms", field="publisher_type")
156+
157+
# Sectors aggregation
158+
search.aggs.bucket("sectors", "terms", field="sectors.raw", size=50)
159+
160+
# Organization types aggregation
161+
search.aggs.bucket("organization_types", "terms", field="organization_types", size=20)
162+
163+
# Location aggregation (top 20 locations)
164+
search.aggs.bucket("locations", "terms", field="location.raw", size=20)
165+
166+
return search
167+
168+
@trace_method(name="apply_sorting", attributes={"component": "publisher_search"})
169+
def apply_sorting(self, search: Search, sort_by: str) -> Search:
170+
"""Apply sorting to the search query."""
171+
172+
if sort_by == "alphabetical":
173+
search = search.sort("name.raw")
174+
elif sort_by == "datasets_count":
175+
search = search.sort({"published_datasets_count": {"order": "desc"}})
176+
elif sort_by == "usecases_count":
177+
search = search.sort({"published_usecases_count": {"order": "desc"}})
178+
elif sort_by == "total_contributions":
179+
# Sort by total datasets + usecases
180+
search = search.sort(
181+
{
182+
"_script": {
183+
"type": "number",
184+
"script": {
185+
"source": "doc['published_datasets_count'].value + doc['published_usecases_count'].value"
186+
},
187+
"order": "desc",
188+
}
189+
}
190+
)
191+
elif sort_by == "members_count":
192+
# Only applicable to organizations
193+
search = search.sort({"members_count": {"order": "desc"}})
194+
elif sort_by == "recent":
195+
search = search.sort({"created": {"order": "desc"}})
196+
else:
197+
# Default: relevance score
198+
pass
199+
200+
return search
201+
202+
@trace_method(name="perform_search", attributes={"component": "publisher_search"})
203+
def perform_search(
204+
self,
205+
query: str,
206+
filters: Dict[str, str],
207+
page: int,
208+
size: int,
209+
sort_by: str = "relevance",
210+
) -> Tuple[List[Dict[str, Any]], int, Dict[str, Any]]:
211+
"""Perform the publisher search."""
212+
213+
# Get index names
214+
index_names = self.get_index_names()
215+
216+
if not index_names:
217+
return [], 0, {}
218+
219+
# Create multi-index search
220+
search = Search(index=index_names)
221+
222+
# Build and apply query
223+
q = self.build_query(query)
224+
search = search.query(q)
225+
226+
# Apply filters
227+
search = self.apply_filters(search, filters)
228+
229+
# Apply sorting
230+
search = self.apply_sorting(search, sort_by)
231+
232+
# Build aggregations
233+
search = self.build_aggregations(search)
234+
235+
# Pagination
236+
start = (page - 1) * size
237+
search = search[start : start + size]
238+
239+
# Execute search
240+
try:
241+
response = search.execute()
242+
except Exception as e:
243+
self.logger.error("publisher_search_error", error=str(e), exc_info=True)
244+
return [], 0, {}
245+
246+
# Process results
247+
results = []
248+
for hit in response:
249+
result = hit.to_dict()
250+
result["_score"] = hit.meta.score
251+
result["_index"] = hit.meta.index
252+
results.append(result)
253+
254+
# Process aggregations
255+
aggregations: Dict[str, Any] = {}
256+
if hasattr(response, "aggregations"):
257+
aggs_dict = response.aggregations.to_dict()
258+
259+
for agg_name in ["publisher_type", "sectors", "organization_types", "locations"]:
260+
if agg_name in aggs_dict:
261+
aggregations[agg_name] = {}
262+
for bucket in aggs_dict[agg_name]["buckets"]:
263+
aggregations[agg_name][bucket["key"]] = bucket["doc_count"]
264+
265+
total = response.hits.total.value if hasattr(response.hits.total, "value") else len(results)
266+
267+
return results, total, aggregations
268+
269+
@trace_method(name="get", attributes={"component": "publisher_search"})
270+
@track_metrics(name="publisher_search")
271+
def get(self, request: Any) -> Response:
272+
"""Handle GET request and return search results."""
273+
try:
274+
query: str = request.GET.get("query", "")
275+
page: int = int(request.GET.get("page", 1))
276+
size: int = int(request.GET.get("size", 10))
277+
sort_by: str = request.GET.get("sort", "relevance")
278+
279+
# Handle filters
280+
filters: Dict[str, str] = {}
281+
for key, values in request.GET.lists():
282+
if key not in ["query", "page", "size", "sort"]:
283+
if len(values) > 1:
284+
filters[key] = ",".join(values)
285+
else:
286+
filters[key] = values[0]
287+
288+
# Perform search
289+
results, total, aggregations = self.perform_search(query, filters, page, size, sort_by)
290+
291+
# Serialize results
292+
serializer = self.serializer_class(results, many=True)
293+
294+
return Response(
295+
{
296+
"results": serializer.data,
297+
"total": total,
298+
"page": page,
299+
"size": size,
300+
"aggregations": aggregations,
301+
}
302+
)
303+
304+
except Exception as e:
305+
self.logger.error("publisher_search_error", error=str(e), exc_info=True)
306+
return Response({"error": "An internal error has occurred."}, status=500)

0 commit comments

Comments
 (0)