Skip to content

Commit 5f00d95

Browse files
committed
Small style changes
1 parent 7357926 commit 5f00d95

3 files changed

Lines changed: 91 additions & 68 deletions

File tree

src/together/resources/finetune.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,12 @@
2929
from together.utils import log_warn_once, normalize_key
3030

3131

32+
AVAILABLE_TRAINING_METHODS = {
33+
TrainingMethodSFT().method,
34+
TrainingMethodDPO().method,
35+
}
36+
37+
3238
def createFinetuneRequest(
3339
model_limits: FinetuneTrainingLimits,
3440
training_file: str,
@@ -105,10 +111,6 @@ def createFinetuneRequest(
105111
if weight_decay is not None and (weight_decay < 0):
106112
raise ValueError("Weight decay should be non-negative")
107113

108-
AVAILABLE_TRAINING_METHODS = {
109-
TrainingMethodSFT().method,
110-
TrainingMethodDPO().method,
111-
}
112114
if training_method not in AVAILABLE_TRAINING_METHODS:
113115
raise ValueError(
114116
f"training_method must be one of {', '.join(AVAILABLE_TRAINING_METHODS)}"

src/together/utils/files.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -96,9 +96,7 @@ def check_file(
9696
return report_dict
9797

9898

99-
def validate_messages(
100-
messages: List[Dict[str, str | bool]], idx: int
101-
) -> None:
99+
def validate_messages(messages: List[Dict[str, str | bool]], idx: int) -> None:
102100
"""Validate the messages column."""
103101
if not isinstance(messages, list):
104102
raise InvalidFileFormatError(
@@ -227,7 +225,6 @@ def validate_preference_openai(example: Dict[str, Any], idx: int = 0) -> None:
227225
line_number=idx + 1,
228226
error_source="key_value",
229227
)
230-
231228

232229
validate_messages(example["preferred_output"], idx)
233230
validate_messages(example["non_preferred_output"], idx)

tests/unit/test_preference_openai.py

Lines changed: 84 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from together.constants import MIN_SAMPLES
66
from together.utils.files import check_file
77

8-
# Test data for preference OpenAI format
8+
99
_TEST_PREFERENCE_OPENAI_CONTENT = [
1010
{
1111
"input": {
@@ -70,16 +70,25 @@ def test_check_jsonl_valid_preference_openai(tmp_path: Path):
7070
assert report["has_min_samples"]
7171

7272

73-
# Define test cases for missing fields
7473
MISSING_FIELDS_TEST_CASES = [
7574
pytest.param("input", "Missing input field", id="missing_input"),
76-
pytest.param("preferred_output", "Missing preferred_output field", id="missing_preferred_output"),
77-
pytest.param("non_preferred_output", "Missing non_preferred_output field", id="missing_non_preferred_output"),
75+
pytest.param(
76+
"preferred_output",
77+
"Missing preferred_output field",
78+
id="missing_preferred_output",
79+
),
80+
pytest.param(
81+
"non_preferred_output",
82+
"Missing non_preferred_output field",
83+
id="missing_non_preferred_output",
84+
),
7885
]
7986

8087

8188
@pytest.mark.parametrize("field_to_remove, description", MISSING_FIELDS_TEST_CASES)
82-
def test_check_jsonl_invalid_preference_openai_missing_fields(tmp_path: Path, field_to_remove, description):
89+
def test_check_jsonl_invalid_preference_openai_missing_fields(
90+
tmp_path: Path, field_to_remove, description
91+
):
8392
"""Test missing required fields in OpenAI preference format."""
8493
file = tmp_path / f"invalid_preference_openai_missing_{field_to_remove}.jsonl"
8594
content = [item.copy() for item in _TEST_PREFERENCE_OPENAI_CONTENT]
@@ -95,69 +104,58 @@ def test_check_jsonl_invalid_preference_openai_missing_fields(tmp_path: Path, fi
95104
assert not report["is_check_passed"], f"Test should fail when {description}"
96105

97106

98-
# Define test cases for structural issues
99107
STRUCTURAL_ISSUE_TEST_CASES = [
100108
pytest.param(
101109
"empty_messages",
102110
lambda item: item.update({"input": {"messages": []}}),
103111
"Empty messages array",
104-
id="empty_messages"
112+
id="empty_messages",
105113
),
106114
pytest.param(
107115
"missing_role_preferred",
108116
lambda item: item.update(
109117
{"preferred_output": [{"content": "Missing role field"}]}
110118
),
111119
"Missing role in preferred_output",
112-
id="missing_role_preferred"
120+
id="missing_role_preferred",
113121
),
114122
pytest.param(
115123
"missing_role_non_preferred",
116124
lambda item: item.update(
117125
{"non_preferred_output": [{"content": "Missing role field"}]}
118126
),
119127
"Missing role in non_preferred_output",
120-
id="missing_role_non_preferred"
128+
id="missing_role_non_preferred",
121129
),
122130
pytest.param(
123131
"missing_content_preferred",
124-
lambda item: item.update(
125-
{"preferred_output": [{"role": "assistant"}]}
126-
),
132+
lambda item: item.update({"preferred_output": [{"role": "assistant"}]}),
127133
"Missing content in preferred_output",
128-
id="missing_content_preferred"
134+
id="missing_content_preferred",
129135
),
130136
pytest.param(
131137
"missing_content_non_preferred",
132-
lambda item: item.update(
133-
{"non_preferred_output": [{"role": "assistant"}]}
134-
),
138+
lambda item: item.update({"non_preferred_output": [{"role": "assistant"}]}),
135139
"Missing content in non_preferred_output",
136-
id="missing_content_non_preferred"
140+
id="missing_content_non_preferred",
137141
),
138142
pytest.param(
139143
"wrong_output_format_preferred",
140-
lambda item: item.update(
141-
{"preferred_output": "Not an array but a string"}
142-
),
144+
lambda item: item.update({"preferred_output": "Not an array but a string"}),
143145
"Wrong format for preferred_output",
144-
id="wrong_output_format_preferred"
146+
id="wrong_output_format_preferred",
145147
),
146148
pytest.param(
147149
"wrong_output_format_non_preferred",
148-
lambda item: item.update(
149-
{"non_preferred_output": "Not an array but a string"}
150-
),
150+
lambda item: item.update({"non_preferred_output": "Not an array but a string"}),
151151
"Wrong format for non_preferred_output",
152-
id="wrong_output_format_non_preferred"
152+
id="wrong_output_format_non_preferred",
153153
),
154154
pytest.param(
155155
"missing_content",
156-
lambda item: item.update(
157-
{"input": {"messages": [{"role": "user"}]}}
158-
),
156+
lambda item: item.update({"input": {"messages": [{"role": "user"}]}}),
159157
"Missing content in messages",
160-
id="missing_content"
158+
id="missing_content",
161159
),
162160
pytest.param(
163161
"multiple_preferred_outputs",
@@ -170,7 +168,7 @@ def test_check_jsonl_invalid_preference_openai_missing_fields(tmp_path: Path, fi
170168
}
171169
),
172170
"Multiple messages in preferred_output",
173-
id="multiple_preferred_outputs"
171+
id="multiple_preferred_outputs",
174172
),
175173
pytest.param(
176174
"multiple_non_preferred_outputs",
@@ -183,88 +181,114 @@ def test_check_jsonl_invalid_preference_openai_missing_fields(tmp_path: Path, fi
183181
}
184182
),
185183
"Multiple messages in non_preferred_output",
186-
id="multiple_non_preferred_outputs"
184+
id="multiple_non_preferred_outputs",
187185
),
188186
pytest.param(
189187
"empty_preferred_output",
190188
lambda item: item.update({"preferred_output": []}),
191189
"Empty preferred_output array",
192-
id="empty_preferred_output"
190+
id="empty_preferred_output",
193191
),
194192
pytest.param(
195193
"empty_non_preferred_output",
196194
lambda item: item.update({"non_preferred_output": []}),
197195
"Empty non_preferred_output array",
198-
id="empty_non_preferred_output"
196+
id="empty_non_preferred_output",
199197
),
200198
pytest.param(
201199
"non_string_content_in_messages",
202-
lambda item: item.update({"input": {"messages": [{"role": "user", "content": 123}]}}),
200+
lambda item: item.update(
201+
{"input": {"messages": [{"role": "user", "content": 123}]}}
202+
),
203203
"Non-string content in messages",
204-
id="non_string_content_in_messages"
204+
id="non_string_content_in_messages",
205205
),
206206
pytest.param(
207207
"invalid_role_in_messages",
208-
lambda item: item.update({"input": {"messages": [{"role": "invalid_role", "content": "Hello"}]}}),
208+
lambda item: item.update(
209+
{"input": {"messages": [{"role": "invalid_role", "content": "Hello"}]}}
210+
),
209211
"Invalid role in messages",
210-
id="invalid_role_in_messages"
212+
id="invalid_role_in_messages",
211213
),
212214
pytest.param(
213215
"non_alternating_roles",
214-
lambda item: item.update({"input": {"messages": [
215-
{"role": "user", "content": "Hello"},
216-
{"role": "user", "content": "How are you?"}
217-
]}}),
216+
lambda item: item.update(
217+
{
218+
"input": {
219+
"messages": [
220+
{"role": "user", "content": "Hello"},
221+
{"role": "user", "content": "How are you?"},
222+
]
223+
}
224+
}
225+
),
218226
"Non-alternating roles in messages",
219-
id="non_alternating_roles"
227+
id="non_alternating_roles",
220228
),
221229
pytest.param(
222230
"invalid_weight_type",
223-
lambda item: item.update({"input": {"messages": [
224-
{"role": "user", "content": "Hello", "weight": "not_an_integer"}
225-
]}}),
231+
lambda item: item.update(
232+
{
233+
"input": {
234+
"messages": [
235+
{"role": "user", "content": "Hello", "weight": "not_an_integer"}
236+
]
237+
}
238+
}
239+
),
226240
"Invalid weight type",
227-
id="invalid_weight_type"
241+
id="invalid_weight_type",
228242
),
229243
pytest.param(
230244
"invalid_weight_value",
231-
lambda item: item.update({"input": {"messages": [
232-
{"role": "user", "content": "Hello", "weight": 2}
233-
]}}),
245+
lambda item: item.update(
246+
{"input": {"messages": [{"role": "user", "content": "Hello", "weight": 2}]}}
247+
),
234248
"Invalid weight value",
235-
id="invalid_weight_value"
249+
id="invalid_weight_value",
236250
),
237251
pytest.param(
238252
"non_dict_message",
239-
lambda item: item.update({"input": {"messages": [
240-
"Not a dictionary"
241-
]}}),
253+
lambda item: item.update({"input": {"messages": ["Not a dictionary"]}}),
242254
"Non-dictionary message",
243-
id="non_dict_message"
255+
id="non_dict_message",
244256
),
245257
pytest.param(
246258
"non_dict_input",
247259
lambda item: item.update({"input": "Not a dictionary"}),
248260
"Non-dictionary input",
249-
id="non_dict_input"
261+
id="non_dict_input",
250262
),
251263
pytest.param(
252264
"missing_messages_in_input",
253265
lambda item: item.update({"input": {}}),
254266
"Missing messages in input",
255-
id="missing_messages_in_input"
267+
id="missing_messages_in_input",
256268
),
257269
pytest.param(
258270
"non_assistant_role_in_preferred",
259-
lambda item: item.update({"preferred_output": [{"role": "user", "content": "This should be assistant"}]}),
271+
lambda item: item.update(
272+
{
273+
"preferred_output": [
274+
{"role": "user", "content": "This should be assistant"}
275+
]
276+
}
277+
),
260278
"Non-assistant role in preferred output",
261-
id="non_assistant_role_in_preferred"
279+
id="non_assistant_role_in_preferred",
262280
),
263281
pytest.param(
264282
"non_assistant_role_in_non_preferred",
265-
lambda item: item.update({"non_preferred_output": [{"role": "user", "content": "This should be assistant"}]}),
283+
lambda item: item.update(
284+
{
285+
"non_preferred_output": [
286+
{"role": "user", "content": "This should be assistant"}
287+
]
288+
}
289+
),
266290
"Non-assistant role in non-preferred output",
267-
id="non_assistant_role_in_non_preferred"
291+
id="non_assistant_role_in_non_preferred",
268292
),
269293
]
270294

0 commit comments

Comments
 (0)