-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmain.py
More file actions
21 lines (18 loc) · 763 Bytes
/
main.py
File metadata and controls
21 lines (18 loc) · 763 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
import tensorflow as tf
from dcgan import DCGAN
batch_size = 64
img_dim = 28
num_channels = 3
num_epochs = 2000
save_interval = 200
def process_path(file_path):
# Decode image as RGB
image = tf.image.decode_png(tf.io.read_file(file_path), channels=num_channels)
# some mapping to constant size - be careful with distorting aspect ratios
image = tf.image.resize(image, (img_dim, img_dim))
# color normalization to [-1, 1]
image = (tf.cast(image, tf.float32) - 127.5) / 127.5
return image
ds = tf.data.Dataset.list_files(str('pokemon/*.png'), shuffle=True).map(process_path)
model = DCGAN(ds, img_dim, num_channels, batch_size, model_dir='./dcgan28/models/', img_dir='./dcgan28/images/')
model.train(epochs=num_epochs, save_interval=save_interval)