Skip to content

Commit 6de4e76

Browse files
committed
Improve tuning tests
1 parent f951722 commit 6de4e76

2 files changed

Lines changed: 231 additions & 1 deletion

File tree

test/test_postgres_utils.py

Lines changed: 225 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,225 @@
1+
import subprocess
2+
import getpass
3+
import pytest
4+
import psycopg2
5+
from psycopg2 import OperationalError
6+
7+
# Import all functions to test
8+
from codes.tune import (
9+
_make_db_url,
10+
_check_remote_reachable,
11+
_check_postgres_running_local,
12+
_start_postgres_server_local,
13+
_initialize_postgres_local,
14+
_initialize_postgres_remote,
15+
initialize_optuna_database,
16+
)
17+
18+
19+
# --- _make_db_url tests ---
20+
def test_make_db_url_with_password():
21+
url = _make_db_url("user", "pass", "host", 5432, "db", "?sslmode=require")
22+
assert url == "postgresql+psycopg2://user:pass@host:5432/db?sslmode=require"
23+
24+
25+
def test_make_db_url_without_password():
26+
url = _make_db_url("user", "", "localhost", 5433, "test", "")
27+
assert url == "postgresql+psycopg2://user@localhost:5433/test"
28+
29+
30+
# --- _check_remote_reachable tests ---
31+
def test_check_remote_reachable_success(monkeypatch):
32+
# monkeypatch psycopg2.connect to succeed
33+
monkeypatch.setattr(
34+
psycopg2,
35+
"connect",
36+
lambda **kw: type("Conn", (), {"close": lambda self: None})(),
37+
)
38+
# should not raise
39+
_check_remote_reachable({"host": "h", "port": 1111, "user": "u", "password": "p"})
40+
41+
42+
def test_check_remote_reachable_failure(monkeypatch):
43+
# monkeypatch connect to raise OperationalError
44+
def bad_connect(**kw):
45+
raise OperationalError("cant connect")
46+
47+
monkeypatch.setattr(psycopg2, "connect", bad_connect)
48+
with pytest.raises(ConnectionError) as exc:
49+
_check_remote_reachable({"host": "h", "port": 2222})
50+
assert "Cannot reach remote Postgres" in str(exc.value)
51+
52+
53+
# --- _check_postgres_running_local tests ---
54+
def test_check_postgres_running_local_success(monkeypatch):
55+
monkeypatch.setattr(
56+
psycopg2,
57+
"connect",
58+
lambda **kw: type("Conn", (), {"close": lambda self: None})(),
59+
)
60+
_check_postgres_running_local({"postgres_config": {"host": "h", "port": 5432}})
61+
62+
63+
def test_check_postgres_running_local_failure(monkeypatch):
64+
monkeypatch.setattr(
65+
psycopg2, "connect", lambda **kw: (_ for _ in ()).throw(OperationalError())
66+
)
67+
with pytest.raises(Exception) as exc:
68+
_check_postgres_running_local({"postgres_config": {}})
69+
assert "PostgreSQL server is not running" in str(exc.value)
70+
71+
72+
# --- _start_postgres_server_local tests ---
73+
def test_start_postgres_server_local_missing_data_dir(tmp_path):
74+
cfg = {
75+
"postgres_config": {
76+
"data_dir": str(tmp_path / "nope"),
77+
"database_folder": str(tmp_path),
78+
"log_file": "log",
79+
}
80+
}
81+
with pytest.raises(Exception) as exc:
82+
_start_postgres_server_local(cfg)
83+
assert "data directory" in str(exc.value)
84+
85+
86+
def test_start_postgres_server_local_missing_pg_ctl(tmp_path):
87+
data_dir = tmp_path / "data"
88+
data_dir.mkdir()
89+
cfg = {
90+
"postgres_config": {
91+
"data_dir": str(data_dir),
92+
"database_folder": str(tmp_path),
93+
"log_file": "log",
94+
}
95+
}
96+
with pytest.raises(Exception) as exc:
97+
_start_postgres_server_local(cfg)
98+
assert "pg_ctl not found" in str(exc.value)
99+
100+
101+
def test_start_postgres_server_local_success(tmp_path, monkeypatch, capsys):
102+
# create data_dir and pg_ctl
103+
data_dir = tmp_path / "data"
104+
data_dir.mkdir()
105+
bin_dir = tmp_path / "bin"
106+
bin_dir.mkdir()
107+
pg_ctl = bin_dir / "pg_ctl"
108+
pg_ctl.write_text("")
109+
cfg = {
110+
"postgres_config": {
111+
"data_dir": str(data_dir),
112+
"database_folder": str(tmp_path),
113+
"log_file": "lf",
114+
}
115+
}
116+
# monkeypatch subprocess.run
117+
monkeypatch.setattr(subprocess, "run", lambda *args, **kw: None)
118+
_start_postgres_server_local(cfg)
119+
captured = capsys.readouterr()
120+
assert "Starting PostgreSQL server" in captured.out
121+
assert "started successfully" in captured.out
122+
123+
124+
# --- _initialize_postgres_local tests ---
125+
class FakeCursor:
126+
def __init__(self, exists):
127+
self.exists = exists
128+
129+
def execute(self, q, params=None):
130+
pass
131+
132+
def fetchone(self):
133+
return (1,) if self.exists else None
134+
135+
def close(self):
136+
pass
137+
138+
139+
class FakeConn:
140+
def __init__(self, exists):
141+
self.exists = exists
142+
self.autocommit = False
143+
144+
def cursor(self):
145+
return FakeCursor(self.exists)
146+
147+
def close(self):
148+
pass
149+
150+
151+
@pytest.mark.parametrize("exists,choice", [(False, None), (True, "u"), (True, "o")])
152+
def test_initialize_postgres_local(monkeypatch, tmp_path, exists, choice):
153+
# monkeypatch connect
154+
def fake_connect(**kw):
155+
return FakeConn(exists)
156+
157+
monkeypatch.setattr(psycopg2, "connect", fake_connect)
158+
# monkeypatch input when db exists
159+
if exists and choice:
160+
monkeypatch.setattr("builtins.input", lambda prompt: choice)
161+
cfg = {
162+
"postgres_config": {
163+
"host": "h",
164+
"port": 5432,
165+
"user": "u",
166+
"password": "p",
167+
"data_dir": "d",
168+
"database_folder": "x",
169+
"log_file": "l",
170+
}
171+
}
172+
url = _initialize_postgres_local(cfg, "mydb")
173+
assert "postgresql+psycopg2://" in url
174+
175+
176+
# --- _initialize_postgres_remote tests ---
177+
def test_initialize_postgres_remote_interactive(monkeypatch):
178+
cfg = {
179+
"postgres_config": {
180+
"mode": "remote",
181+
"host": "rh",
182+
"port": 1234,
183+
"user": "ru",
184+
"password": None,
185+
"db_name": "dbn",
186+
"sslmode": "require",
187+
}
188+
}
189+
# ensure no env var
190+
monkeypatch.delenv("PGPASSWORD", raising=False)
191+
# patch getpass
192+
monkeypatch.setattr(getpass, "getpass", lambda prompt: "pwd")
193+
# patch reachability
194+
monkeypatch.setattr(
195+
"codes.tune.postgres_fcts._check_remote_reachable", lambda conf: None
196+
)
197+
url = _initialize_postgres_remote(cfg, "ignored")
198+
assert "sslmode=require" in url
199+
200+
201+
# --- initialize_optuna_database tests ---
202+
def test_initialize_optuna_database_local(monkeypatch):
203+
cfg = {"postgres_config": {"mode": "local"}}
204+
monkeypatch.setattr(
205+
"codes.tune.postgres_fcts._check_postgres_running_local", lambda c: None
206+
)
207+
monkeypatch.setattr(
208+
"codes.tune.postgres_fcts._initialize_postgres_local", lambda c, n: "URL_LOCAL"
209+
)
210+
val = initialize_optuna_database(cfg, "sf")
211+
assert val == "URL_LOCAL"
212+
213+
214+
def test_initialize_optuna_database_remote(monkeypatch):
215+
cfg = {"postgres_config": {"mode": "remote"}}
216+
monkeypatch.setattr(
217+
"codes.tune.postgres_fcts._initialize_postgres_remote", lambda c, n: "URL_REM"
218+
)
219+
val = initialize_optuna_database(cfg, "sf")
220+
assert val == "URL_REM"
221+
222+
223+
def test_initialize_optuna_database_bad_mode():
224+
with pytest.raises(ValueError):
225+
initialize_optuna_database({"postgres_config": {"mode": "foo"}}, "sf")

