Skip to content

Commit f902b40

Browse files
committed
add file support for TCI
- updated run method with files and validation - added tests for files - added a FileInput type for validation help
1 parent 367d2bd commit f902b40

3 files changed

Lines changed: 151 additions & 3 deletions

File tree

src/together/resources/code_interpreter.py

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
from __future__ import annotations
22

3-
from typing import Dict, Literal, Optional
3+
from typing import Any, Dict, List, Literal, Optional
4+
from pydantic import ValidationError
45

56
from together.abstract import api_requestor
67
from together.together_response import TogetherResponse
78
from together.types import TogetherClient, TogetherRequest
8-
from together.types.code_interpreter import ExecuteResponse
9+
from together.types.code_interpreter import ExecuteResponse, FileInput
910

1011

1112
class CodeInterpreter:
@@ -19,16 +20,22 @@ def run(
1920
code: str,
2021
language: Literal["python"],
2122
session_id: Optional[str] = None,
23+
files: Optional[List[Dict[str, Any]]] = None,
2224
) -> ExecuteResponse:
23-
"""Execute a code snippet.
25+
"""Execute a code snippet, optionally with files.
2426
2527
Args:
2628
code (str): Code snippet to execute
2729
language (str): Programming language for the code to execute. Currently only supports Python.
2830
session_id (str, optional): Identifier of the current session. Used to make follow-up calls.
31+
files (List[Dict], optional): Files to upload to the session before executing the code.
2932
3033
Returns:
3134
ExecuteResponse: Object containing execution results and outputs
35+
36+
Raises:
37+
ValidationError: If any dictionary in the `files` list does not conform to the
38+
required structure or types.
3239
"""
3340
requestor = api_requestor.APIRequestor(
3441
client=self._client,
@@ -42,6 +49,21 @@ def run(
4249
if session_id is not None:
4350
data["session_id"] = session_id
4451

52+
if files is not None:
53+
serialized_files = []
54+
try:
55+
for file_dict in files:
56+
# Validate the dictionary by creating a FileInput instance
57+
validated_file = FileInput(**file_dict)
58+
# Serialize the validated model back to a dict for the API call
59+
serialized_files.append(validated_file.model_dump())
60+
except ValidationError as e:
61+
raise ValueError(f"Invalid file input format: {e}") from e
62+
except TypeError as e:
63+
raise ValueError(f"Invalid file input: Each item in 'files' must be a dictionary. Error: {e}") from e
64+
65+
data["files"] = serialized_files
66+
4567
# Use absolute URL to bypass the /v1 prefix
4668
response, _, _ = requestor.request(
4769
options=TogetherRequest(

src/together/types/code_interpreter.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,15 @@
66

77
from together.types.endpoints import TogetherJSONModel
88

9+
class FileInput(TogetherJSONModel):
10+
"""File input to be uploaded to the code interpreter session."""
11+
12+
name: str = Field(description="The name of the file.")
13+
encoding: Literal["string", "base64"] = Field(
14+
description="Encoding of the file content. Use 'string' for text files and 'base64' for binary files."
15+
)
16+
content: str = Field(description="The content of the file, encoded as specified.")
17+
918

1019
class InterpreterOutput(TogetherJSONModel):
1120
"""Base class for interpreter output types."""
@@ -40,6 +49,7 @@ class ExecuteResponse(TogetherJSONModel):
4049

4150

4251
__all__ = [
52+
"FileInput",
4353
"InterpreterOutput",
4454
"ExecuteResponseData",
4555
"ExecuteResponse",

tests/unit/test_code_interpreter.py

Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from __future__ import annotations
22

3+
import pytest
4+
from pydantic import ValidationError
35

46
from together.resources.code_interpreter import CodeInterpreter
57
from together.together_response import TogetherResponse
@@ -326,3 +328,117 @@ def test_code_interpreter_session_management(mocker):
326328

327329
# Second call should have session_id
328330
assert calls[1][1]["options"].params["session_id"] == "new_session"
331+
332+
333+
def test_code_interpreter_run_with_files(mocker):
334+
335+
mock_requestor = mocker.MagicMock()
336+
response_data = {
337+
"data": {
338+
"session_id": "test_session_files",
339+
"status": "success",
340+
"outputs": [{"type": "stdout", "data": "File content read"}],
341+
}
342+
}
343+
mock_headers = {
344+
"cf-ray": "test-ray-id-files",
345+
"x-ratelimit-remaining": "98",
346+
"x-hostname": "test-host",
347+
"x-total-time": "42.0",
348+
}
349+
mock_response = TogetherResponse(data=response_data, headers=mock_headers)
350+
mock_requestor.request.return_value = (mock_response, None, None)
351+
mocker.patch(
352+
"together.abstract.api_requestor.APIRequestor", return_value=mock_requestor
353+
)
354+
355+
# Create code interpreter instance
356+
client = mocker.MagicMock()
357+
interpreter = CodeInterpreter(client)
358+
359+
# Define files
360+
files_to_upload = [
361+
{"name": "test.txt", "encoding": "string", "content": "Hello from file!"},
362+
{"name": "image.png", "encoding": "base64", "content": "aW1hZ2UgZGF0YQ=="},
363+
]
364+
365+
# Test run method with files (passing list of dicts)
366+
response = interpreter.run(
367+
code='with open("test.txt") as f: print(f.read())',
368+
language="python",
369+
files=files_to_upload, # Pass the list of dictionaries directly
370+
)
371+
372+
# Verify the response
373+
assert isinstance(response, ExecuteResponse)
374+
assert response.data.session_id == "test_session_files"
375+
assert response.data.status == "success"
376+
assert len(response.data.outputs) == 1
377+
assert response.data.outputs[0].type == "stdout"
378+
379+
# Verify API request includes files (expected_files_payload remains the same)
380+
mock_requestor.request.assert_called_once_with(
381+
options=mocker.ANY,
382+
stream=False,
383+
)
384+
request_options = mock_requestor.request.call_args[1]["options"]
385+
assert request_options.method == "POST"
386+
assert request_options.url == "/tci/execute"
387+
expected_files_payload = [
388+
{"name": "test.txt", "encoding": "string", "content": "Hello from file!"},
389+
{"name": "image.png", "encoding": "base64", "content": "aW1hZ2UgZGF0YQ=="},
390+
]
391+
assert request_options.params == {
392+
"code": 'with open("test.txt") as f: print(f.read())',
393+
"language": "python",
394+
"files": expected_files_payload,
395+
}
396+
397+
def test_code_interpreter_run_with_invalid_file_dict_structure(mocker):
398+
"""Test that run raises ValueError for missing keys in file dict."""
399+
client = mocker.MagicMock()
400+
interpreter = CodeInterpreter(client)
401+
402+
invalid_files = [
403+
{"name": "test.txt", "content": "Missing encoding"} # Missing 'encoding'
404+
]
405+
406+
with pytest.raises(ValueError, match="Invalid file input format"):
407+
interpreter.run(
408+
code="print('test')",
409+
language="python",
410+
files=invalid_files,
411+
)
412+
413+
def test_code_interpreter_run_with_invalid_file_dict_encoding(mocker):
414+
"""Test that run raises ValueError for invalid encoding value."""
415+
client = mocker.MagicMock()
416+
interpreter = CodeInterpreter(client)
417+
418+
invalid_files = [
419+
{"name": "test.txt", "encoding": "utf-8", "content": "Invalid encoding"} # Invalid 'encoding' value
420+
]
421+
422+
with pytest.raises(ValueError, match="Invalid file input format"):
423+
interpreter.run(
424+
code="print('test')",
425+
language="python",
426+
files=invalid_files,
427+
)
428+
429+
def test_code_interpreter_run_with_invalid_file_list_item(mocker):
430+
"""Test that run raises ValueError for non-dict item in files list."""
431+
client = mocker.MagicMock()
432+
interpreter = CodeInterpreter(client)
433+
434+
invalid_files = [
435+
{"name": "good.txt", "encoding": "string", "content": "Good"},
436+
"not a dictionary" # Invalid item type
437+
]
438+
439+
with pytest.raises(ValueError, match="Invalid file input: Each item in 'files' must be a dictionary"):
440+
interpreter.run(
441+
code="print('test')",
442+
language="python",
443+
files=invalid_files,
444+
)

0 commit comments

Comments
 (0)