Skip to content

Commit 0a6dfc4

Browse files
Merge pull request #45 from robin-janssen/improve-unit-tests
Improve unit tests
2 parents 66622ed + 1768e8d commit 0a6dfc4

32 files changed

Lines changed: 5471 additions & 3102 deletions

.github/workflows/ci.yaml

Lines changed: 37 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -21,32 +21,31 @@ jobs:
2121

2222
strategy:
2323
matrix:
24-
python-version: ['3.10']
24+
python-version: ["3.10"]
2525

2626
steps:
2727
- name: Check out the repository
2828
uses: actions/checkout@v4
2929
with:
30-
fetch-depth: 0 # Fetch all history to support dependency checks
30+
fetch-depth: 0 # Fetch all history to support dependency checks
3131
ref: ${{ github.head_ref || github.ref }}
3232

3333
- name: Set up Python
3434
uses: actions/setup-python@v5
3535
with:
3636
python-version: ${{ matrix.python-version }}
3737

38-
- name: Install Poetry
38+
- name: Install uv
3939
run: |
40-
curl -sSL https://install.python-poetry.org | python3 -
41-
echo "$HOME/.local/bin" >> $GITHUB_PATH
40+
curl -LsSf https://astral.sh/uv/install.sh | sh
41+
echo "$HOME/.cargo/bin" >> $GITHUB_PATH
4242
43-
- name: Install Poetry Plugin for Export
43+
- name: Setup Python environment with uv
4444
run: |
45-
poetry self add poetry-plugin-export
46-
47-
- name: Install project with dev dependencies
48-
run: |
49-
poetry install --with dev
45+
uv venv
46+
source .venv/bin/activate
47+
uv pip install -e .
48+
uv pip install --group dev
5049
5150
- name: Check for dependency changes
5251
id: deps
@@ -66,8 +65,8 @@ jobs:
6665
6766
echo "Changed files: $CHANGED_FILES"
6867
69-
# Check if pyproject.toml or poetry.lock has changed
70-
if echo "$CHANGED_FILES" | grep -qE '(^|/)pyproject\.toml$|(^|/)poetry\.lock$'; then
68+
# Check if pyproject.toml has changed
69+
if echo "$CHANGED_FILES" | grep -qE '(^|/)pyproject\.toml$'; then
7170
echo "dependencies-changed=true" >> $GITHUB_ENV
7271
else
7372
echo "dependencies-changed=false" >> $GITHUB_ENV
@@ -76,7 +75,8 @@ jobs:
7675
- name: Generate requirements.txt
7776
if: env.dependencies-changed == 'true'
7877
run: |
79-
poetry export -f requirements.txt --output requirements.txt --without-hashes
78+
source .venv/bin/activate
79+
uv pip freeze > requirements.txt
8080
8181
- name: Commit and push updated requirements.txt
8282
if: |
@@ -92,58 +92,66 @@ jobs:
9292
9393
- name: Run Black (auto-reformat)
9494
run: |
95-
poetry run black .
95+
source .venv/bin/activate
96+
black .
9697
9798
- name: Run isort (auto-reformat)
9899
run: |
99-
poetry run isort .
100+
source .venv/bin/activate
101+
isort .
100102
101103
- name: Run pytest and generate coverage report
102104
run: |
103-
poetry run pytest --cov-report=term-missing:skip-covered --cov=codes test/ --cov-report=xml:coverage.xml
105+
source .venv/bin/activate
106+
pytest --cov-report=term-missing:skip-covered --cov=codes test/ --cov-report=xml:coverage.xml
104107
105108
- name: Upload results to Codecov
106109
uses: codecov/codecov-action@v4
107110
with:
108-
token: ${{ secrets.CODECOV_TOKEN }}
109-
files: ./coverage.xml # Path to the coverage XML file, adjust if necessary
110-
fail_ci_if_error: true # Optional, ensures the CI fails if Codecov upload fails
111+
token: ${{ secrets.CODECOV_TOKEN }}
112+
files: ./coverage.xml # Path to the coverage XML file, adjust if necessary
113+
fail_ci_if_error: true # Optional, ensures the CI fails if Codecov upload fails
111114

112115
docs:
113116
if: github.ref == 'refs/heads/main' || (github.event_name == 'pull_request' && github.event.pull_request.base.ref == 'main')
114117
runs-on: ubuntu-latest
115118

116119
strategy:
117120
matrix:
118-
python-version: ['3.10']
121+
python-version: ["3.10"]
119122

120123
steps:
121124
- name: Check out the repository
122125
uses: actions/checkout@v4
123126
with:
124-
fetch-depth: 0 # Fetch all history if necessary
127+
fetch-depth: 0 # Fetch all history if necessary
125128

