Skip to content

Commit 3c2efef

Browse files
committed
fix tests
1 parent a00f4ee commit 3c2efef

3 files changed

Lines changed: 40 additions & 14 deletions

File tree

main.py

Lines changed: 0 additions & 6 deletions
This file was deleted.

test/test_postgres_utils.py

Lines changed: 36 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,18 @@
1-
import subprocess
21
import getpass
3-
import pytest
2+
import subprocess
3+
44
import psycopg2
5+
import pytest
56
from psycopg2 import OperationalError
67

78
# Import all functions to test
89
from codes.tune import (
9-
_make_db_url,
10-
_check_remote_reachable,
1110
_check_postgres_running_local,
12-
_start_postgres_server_local,
11+
_check_remote_reachable,
1312
_initialize_postgres_local,
1413
_initialize_postgres_remote,
14+
_make_db_url,
15+
_start_postgres_server_local,
1516
initialize_optuna_database,
1617
)
1718

@@ -194,6 +195,36 @@ def test_initialize_postgres_remote_interactive(monkeypatch):
194195
monkeypatch.setattr(
195196
"codes.tune.postgres_fcts._check_remote_reachable", lambda conf: None
196197
)
198+
199+
# fake psycopg2 connection/cursor to avoid real network calls
200+
def fake_connect(**kwargs):
201+
class FakeCursor:
202+
def execute(self, query, params=None):
203+
# accept any SQL (CREATE/DROP/SELECT)
204+
self._last_query = query
205+
206+
def fetchone(self):
207+
# simulate "schema does not exist"
208+
return None
209+
210+
def close(self):
211+
pass
212+
213+
class FakeConn:
214+
def __init__(self):
215+
self.autocommit = False
216+
217+
def cursor(self):
218+
return FakeCursor()
219+
220+
def close(self):
221+
pass
222+
223+
return FakeConn()
224+
225+
# patch the connect used inside the module under test
226+
monkeypatch.setattr("codes.tune.postgres_fcts.psycopg2.connect", fake_connect)
227+
197228
url = _initialize_postgres_remote(cfg, "ignored")
198229
assert "sslmode=require" in url
199230

test/test_tuning_pipeline.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,7 @@ def __init__(self, num, state, start, complete=None):
125125
self.state = state
126126
self.datetime_start = start
127127
self.datetime_complete = complete
128+
self.user_attrs = {}
128129

129130

130131
class FakeStudy:
@@ -144,16 +145,16 @@ def test_maybe_set_runtime_threshold_not_enough():
144145
t1 = FakeTrial(
145146
0,
146147
TrialState.COMPLETE,
147-
datetime.utcnow() - timedelta(seconds=5),
148-
datetime.utcnow(),
148+
datetime.now() - timedelta(seconds=5),
149+
datetime.now(),
149150
)
150151
study = FakeStudy([t1])
151152
maybe_set_runtime_threshold(study, warmup_target=2)
152153
assert "runtime_threshold" not in study.user_attrs
153154

154155

155156
def test_maybe_set_runtime_threshold_enough():
156-
now = datetime.utcnow()
157+
now = datetime.now()
157158
trials = []
158159
for i in range(3):
159160
trials.append(

0 commit comments

Comments
 (0)