test/test_tuning_pipeline.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ def basic_params():
6565
"batch_size": {"type": "int", "low": 10, "high": 20, "step": 5},
6666
"learning_rate": {"type": "float", "low": 0.001, "high": 0.01, "log": True},
6767
"activation": {"choices": ["relu", "tanh", "identity"]},
68+
"loss_function": {"choices": ["mse", "smoothl1"]},
6869
}
6970

7071

@@ -86,6 +87,7 @@ def test_make_optuna_params_basic(basic_params):
8687
"batch_size": basic_params["batch_size"],
8788
"learning_rate": basic_params["learning_rate"],
8889
"activation": basic_params["activation"],
90+
"loss_function": basic_params["loss_function"],
8991
},
9092
)
9193
# int param should equal low
@@ -95,6 +97,9 @@ def test_make_optuna_params_basic(basic_params):
9597
# activation returns first choice
9698
expected_cls = MODULE_REGISTRY[basic_params["activation"]["choices"][0]]
9799
assert isinstance(out["activation"], expected_cls)
100+
# loss_function returns first choice
101+
expected_cls = MODULE_REGISTRY[basic_params["loss_function"]["choices"][0]]
102+
assert isinstance(out["loss_function"], expected_cls)
98103

99104

100105
def test_make_optuna_params_conditional(conditional_params):
@@ -293,7 +298,7 @@ def predict(self, loader, leave_log=False):
293298
"codes.tune.optuna_fcts.make_optuna_params", lambda trial, params: {}
294299
)
295300
monkeypatch.setattr(
296-
"codes.tune.optuna_fcts.measure_inference_time", lambda m, l: [1.0, 2.0, 3.0]
301+
"codes.tune.optuna_fcts.measure_inference_time", lambda m, dur: [1.0, 2.0, 3.0]
297302
)
298303
import torch
299304

0 commit comments

Comments
 (0)