126129
- name: Set up Python
127130
uses: actions/setup-python@v5
128131
with:
129132
python-version: ${{ matrix.python-version }}
130133

131-
- name: Install Poetry
134+
- name: Install uv
132135
run: |
133-
curl -sSL https://install.python-poetry.org | python3 -
134-
echo "$HOME/.local/bin" >> $GITHUB_PATH
136+
curl -LsSf https://astral.sh/uv/install.sh | sh
137+
echo "$HOME/.cargo/bin" >> $GITHUB_PATH
135138
136-
- name: Install project with dev dependencies
139+
- name: Setup Python environment with uv
137140
run: |
138-
poetry install --with dev
141+
uv venv
142+
source .venv/bin/activate
143+
uv pip install -e .
144+
uv pip install --group dev
139145
140146
- name: Generate API Documentation with Sphinx
141147
run: |
142-
poetry run sphinx-apidoc -o docs/ codes
148+
source .venv/bin/activate
149+
sphinx-apidoc -o docs/ codes
143150
144151
- name: Build HTML with Sphinx
145152
run: |
146-
poetry run sphinx-build -b html docs/ docs/_build
153+
source .venv/bin/activate
154+
sphinx-build -b html docs/ docs/_build
147155
148156
- name: Deploy Sphinx API docs to gh-pages
149157
uses: peaceiris/actions-gh-pages@v4

.python-version

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
3.10

AGENT.md

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
# CODES Benchmark - Agent Guide
2+
3+
## Build/Test Commands
4+
- Install dependencies: `poetry install` or `pip install -r requirements.txt`
5+
- Run tests: `pytest test/` or `python -m pytest test/`
6+
- Run single test: `pytest test/test_data.py::test_function_name`
7+
- Run training: `python run_training.py --config config.yaml`
8+
- Run evaluation: `python run_eval.py --config config.yaml`
9+
- Run hyperparameter tuning: `python run_tuning.py --config config.yaml`
10+
11+
## Architecture & Structure
12+
- `codes/` - Main package with 4 modules: benchmark, surrogates, train, tune, utils
13+
- `codes/surrogates/` - Neural network surrogate models (NN, DeepONet, NeuralODE, Polynomial)
14+
- `codes/benchmark/` - Core benchmarking logic and evaluation metrics
15+
- `codes/train/` - Parallel/sequential training infrastructure with task queues
16+
- `codes/utils/` - Data handling, config management, progress bars, seeding
17+
- `datasets/` - Training data in HDF5 format, `trained/` - Model checkpoints, `results/` - Benchmark outputs
18+
- Uses PyTorch, torchode for ODE solving, Optuna for hyperparameter tuning, PostgreSQL for study storage
19+
20+
## Code Style & Conventions
21+
- Python 3.10+, type hints required (e.g., `str | None`, `int | None`)
22+
- Use `__all__` in `__init__.py` files for explicit exports
23+
- Function docstrings with Args/Returns sections
24+
- Classes use PascalCase, functions/variables use snake_case
25+
- Threading with explicit locks, use DummyLock() for single-threaded contexts
26+
- Config-driven architecture with YAML files, use `read_yaml_config()` utility

codes/benchmark/bench_fcts.py

Lines changed: 0 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1019,37 +1019,6 @@ def load_losses(model_identifier: str):
10191019
)
10201020

10211021

