Skip to content

Commit 3ea4238

Browse files
committed
add domain to aimodel
1 parent e7f45a1 commit 3ea4238

8 files changed

Lines changed: 34 additions & 4 deletions

File tree

api/admin.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ class AIModelAdmin(admin.ModelAdmin):
5959
list_filter = (
6060
"provider",
6161
"model_type",
62+
"domain",
6263
"status",
6364
"is_public",
6465
"is_active",
@@ -85,7 +86,7 @@ class AIModelAdmin(admin.ModelAdmin):
8586
"Schema",
8687
{"fields": ("input_schema", "output_schema"), "classes": ("collapse",)},
8788
),
88-
("Metadata", {"fields": ("tags", "metadata"), "classes": ("collapse",)}),
89+
("Metadata", {"fields": ("tags", "domain", "metadata"), "classes": ("collapse",)}),
8990
("Status & Visibility", {"fields": ("status", "is_public", "is_active")}),
9091
(
9192
"Performance Metrics",

api/models/AIModel.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
EndpointAuthType,
1313
EndpointHTTPMethod,
1414
HFModelClass,
15+
PromptDomain,
1516
)
1617

1718
User = get_user_model()
@@ -107,6 +108,13 @@ class AIModel(models.Model):
107108
tags = models.ManyToManyField("api.Tag", blank=True)
108109
sectors = models.ManyToManyField("api.Sector", blank=True, related_name="ai_models")
109110
geographies = models.ManyToManyField("api.Geography", blank=True, related_name="ai_models")
111+
domain = models.CharField(
112+
max_length=200,
113+
choices=PromptDomain.choices,
114+
blank=True,
115+
null=True,
116+
help_text="Domain or category (e.g., healthcare, education, legal)",
117+
)
110118
metadata = models.JSONField(
111119
default=dict,
112120
blank=True,

api/schema/aimodel_schema.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""GraphQL schema for AI Model."""
22

3-
# mypy: disable-error-code=union-attr
3+
# mypy: disable-error-code="union-attr,misc"
44

55
import datetime
66
from typing import List, Optional
@@ -27,6 +27,7 @@
2727
AIModelVersionOrder,
2828
EndpointAuthTypeEnum,
2929
EndpointHTTPMethodEnum,
30+
PromptDomainEnum,
3031
TypeAIModel,
3132
TypeAIModelVersion,
3233
TypeModelEndpoint,
@@ -95,6 +96,7 @@ class CreateAIModelInput:
9596
tags: Optional[List[str]] = None
9697
sectors: Optional[List[str]] = None
9798
geographies: Optional[List[int]] = None
99+
domain: Optional[PromptDomainEnum] = None
98100
metadata: Optional[strawberry.scalars.JSON] = None
99101
is_public: bool = False
100102

@@ -119,6 +121,7 @@ class UpdateAIModelInput:
119121
tags: Optional[List[str]] = None
120122
sectors: Optional[List[str]] = None
121123
geographies: Optional[List[int]] = None
124+
domain: Optional[PromptDomainEnum] = None
122125
metadata: Optional[strawberry.scalars.JSON] = None
123126
is_public: Optional[bool] = None
124127
is_active: Optional[bool] = None
@@ -441,6 +444,7 @@ def create_ai_model(
441444
supported_languages=supported_languages,
442445
input_schema=input_schema,
443446
output_schema=output_schema,
447+
domain=input.domain if input.domain else None,
444448
metadata=metadata,
445449
is_public=input.is_public,
446450
status="REGISTERED",
@@ -518,6 +522,8 @@ def update_ai_model(
518522
model.input_schema = input.input_schema
519523
if input.output_schema is not None:
520524
model.output_schema = input.output_schema
525+
if input.domain is not None:
526+
model.domain = input.domain
521527
if input.metadata is not None:
522528
model.metadata = input.metadata
523529
if input.is_public is not None:

api/types/type_aimodel.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
EndpointAuthType,
2727
EndpointHTTPMethod,
2828
HFModelClass,
29+
PromptDomain,
2930
)
3031
from authorization.types import TypeUser
3132

@@ -41,6 +42,7 @@
4142
AIModelFrameworkEnum = strawberry.enum(AIModelFramework) # type: ignore
4243
HFModelClassEnum = strawberry.enum(HFModelClass) # type: ignore
4344
AIModelLifecycleStageEnum = strawberry.enum(AIModelLifecycleStage) # type: ignore
45+
PromptDomainEnum = strawberry.enum(PromptDomain) # type: ignore
4446

4547

4648
@strawberry.type
@@ -83,6 +85,7 @@ class AIModelFilter:
8385
status: Optional[AIModelStatusEnum]
8486
model_type: Optional[AIModelTypeEnum]
8587
provider: Optional[AIModelProviderEnum]
88+
domain: Optional[PromptDomainEnum]
8689
is_public: Optional[bool]
8790
is_active: Optional[bool]
8891

@@ -123,6 +126,7 @@ class TypeAIModel(BaseType):
123126
supported_languages: Optional[strawberry.scalars.JSON]
124127
input_schema: Optional[strawberry.scalars.JSON]
125128
output_schema: Optional[strawberry.scalars.JSON]
129+
domain: Optional[PromptDomainEnum]
126130
metadata: Optional[strawberry.scalars.JSON]
127131
status: AIModelStatusEnum
128132
is_public: bool

api/views/aimodel_detail.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
"""API view for AI Model detail."""
22

3+
import logging
34
from typing import Any, Dict, List, Optional
45

5-
import logging
66
from rest_framework import serializers, status
77
from rest_framework.permissions import AllowAny
88
from rest_framework.request import Request
@@ -11,9 +11,9 @@
1111

1212
from api.models.AIModel import AIModel, ModelEndpoint
1313

14-
1514
logger = logging.getLogger(__name__)
1615

16+
1717
class ModelEndpointSerializer(serializers.ModelSerializer):
1818
"""Serializer for Model Endpoint."""
1919

@@ -59,6 +59,7 @@ class Meta:
5959
"tags",
6060
"sectors",
6161
"geographies",
62+
"domain",
6263
"metadata",
6364
"status",
6465
"is_public",

api/views/search_aimodel.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,7 @@ class Meta:
103103
"is_individual_model",
104104
"has_active_endpoints",
105105
"endpoint_count",
106+
"domain",
106107
"version_count",
107108
"lifecycle_stage",
108109
"all_providers",
@@ -154,6 +155,7 @@ def get_searchable_and_aggregations(self) -> Tuple[List[str], Dict[str, str]]:
154155
"supports_streaming": "terms",
155156
"lifecycle_stage": "terms",
156157
"all_providers": "terms",
158+
"domain": "terms",
157159
}
158160

159161
return searchable_fields, aggregations
@@ -208,6 +210,7 @@ def add_filters(self, filters: Dict[str, str], search: Search) -> Search:
208210
"supported_languages",
209211
"lifecycle_stage",
210212
"all_providers",
213+
"domain",
211214
]:
212215
# Handle single or multi-value filters
213216
filter_values = filters[filter_key].split(",")

dataspace_sdk/resources/aimodels.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ def search(
1717
status: Optional[str] = None,
1818
model_type: Optional[str] = None,
1919
provider: Optional[str] = None,
20+
domain: Optional[str] = None,
2021
sort: Optional[str] = None,
2122
page: int = 1,
2223
page_size: int = 10,
@@ -32,6 +33,7 @@ def search(
3233
status: Filter by status (ACTIVE, INACTIVE, etc.)
3334
model_type: Filter by model type (LLM, VISION, etc.)
3435
provider: Filter by provider (OPENAI, ANTHROPIC, etc.)
36+
domain: Filter by domain (HEALTHCARE, EDUCATION, etc.)
3537
sort: Sort order (recent, alphabetical)
3638
page: Page number (1-indexed)
3739
page_size: Number of results per page
@@ -58,6 +60,8 @@ def search(
5860
params["model_type"] = model_type
5961
if provider:
6062
params["provider"] = provider
63+
if domain:
64+
params["domain"] = domain
6165
if sort:
6266
params["sort"] = sort
6367

@@ -94,6 +98,7 @@ def get_by_id_graphql(self, model_id: str) -> Dict[str, Any]:
9498
displayName
9599
description
96100
modelType
101+
domain
97102
status
98103
isPublic
99104
createdAt
@@ -207,6 +212,7 @@ def list_all(
207212
displayName
208213
description
209214
modelType
215+
domain
210216
status
211217
isPublic
212218
createdAt

search/documents/aimodel_document.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ class AIModelDocument(Document):
5454

5555
# Model configuration
5656
model_type = fields.KeywordField()
57+
domain = fields.KeywordField()
5758

5859
# Status and visibility
5960
status = fields.KeywordField()

0 commit comments

Comments
 (0)