Skip to content

Commit b93a673

Browse files
Support Multimodal datasets
1 parent 137b84e commit b93a673

2 files changed

Lines changed: 206 additions & 29 deletions

File tree

src/together/constants.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import enum
22

3+
34
# Session constants
45
TIMEOUT_SECS = 600
56
MAX_SESSION_LIFETIME_SECS = 180
@@ -40,6 +41,11 @@
4041
# the number of bytes in a gigabyte, used to convert bytes to GB for readable comparison
4142
NUM_BYTES_IN_GB = 2**30
4243

44+
# Multimodal limits
45+
MAX_IMAGES_PER_EXAMPLE = 10
46+
MAX_IMAGE_BYTES = 10 * 1024 * 1024 # 10MB
47+
# Max length = Header length + base64 factor (4/3) * image bytes
48+
MAX_BASE64_IMAGE_LENGTH = len("data:image/jpeg;base64,") + 4 * MAX_IMAGE_BYTES // 3
4349

4450
# expected columns for Parquet files
4551
PARQUET_EXPECTED_COLUMNS = ["input_ids", "attention_mask", "labels"]

src/together/utils/files.py

Lines changed: 200 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,39 @@
11
from __future__ import annotations
22

3+
import csv
34
import json
45
import os
5-
import csv
66
from pathlib import Path
77
from traceback import format_exc
88
from typing import Any, Dict, List
99

1010
from tqdm import tqdm
1111

1212
from together.constants import (
13+
JSONL_REQUIRED_COLUMNS_MAP,
14+
MAX_BASE64_IMAGE_LENGTH,
1315
MAX_FILE_SIZE_GB,
16+
MAX_IMAGES_PER_EXAMPLE,
1417
MIN_SAMPLES,
1518
NUM_BYTES_IN_GB,
1619
PARQUET_EXPECTED_COLUMNS,
17-
JSONL_REQUIRED_COLUMNS_MAP,
18-
REQUIRED_COLUMNS_MESSAGE,
1920
POSSIBLE_ROLES_CONVERSATION,
21+
REQUIRED_COLUMNS_MESSAGE,
2022
DatasetFormat,
2123
)
2224
from together.types import FilePurpose
2325

2426

27+
# MessageContent is a string or a list of dicts with 'type': 'text' or 'image_url', and 'text' or 'image_url.url'
28+
# Example: "Hello" or [
29+
# {"type": "text", "text": "Hello"},
30+
# {"type": "image_url", "image_url": {
31+
# "url": "data:image/jpeg;base64,..."
32+
# }}
33+
# ]
34+
MessageContent = str | list[dict[str, Any]]
35+
36+
2537
class InvalidFileFormatError(ValueError):
2638
"""Exception raised for invalid file formats during file checks."""
2739