1022-
# def compare_MAE(metrics: dict, config: dict) -> None:
1023-
# """
1024-
# Compare the MAE of different surrogate models over the course of training.
1025-
1026-
# Args:
1027-
# metrics (dict): dictionary containing the benchmark metrics for each surrogate model.
1028-
# config (dict): Configuration dictionary.
1029-
1030-
# Returns:
1031-
# None
1032-
# """
1033-
# MAE = []
1034-
# labels = []
1035-
# train_durations = []
1036-
# device = config["devices"]
1037-
# device = device[0] if isinstance(device, list) else device
1038-
1039-
# for surr_name, _ in metrics.items():
1040-
# training_id = config["training_id"]
1041-
# surrogate_class = get_surrogate(surr_name)
1042-
# n_timesteps = metrics[surr_name]["timesteps"].shape[0]
1043-
# n_quantities = metrics[surr_name]["accuracy"]["absolute_errors"].shape[2]
1044-
# model_config = get_model_config(surr_name, config)
1045-
# model = surrogate_class(device, n_quantities, n_timesteps, model_config)
1046-
# model_identifier = f"{surr_name.lower()}_main"
1047-
# model.load(training_id, surr_name, model_identifier=model_identifier)
1048-
# MAE.append(model.MAE)
1049-
# labels.append(surr_name)
1050-
# train_durations.append(model.train_duration)
1051-
1052-
10531022
def compare_relative_errors(metrics: dict[str, dict], config: dict) -> None:
10541023
"""
10551024
Compare the relative errors over time for different surrogate models.

codes/surrogates/DeepONet/deeponet.py

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -359,11 +359,15 @@ def create_dataloader(
359359
shuffle: bool,
360360
dataset_params: np.ndarray | None,
361361
params_in_branch: bool,
362-
num_workers: int = 0, # will usually stay 0 here
362+
num_workers: int = 0,
363363
pin_memory: bool = True,
364364
):
365365
n_samples, n_timesteps, n_quantities = data.shape
366366

367+
if pin_memory:
368+
if "cuda" not in self.device:
369+
pin_memory = False
370+
367371
branch = np.repeat(data[:, 0, :], n_timesteps, axis=0) # (total, n_q)
368372
trunk = np.tile(timesteps.reshape(1, -1), (n_samples, 1)).reshape(
369373
-1, 1
@@ -419,21 +423,23 @@ def prepare_data(
419423
dataset_train,
420424
timesteps,
421425
batch_size,
422-
True,
426+
shuffle=shuffle,
423427
dataset_params=dataset_train_params,
424428
params_in_branch=self.config.params_branch,
425429
num_workers=nw,
426430
)
427431

428-
test_loader = self.create_dataloader(
429-
dataset_test,
430-
timesteps,
431-
batch_size,
432-
False,
433-
dataset_params=dataset_test_params,
434-
params_in_branch=self.config.params_branch,
435-
num_workers=nw,
436-
)
432+
test_loader = None
433+
if dataset_test is not None:
434+
test_loader = self.create_dataloader(
435+
dataset_test,
436+
timesteps,
437+
batch_size,
438+
False,
439+
dataset_params=dataset_test_params,
440+
params_in_branch=self.config.params_branch,
441+
num_workers=nw,
442+
)
437443

438444
val_loader = None
439445
if dataset_val is not None:

codes/surrogates/FCNN/fcnn.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -310,6 +310,10 @@ def create_dataloader(
310310
n_samples, n_timesteps, n_quantities = data.shape
311311
total = n_samples * n_timesteps
312312

313+
if pin_memory:
314+
if "cuda" not in self.device:
315+
pin_memory = False
316+
313317
init_states = data[:, 0, :] # (n_samples, n_quantities)
314318
rep_init = np.repeat(
315319
init_states[:, None, :], n_timesteps, 1

codes/surrogates/LatentNeuralODE/latent_neural_ode.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,10 @@ def create_dataloader(
132132
num_workers: int = 0,
133133
pin_memory: bool = True,
134134
):
135+
if pin_memory:
136+
if "cuda" not in self.device:
137+
pin_memory = False
138+
135139
data_t = torch.from_numpy(data).float()
136140
t_t = torch.from_numpy(timesteps).float()
137141
if dataset_params is not None:

codes/surrogates/LatentPolynomial/latent_poly.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,10 @@ def create_dataloader(
9292
num_workers: int = 0,
9393
pin_memory: bool = True,
9494
):
95+
if pin_memory:
96+
if "cuda" not in self.device:
97+
pin_memory = False
98+
9599
data_t = torch.from_numpy(data).float()
96100
t_t = torch.from_numpy(timesteps).float()
97101
params_t = (

codes/train/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
train_and_save_model,
55
create_task_list_for_surrogate,
66
worker,
7+
DummyLock,
78
)
89

910
__all__ = [
@@ -12,4 +13,5 @@
1213
"train_and_save_model",
1314
"create_task_list_for_surrogate",
1415
"worker",
16+
"DummyLock",
1517
]

codes/tune/__init__.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,15 @@
1111
maybe_set_runtime_threshold,
1212
training_run,
1313
)
14-
from .postgres_fcts import _make_db_url, initialize_optuna_database
14+
from .postgres_fcts import (
15+
_make_db_url,
16+
initialize_optuna_database,
17+
_check_postgres_running_local,
18+
_start_postgres_server_local,
19+
_check_remote_reachable,
20+
_initialize_postgres_local,
21+
_initialize_postgres_remote,
22+
)
1523
from .tune_utils import (
1624
build_study_names,
1725
copy_config,
@@ -37,4 +45,9 @@
3745
"yes_no",
3846
"_make_db_url",
3947
"initialize_optuna_database",
48+
"_check_postgres_running_local",
49+
"_start_postgres_server_local",
50+
"_check_remote_reachable",
51+
"_initialize_postgres_local",
52+
"_initialize_postgres_remote",
4053
]

0 commit comments

Comments
 (0)