-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathFlowers.py
More file actions
65 lines (45 loc) · 1.73 KB
/
Flowers.py
File metadata and controls
65 lines (45 loc) · 1.73 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import tensorflow_hub as hub
import tensorflow_datasets as tfds
from tensorflow.keras import layers
import logging
logger = tf.get_logger()
logger.setLevel(logging.ERROR)
# download dataset
(training_set, validation_set), dataset_info = tfds.load(
'tf_flowers', # name of dataset to download
split=['train[:70%]', 'train[70%:]'],
with_info=True,
as_supervised=True
)
num_classes = dataset_info.features['label'].num_classes
num_training_examples = 0
num_validation_examples = 0
for example in training_set:
num_training_examples += 1
for example in validation_set:
num_validation_examples += 1
print('Total Number of Classes: {}'.format(num_classes))
print('Total Number of Training Images: {}'.format(num_training_examples))
print('Total Number of Validation Images: {} \n'.format(num_validation_examples))
for i, example in enumerate(training_set.take(5)):
print('Image {} shape: {} label: {}'.format(i + 1, example[0].shape, example[1]))
IMAGE_RES = 224
def format_image(image, label):
image = tf.image.resize(image, (IMAGE_RES, IMAGE_RES)) / 255.0
return image, label
BATCH_SIZE = 32
train_batches = training_set.shuffle(num_training_examples // 4).map(format_image).batch(BATCH_SIZE).prefetch(1)
validation_batches = validation_set.map(format_image).batch(BATCH_SIZE).prefetch(1)
# Feature extractor creation
URL = "https://tfhub.dev/google/tf2-preview/mobilenet_v2/feature_vector/4"
feature_extractor = hub.KerasLayer(URL,
input_shape=(IMAGE_RES, IMAGE_RES, 3))
feature_extractor.trainable = False
model = tf.keras.Sequential([
feature_extractor,
layers.Dense(num_classes)
])
model.summary()