@@ -70,7 +82,7 @@ def check_file(
7082

7183
if file_size > MAX_FILE_SIZE_GB * NUM_BYTES_IN_GB:
7284
report_dict["message"] = (
73-
f"Maximum supported file size is {MAX_FILE_SIZE_GB} GB. Found file with size of {round(file_size / NUM_BYTES_IN_GB ,3)} GB."
85+
f"Maximum supported file size is {MAX_FILE_SIZE_GB} GB. Found file with size of {round(file_size / NUM_BYTES_IN_GB, 3)} GB."
7486
)
7587
report_dict["is_check_passed"] = False
7688
elif file_size == 0:
@@ -103,7 +115,9 @@ def check_file(
103115
return report_dict
104116

105117

106-
def _check_conversation_type(messages: List[Dict[str, str | bool]], idx: int) -> None:
118+
def _check_conversation_type(
119+
messages: List[Dict[str, str | int | MessageContent]], idx: int
120+
) -> None:
107121
"""Check that the conversation has correct type.
108122
109123
Args:
@@ -175,7 +189,9 @@ def _check_conversation_roles(
175189
)
176190

177191

178-
def _check_message_weight(message: Dict[str, str | bool], idx: int) -> None:
192+
def _check_message_weight(
193+
message: Dict[str, str | int | MessageContent], idx: int
194+
) -> int | None:
179195
"""Check that the message has a weight with the correct type and value.
180196
181197
Args:
@@ -199,11 +215,14 @@ def _check_message_weight(message: Dict[str, str | bool], idx: int) -> None:
199215
line_number=idx + 1,
200216
error_source="key_value",
201217
)
218+
return weight
219+
220+
return None
202221

203222

204223
def _check_message_role(
205-
message: Dict[str, str | bool], previous_role: str | None, idx: int
206-
) -> str | bool:
224+
message: Dict[str, str | int | MessageContent], previous_role: str | None, idx: int
225+
) -> str:
207226
"""Check that the message has correct roles.
208227
209228
Args:
@@ -217,6 +236,14 @@ def _check_message_role(
217236
Raises:
218237
InvalidFileFormatError: If the message role is invalid.
219238
"""
239+
if not isinstance(message["role"], str):
240+
raise InvalidFileFormatError(
241+
message=f"Invalid role `{message['role']}` in conversation on line {idx + 1}. "
242+
f"Role must be a string. Found {type(message['role'])}",
243+
line_number=idx + 1,
244+
error_source="key_value",
245+
)
246+
220247
if message["role"] not in POSSIBLE_ROLES_CONVERSATION:
221248
raise InvalidFileFormatError(
222249
message=f"Invalid role `{message['role']}` in conversation on line {idx + 1}. "
@@ -234,8 +261,132 @@ def _check_message_role(
234261
return message["role"]
235262

236263

264+
def _check_message_content(
265+
message_content: str | int | MessageContent, role: str, idx: int
266+
) -> tuple[bool, int]:
267+
"""Check that the message content has the correct type.
268+
Message content can be either a) a string or b) an OpenAI-style multimodal list of content items
269+
Example:
270+
a) "Hello", or
271+
b) [
272+
{"type": "text", "text": "Hello"},
273+
{"type": "image_url", "image_url": {
274+
"url": "data:image/jpeg;base64,..."
275+
}}
276+
]
277+
278+
Args:
279+
message: The message to check.
280+
idx: Line number in the file.
281+
282+
Returns:
283+
tuple[bool, int]: A tuple with message is multimodal and the number of images in the message content.
284+
"""
285+
# Text-only message content
286+
if isinstance(message_content, str):
287+
return False, 0
288+
289+
# Multimodal message content
290+
if isinstance(message_content, list):
291+
num_images = 0
292+
for item in message_content:
293+
if not isinstance(item, dict):
294+
raise InvalidFileFormatError(
295+
"The dataset is malformed, the `content` field must be a list of dicts.",
296+
line_number=idx + 1,
297+
error_source="key_value",
298+
)
299+
if "type" not in item:
300+
raise InvalidFileFormatError(
301+
"The dataset is malformed, the `content` field must be a list of dicts with a `type` field.",
302+
line_number=idx + 1,
303+
error_source="key_value",
304+
)
305+
306+
if item["type"] == "text":
307+
if "text" not in item or not isinstance(item["text"], str):
308+
raise InvalidFileFormatError(
309+
"The dataset is malformed, the `text` field must be present in the `content` item field and be"
310+
f" a string. Got '{item.get('text')!r}' instead.",
311+
line_number=idx + 1,
312+
error_source="key_value",
313+
)
314+
elif item["type"] == "image_url":
315+
if role != "user":
316+
raise InvalidFileFormatError(
317+
"The dataset is malformed, only user messages can contain images.",
318+
line_number=idx + 1,
319+
error_source="key_value",
320+
)
321+
322+
if "image_url" not in item or not isinstance(item["image_url"], dict):
323+
raise InvalidFileFormatError(
324+
"The dataset is malformed, the `image_url` field must be present in the `content` field and "
325+
f"be a dictionary. Got {item.get('image_url')!r} instead.",
326+
line_number=idx + 1,
327+
error_source="key_value",
328+
)
329+
330+
image_data = item["image_url"].get("url")
331+
if not image_data or not isinstance(image_data, str):
332+
raise InvalidFileFormatError(
333+
"The dataset is malformed, the `url` field must be present in the `image_url` field and be "
334+
f"a string. Got {image_data!r} instead.",
335+
line_number=idx + 1,
336+
error_source="key_value",
337+
)
338+
339+
if not any(
340+
image_data.startswith(f"data:image/{fmt};base64,")
341+
for fmt in ["jpeg", "png", "webp"]
342+
):
343+
raise InvalidFileFormatError(
344+
"The dataset is malformed, the `url` field must be either a JPEG, PNG or WEBP base64-encoded "
345+
"image in 'data:image/<format>;base64,<base64_encoded_image>' format. "
346+
f"Got '{image_data[:100]}...' instead.",
347+
line_number=idx + 1,
348+
)
349+
350+
if len(image_data) > MAX_BASE64_IMAGE_LENGTH:
351+
raise InvalidFileFormatError(
352+
"The dataset is malformed, the `url` field must contain base64-encoded image "
353+
f"that is less than 10MB, found ~{len(image_data) * 3 // 4} bytes.",
354+
line_number=idx + 1,
355+
error_source="key_value",
356+
)
357+
358+
num_images += 1
359+
else:
360+
raise InvalidFileFormatError(
361+
"The dataset is malformed, the `type` field must be either 'text' or 'image_url'. "
362+
f"Got {item['type']!r}.",
363+
line_number=idx + 1,
364+
error_source="key_value",
365+
)
366+
367+
if num_images > MAX_IMAGES_PER_EXAMPLE:
368+
raise InvalidFileFormatError(
369+
f"The dataset is malformed, the `content` field must contain at most "
370+
f"{MAX_IMAGES_PER_EXAMPLE} images, found {num_images}.",
371+
line_number=idx + 1,
372+
error_source="key_value",
373+
)
374+
375+
# We still consider text-only messages in such format as multimodal, even if they don't have any images
376+
# included - so we can process datasets with rather sparse images (i.e. not in each sample) consistently.
377+
return True, num_images
378+
379+
raise InvalidFileFormatError(
380+
message=f"Invalid content type on line {idx + 1} of the input file. Found {type(message_content)}",
381+
line_number=idx + 1,
382+
error_source="key_value",
383+
)
384+
385+
237386
def validate_messages(
238-
messages: List[Dict[str, str | bool]], idx: int, require_assistant_role: bool = True
387+
messages: List[Dict[str, str | int | MessageContent]],
388+
idx: int,
389+
require_assistant_role: bool = True,
239390
) -> None:
240391
"""Validate the messages column.
241392
@@ -249,15 +400,45 @@ def validate_messages(
249400
"""
250401
_check_conversation_type(messages, idx)
251402

252-
has_weights = any("weight" in message for message in messages)
253403
previous_role = None
254404
assistant_role_exists = False
255405

406+
messages_are_multimodal: bool | None = None
407+
total_number_of_images = 0
408+
256409
for message in messages:
257-
if has_weights:
258-
_check_message_weight(message, idx)
410+
message_weight = _check_message_weight(message, idx)
259411
previous_role = _check_message_role(message, previous_role, idx)
260412
assistant_role_exists |= previous_role == "assistant"
413+
is_multimodal, number_of_images = _check_message_content(
414+
message["content"], role=previous_role, idx=idx
415+
)
416+
# Multimodal validation
417+
if number_of_images > 0 and message_weight is not None and message_weight != 0:
418+
raise InvalidFileFormatError(
419+
"Messages with images cannot have non-zero weights.",
420+
line_number=idx + 1,
421+
error_source="key_value",
422+
)
423+
if messages_are_multimodal is None:
424+
# Detect the format of the messages in the conversation.
425+
messages_are_multimodal = is_multimodal
426+
elif messages_are_multimodal != is_multimodal:
427+
# Due to the format limitation, we cannot mix multimodal and text only messages in the same sample.
428+
raise InvalidFileFormatError(
429+
"Messages in the conversation must be either all in multimodal or all intext only format.",
430+
line_number=idx + 1,
431+
error_source="key_value",
432+
)
433+
total_number_of_images += number_of_images
434+
435+
if total_number_of_images > MAX_IMAGES_PER_EXAMPLE:
436+
raise InvalidFileFormatError(
437+
f"The dataset is malformed, the `messages` must contain at most {MAX_IMAGES_PER_EXAMPLE} images. "
438+
f"Found {total_number_of_images} images.",
439+
line_number=idx + 1,
440+
error_source="key_value",
441+
)
261442

262443
_check_conversation_roles(require_assistant_role, assistant_role_exists, idx)
263444

@@ -347,12 +528,7 @@ def validate_preference_openai(example: Dict[str, Any], idx: int = 0) -> None:
347528
error_source="key_value",
348529
)
349530

350-
if not isinstance(example[key][0]["content"], str):
351-
raise InvalidFileFormatError(
352-
message=f"The dataset is malformed, the 'content' field in `{key}` must be a string on line {idx + 1}.",
353-
line_number=idx + 1,
354-
error_source="key_value",
355-
)
531+
_check_message_content(example[key][0]["content"], role="assistant", idx=idx)
356532

357533

358534
def _check_utf8(file: Path) -> Dict[str, Any]:
@@ -454,8 +630,7 @@ def _check_csv(file: Path, purpose: FilePurpose | str) -> Dict[str, Any]:
454630
report_dict["load_csv"] = False
455631
if idx < 0:
456632
report_dict["message"] = (
457-
"Unable to decode file. "
458-
"File may be empty or in an unsupported format. "
633+
"Unable to decode file. File may be empty or in an unsupported format. "
459634
)
460635
else:
461636
report_dict["message"] = (
@@ -542,13 +717,10 @@ def _check_jsonl(file: Path, purpose: FilePurpose | str) -> Dict[str, Any]:
542717
)
543718
else:
544719
for column in JSONL_REQUIRED_COLUMNS_MAP[current_format]:
545-
if not isinstance(json_line[column], str):
546-
raise InvalidFileFormatError(
547-
message=f'Invalid value type for "{column}" key on line {idx + 1}. '
548-
f"Expected string. Found {type(json_line[column])}.",
549-
line_number=idx + 1,
550-
error_source="key_value",
551-
)
720+
role = "assistant" if column in {"completion"} else "user"
721+
_check_message_content(
722+
json_line[column], role=role, idx=idx
723+
)
552724

553725
if dataset_format is None:
554726
dataset_format = current_format
@@ -578,8 +750,7 @@ def _check_jsonl(file: Path, purpose: FilePurpose | str) -> Dict[str, Any]:
578750
report_dict["load_json"] = False
579751
if idx < 0:
580752
report_dict["message"] = (
581-
"Unable to decode file. "
582-
"File may be empty or in an unsupported format. "
753+
"Unable to decode file. File may be empty or in an unsupported format. "
583754
)
584755
else:
585756
report_dict["message"] = (

0 commit comments

Comments
 (0)