Skip to content

Commit 9bd86a3

Browse files
committed
add support to call model with a specific version
1 parent 3ee95ea commit 9bd86a3

2 files changed

Lines changed: 42 additions & 8 deletions

File tree

api/views/aimodel_execution.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -65,11 +65,21 @@ def call_aimodel(request: Request, model_id: str) -> Response:
6565
)
6666

6767
parameters = request.data.get("parameters", {})
68+
version_id = request.data.get("version_id")
6869

69-
# Get the primary version and provider
70-
primary_version = model.versions.filter(is_latest=True).first()
71-
if not primary_version:
72-
primary_version = model.versions.first()
70+
# Get the version - either specific version or primary (latest)
71+
if version_id:
72+
primary_version = model.versions.filter(id=version_id).first()
73+
if not primary_version:
74+
return Response(
75+
{"error": f"Version with ID {version_id} not found for this model"},
76+
status=status.HTTP_400_BAD_REQUEST,
77+
)
78+
else:
79+
# Fall back to primary (latest) version
80+
primary_version = model.versions.filter(is_latest=True).first()
81+
if not primary_version:
82+
primary_version = model.versions.first()
7383

7484
if not primary_version:
7585
return Response(

dataspace_sdk/resources/aimodels.py

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -337,7 +337,11 @@ def delete_model(self, model_id: str) -> Dict[str, Any]:
337337
return self.delete(f"/api/aimodels/{model_id}/")
338338

339339
def call_model(
340-
self, model_id: str, input_text: str, parameters: Optional[Dict[str, Any]] = None
340+
self,
341+
model_id: str,
342+
input_text: str,
343+
parameters: Optional[Dict[str, Any]] = None,
344+
version_id: Optional[int] = None,
341345
) -> Dict[str, Any]:
342346
"""
343347
Call an AI model with input text using the appropriate client (API or HuggingFace).
@@ -346,6 +350,7 @@ def call_model(
346350
model_id: UUID of the AI model
347351
input_text: Input text to process
348352
parameters: Optional parameters for the model call (temperature, max_tokens, etc.)
353+
version_id: Optional specific version ID to call (defaults to primary/latest version)
349354
350355
Returns:
351356
Dictionary containing model response:
@@ -358,13 +363,24 @@ def call_model(
358363
...
359364
}
360365
"""
366+
payload: Dict[str, Any] = {
367+
"input_text": input_text,
368+
"parameters": parameters or {},
369+
}
370+
if version_id is not None:
371+
payload["version_id"] = version_id
372+
361373
return self.post(
362374
f"/api/aimodels/{model_id}/call/",
363-
json_data={"input_text": input_text, "parameters": parameters or {}},
375+
json_data=payload,
364376
)
365377

366378
def call_model_async(
367-
self, model_id: str, input_text: str, parameters: Optional[Dict[str, Any]] = None
379+
self,
380+
model_id: str,
381+
input_text: str,
382+
parameters: Optional[Dict[str, Any]] = None,
383+
version_id: Optional[int] = None,
368384
) -> Dict[str, Any]:
369385
"""
370386
Call an AI model asynchronously (returns task ID for long-running operations).
@@ -373,6 +389,7 @@ def call_model_async(
373389
model_id: UUID of the AI model
374390
input_text: Input text to process
375391
parameters: Optional parameters for the model call
392+
version_id: Optional specific version ID to call (defaults to primary/latest version)
376393
377394
Returns:
378395
Dictionary containing task information:
@@ -382,9 +399,16 @@ def call_model_async(
382399
"model_id": str
383400
}
384401
"""
402+
payload: Dict[str, Any] = {
403+
"input_text": input_text,
404+
"parameters": parameters or {},
405+
}
406+
if version_id is not None:
407+
payload["version_id"] = version_id
408+
385409
return self.post(
386410
f"/api/aimodels/{model_id}/call-async/",
387-
json_data={"input_text": input_text, "parameters": parameters or {}},
411+
json_data=payload,
388412
)
389413

390414
# ==================== Version Management ====================

0 commit comments

Comments
 (0)