Skip to content

Commit f276b28

Browse files
committed
requested changes
1 parent 9a752eb commit f276b28

4 files changed

Lines changed: 60 additions & 24 deletions

File tree

openml/_api/resources/base/resources.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,11 @@ def get(
4343
"""
4444
...
4545

46+
@abstractmethod
47+
def supports_download_splits(self) -> bool:
48+
"""Return whether the task API implementation supports split downloads."""
49+
...
50+
4651
# Task listing (V1 only)
4752
@abstractmethod
4853
def list(

openml/_api/resources/task.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,9 @@ def get(self, task_id: int) -> OpenMLTask:
200200
response = self._http.get(f"task/{task_id}", enable_cache=True)
201201
return _create_task_from_xml(response.text)
202202

203+
def supports_download_splits(self) -> bool:
204+
return True
205+
203206
def list(
204207
self,
205208
limit: int,
@@ -401,3 +404,6 @@ def list(
401404
**kwargs: Any, # noqa: ARG002
402405
) -> pd.DataFrame:
403406
raise self._not_supported(method="list")
407+
408+
def supports_download_splits(self) -> bool:
409+
return False

openml/tasks/functions.py

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,11 @@ def get_tasks(
134134
tasks = []
135135
for task_id in task_ids:
136136
tasks.append(
137-
get_task(task_id, download_data=download_data, download_qualities=download_qualities)
137+
get_task(
138+
task_id,
139+
download_data=download_data,
140+
download_qualities=download_qualities,
141+
)
138142
)
139143
return tasks
140144

@@ -166,8 +170,6 @@ def get_task(
166170
-------
167171
task: OpenMLTask
168172
"""
169-
from openml._api.resources.task import TaskV1API, TaskV2API
170-
171173
if not isinstance(task_id, int):
172174
raise TypeError(f"Task id should be integer, is {type(task_id)}")
173175

@@ -180,17 +182,14 @@ def get_task(
180182
):
181183
task.class_labels = dataset.retrieve_class_labels(task.target_name)
182184

183-
if (
184-
download_splits
185-
and isinstance(task, OpenMLSupervisedTask)
186-
and isinstance(openml._backend.task, TaskV1API)
187-
):
188-
task.download_split()
189-
elif download_splits and isinstance(openml._backend.task, TaskV2API):
190-
warnings.warn(
191-
"`download_splits` is not yet supported in the v2 API and will be ignored.",
192-
stacklevel=2,
193-
)
185+
if download_splits and isinstance(task, OpenMLSupervisedTask):
186+
if openml._backend.task.supports_download_splits():
187+
task.download_split()
188+
else:
189+
warnings.warn(
190+
"`download_splits` is not yet supported in the v2 API and will be ignored.",
191+
stacklevel=2,
192+
)
194193

195194
return task
196195

tests/test_tasks/test_task_functions.py

Lines changed: 36 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
OpenMLNotAuthorizedError,
1515
OpenMLServerException,
1616
)
17-
from openml.tasks import TaskType
17+
from openml.tasks import TaskType, task
1818
from openml.testing import TestBase, create_request_response
1919

2020

@@ -29,7 +29,7 @@ def tearDown(self):
2929

3030
@pytest.mark.test_server()
3131
def test__get_estimation_procedure_list(self):
32-
estimation_procedures = openml.tasks.functions._get_estimation_procedure_list()
32+
estimation_procedures = openml._api.resources.task._get_estimation_procedure_list()
3333
assert isinstance(estimation_procedures, list)
3434
assert isinstance(estimation_procedures[0], dict)
3535
assert estimation_procedures[0]["task_type_id"] == TaskType.SUPERVISED_CLASSIFICATION
@@ -116,6 +116,13 @@ def test_list_tasks_per_type_paginate(self):
116116
assert j == task["ttid"]
117117
self._check_task(task)
118118

119+
@pytest.mark.test_server()
120+
def test__get_task(self):
121+
openml.config.set_root_cache_directory(self.static_cache_dir)
122+
with unittest.mock.patch("requests.sessions.Session.request") as mock_request:
123+
openml.tasks.get_task(1882)
124+
mock_request.assert_not_called()
125+
119126
@unittest.skip(
120127
"Please await outcome of discussion: https://github.com/openml/OpenML/issues/776",
121128
)
@@ -126,13 +133,18 @@ def test__get_task_live(self):
126133
# https://github.com/openml/openml-python/issues/378
127134
openml.tasks.get_task(34536)
128135

129-
@pytest.mark.skipif(
130-
os.getenv("OPENML_USE_LOCAL_SERVICES") == "true",
131-
reason="Pending resolution of #1657",
132-
)
136+
@pytest.mark.test_server()
137+
def test_get_task(self):
138+
with unittest.mock.patch("requests.sessions.Session.request") as mock_request:
139+
openml.tasks.get_task(1)
140+
mock_request.assert_not_called()
141+
133142
@pytest.mark.test_server()
134143
def test_get_task_lazy(self):
135-
task = openml.tasks.get_task(2, download_data=False) # anneal; crossvalidation
144+
with unittest.mock.patch("requests.sessions.Session.request") as mock_request:
145+
task = openml.tasks.get_task(2, download_data=False) # anneal; crossvalidation
146+
mock_request.assert_not_called()
147+
136148
assert isinstance(task, OpenMLTask)
137149
assert os.path.exists(
138150
os.path.join(openml.config.get_cache_directory(), "tasks", "2", "task.xml")
@@ -151,7 +163,10 @@ def test_get_task_lazy(self):
151163
)
152164
)
153165

154-
task.download_split()
166+
with unittest.mock.patch("requests.sessions.Session.request") as mock_request:
167+
task.download_split()
168+
mock_request.assert_not_called()
169+
155170
assert os.path.exists(
156171
os.path.join(
157172
openml.config.get_cache_directory(), "tasks", "2", "datasplits.arff"
@@ -177,6 +192,15 @@ def assert_and_raise(*args, **kwargs):
177192
# Now the file should no longer exist
178193
assert not os.path.exists(os.path.join(os.getcwd(), "tasks", "1", "tasks.xml"))
179194

195+
@pytest.mark.test_server()
196+
def test_get_task_with_cache(self):
197+
openml.config.set_root_cache_directory(self.static_cache_dir)
198+
with unittest.mock.patch("requests.sessions.Session.request") as mock_request:
199+
task = openml.tasks.get_task(1)
200+
mock_request.assert_not_called()
201+
202+
assert isinstance(task, OpenMLTask)
203+
180204
@pytest.mark.production_server()
181205
def test_get_task_different_types(self):
182206
self.use_production_server()
@@ -189,8 +213,10 @@ def test_get_task_different_types(self):
189213

190214
@pytest.mark.test_server()
191215
def test_download_split(self):
192-
task = openml.tasks.get_task(1) # anneal; crossvalidation
193-
split = task.download_split()
216+
with unittest.mock.patch("requests.sessions.Session.request") as mock_request:
217+
task = openml.tasks.get_task(1) # anneal; crossvalidation
218+
split = task.download_split()
219+
mock_request.assert_not_called()
194220
assert type(split) == OpenMLSplit
195221
assert os.path.exists(
196222
os.path.join(

0 commit comments

Comments
 (0)