Skip to content

Commit e3644d9

Browse files
committed
refactor: modify structure according to new setup
1 parent 318d7cc commit e3644d9

12 files changed

Lines changed: 106 additions & 202737 deletions

File tree

guild.yml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
- model: nvae-mixture-logistic
44
sourcecode:
55
- '*.py'
6-
- 'vae/splits/*.json'
76
- guild.yml
87
- exclude:
98
dir:

update_splits.py

Lines changed: 0 additions & 6 deletions
This file was deleted.

vae/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from vae import problem
22
from vae import datastream
33
from vae import architecture
4+
from vae import metrics
45

6+
from vae.log_examples import log_examples
57
from vae.train import train
Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,14 @@
1-
import numpy as np
2-
import pandas as pd
31
from datastream import Datastream
42

5-
from vae import problem, splits
3+
from vae import problem
64

75

86
def evaluate_datastreams():
9-
evaluate_datasets = problem.evaluate_datasets()
7+
datasets = problem.datasets()
108
return {
11-
split_name: Datastream(dataset)
9+
split_name: Datastream(dataset).take(256)
1210
for split_name, dataset in dict(
13-
gradient=evaluate_datasets['train'],
14-
compare=evaluate_datasets['compare'],
11+
gradient=datasets['train'],
12+
compare=datasets['compare'],
1513
).items()
1614
}
Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,12 @@
1-
import numpy as np
21
from datastream import Datastream
32

4-
from vae.datastream import (
5-
evaluate_datastreams, augmenter
6-
)
3+
from vae import problem
4+
from vae.datastream import augmenter
75

86

9-
def GradientDatastream():
7+
def GradientDatastream():
108
augmenter_ = augmenter()
119
return (
12-
evaluate_datastreams()['gradient']
10+
Datastream(problem.datasets()['gradient'])
1311
.map(lambda example: example.augment(augmenter_))
1412
)

vae/log_examples.py

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
import numpy as np
2+
import torch
3+
from workflow.torch import module_eval
4+
5+
from vae import architecture
6+
7+
8+
def log_examples(description, trainer, model):
9+
def log_examples_(engine, logger, event_name):
10+
n_examples = 5
11+
indices = np.random.choice(
12+
len(engine.state.output['predictions']),
13+
n_examples,
14+
replace=False,
15+
)
16+
17+
logger.writer.add_images(
18+
f'{description}/predictions',
19+
np.stack([
20+
np.concatenate([
21+
np.array(
22+
engine.state.output['examples'][index]
23+
.representation()
24+
),
25+
np.array(
26+
engine.state.output['predictions'][index]
27+
.representation()
28+
),
29+
], axis=0) / 255
30+
for index in indices
31+
]),
32+
trainer.state.epoch,
33+
dataformats='NHWC',
34+
)
35+
36+
with torch.no_grad(), module_eval(model) as eval_model:
37+
std_samples = [
38+
eval_model.generated(16, prior_std)
39+
for prior_std in np.linspace(0.4, 1.1, num=8)
40+
]
41+
42+
logger.writer.add_images(
43+
f'{description}/samples',
44+
np.stack([np.concatenate([
45+
np.concatenate([
46+
np.array(sample.representation())
47+
for sample in samples
48+
], axis=1)
49+
for samples in std_samples
50+
], axis=0)]) / 255,
51+
trainer.state.epoch,
52+
dataformats='NHWC',
53+
)
54+
55+
with torch.no_grad(), module_eval(model) as eval_model:
56+
partial_samples = [
57+
eval_model.partially_generated(
58+
architecture.FeaturesBatch.from_examples(
59+
[
60+
engine.state.output['examples'][index]
61+
for index in indices
62+
]
63+
).image_batch,
64+
sample=[
65+
index == sample_index
66+
for index in range(model.levels)
67+
],
68+
prior_std=0.7,
69+
)
70+
for sample_index in range(model.levels)
71+
]
72+
73+
logger.writer.add_images(
74+
f'{description}/partially_sampled',
75+
np.concatenate([
76+
np.stack([
77+
np.array(sample.representation())
78+
for sample in samples
79+
])
80+
for samples in partial_samples
81+
], axis=1) / 255,
82+
trainer.state.epoch,
83+
dataformats='NHWC',
84+
)
85+
86+
return log_examples_

vae/problem/__init__.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,2 @@
11
from vae.problem.example import Example
2-
from vae.problem.evaluate_datasets import (
3-
evaluate_datasets
4-
)
2+
from vae.problem.datasets import datasets
Lines changed: 3 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import pandas as pd
44
from datastream import Dataset
55

6-
from vae import problem, splits
6+
from vae import problem
77

88

99
def dataframe():
@@ -30,25 +30,12 @@ def dataset(dataframe):
3030
)
3131

3232

33-
def evaluate_datasets(frozen=True):
34-
datasets = (
33+
def datasets(frozen=True):
34+
return (
3535
dataset(dataframe())
3636
.split(
3737
key_column='key',
3838
proportions=dict(train=0.8, compare=0.2),
39-
# filepath=splits.compare,
40-
# frozen=frozen,
41-
# don't save split until we have a solution for remote training
42-
# guild is not saving .json as sourcecode to remote
4339
seed=177,
4440
)
4541
)
46-
# TODO: temporary
47-
return dict(
48-
train=datasets['train'],
49-
compare=datasets['compare'].split(
50-
key_column='key',
51-
proportions=dict(keep=0.05, throw=0.95),
52-
seed=523,
53-
)['keep'],
54-
)

vae/splits/.gitignore

Whitespace-only changes.

vae/splits/__init__.py

Lines changed: 0 additions & 5 deletions
This file was deleted.

0 commit comments

Comments
 (0)