Skip to content

Commit 01a1168

Browse files
committed
training: support image sequence filenames, add exposure arg for infer
1 parent e050ac8 commit 01a1168

6 files changed

Lines changed: 117 additions & 76 deletions

File tree

training/compare_image.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ def main():
2525
error('cannot compare different features')
2626

2727
# Load metadata for the images if it exists
28-
tonemap_exposure = cfg.exposure
28+
tonemap_exposure = cfg.exposure if cfg.exposure is not None else 1.
2929
if os.path.dirname(cfg.input[0]) == os.path.dirname(cfg.input[1]):
3030
metadata = load_image_metadata(os.path.commonprefix(cfg.input))
3131
if metadata:

training/config.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -182,8 +182,8 @@ def get_default_device():
182182
parser.add_argument('output', type=str,
183183
help='output image')
184184

185-
if cmd in {'convert_image', 'compare_image'}:
186-
parser.add_argument('--exposure', '-E', type=float, default=1.,
185+
if cmd in {'infer', 'convert_image', 'compare_image'}:
186+
parser.add_argument('--exposure', '-E', type=float, default=None,
187187
help='linear exposure scale for HDR image')
188188

189189
if cmd in {'split_exr'}:

training/convert_image.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ def main():
1818
image, _ = load_image(cfg.input)
1919

2020
# Load metadata for the image if it exists
21-
tonemap_exposure = cfg.exposure
21+
tonemap_exposure = cfg.exposure if cfg.exposure is not None else 1.
2222
metadata = load_image_metadata(cfg.input)
2323
if metadata:
2424
tonemap_exposure = metadata['exposure']

training/dataset.py

Lines changed: 79 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -57,55 +57,62 @@ def shuffle_channels(channels, first_channel, order, num_channels=None):
5757
if num_channels is not None:
5858
del channels[first+num_channels:first+len(new_channels)]
5959

60-
# Checks whether the image with specified features exists
61-
def image_exists(name, features):
62-
suffixes = features.copy()
63-
if 'sh1' in suffixes:
64-
suffixes.remove('sh1')
65-
suffixes += ['sh1x', 'sh1y', 'sh1z']
66-
return all([os.path.isfile(name + '.' + s + '.exr') for s in suffixes])
60+
# Checks whether the specified image feature files exist (name is a format string, e.g. 'image%s.01')
61+
def image_features_exist(name, features):
62+
name = (name % '.%s') + '.exr'
63+
feature_exts = features.copy()
64+
if 'sh1' in feature_exts:
65+
feature_exts.remove('sh1')
66+
feature_exts += ['sh1x', 'sh1y', 'sh1z']
67+
return all([os.path.isfile(name % feature_ext) for feature_ext in feature_exts])
6768

6869
# Returns the feature an image represents given its filename
6970
def get_image_feature(filename):
70-
filename_split = filename.rsplit('.', 2)
71-
if len(filename_split) < 2:
72-
return 'srgb' # no extension, assume sRGB
71+
filename_split = filename.rsplit('.', 3)
72+
if len(filename_split) == 1:
73+
return 'srgb' # no format extension, assume sRGB
7374
else:
74-
ext = filename_split[-1].lower()
75-
if ext in {'exr', 'pfm', 'phm', 'hdr'}:
76-
if len(filename_split) == 3:
75+
format_ext = filename_split[-1].lower()
76+
if format_ext in {'exr', 'pfm', 'phm', 'hdr'}:
77+
# Identify the feature extension, ignore the optional frame number
78+
if len(filename_split) >= 3 and not filename_split[-2].isdecimal():
7779
feature = filename_split[-2]
78-
if feature in {'sh1x', 'sh1y', 'sh1z'}:
79-
feature = 'sh1'
80-
return feature
80+
elif len(filename_split) >= 4 and not filename_split[-3].isdecimal():
81+
feature = filename_split[-3]
8182
else:
82-
return 'hdr' # assume HDR
83+
return 'hdr' # no feature extension, assume HDR
84+
85+
if feature in {'sh1x', 'sh1y', 'sh1z'}:
86+
feature = 'sh1'
87+
return feature
8388
else:
8489
return 'srgb' # assume sRGB
8590

86-
# Loads image features in EXR format with given filename prefix, and returns the concatenated image
87-
# and the number of channels loaded for each feature
91+
# Loads image features in EXR format, where name is a format string returned by
92+
# get_image_sample_groups() (e.g. 'image%s.01'), and returns the concatenated image and the number
93+
# of channels loaded per feature
8894
def load_image_features(name, features):
95+
name = (name % '.%s') + '.exr'
8996
images = []
9097
num_channels = {}
9198

9299
# HDR color
93100
if 'hdr' in features:
94-
hdr, num_channels['hdr'] = load_image(name + '.hdr.exr', num_channels=3)
101+
hdr, num_channels['hdr'] = load_image(name % 'hdr', num_channels=3)
95102
hdr = np.maximum(hdr, 0.)
96103
images.append(hdr)
97104

98105
# LDR color
99106
if 'ldr' in features:
100-
ldr, num_channels['ldr'] = load_image(name + '.ldr.exr', num_channels=3)
107+
ldr, num_channels['ldr'] = load_image(name % 'ldr', num_channels=3)
101108
ldr = np.clip(ldr, 0., 1.)
102109
images.append(ldr)
103110

104111
# SH L1 color coefficients
105112
if 'sh1' in features:
106-
sh1x, num_channels['sh1x'] = load_image(name + '.sh1x.exr', num_channels=3)
107-
sh1y, num_channels['sh1y'] = load_image(name + '.sh1y.exr', num_channels=3)
108-
sh1z, num_channels['sh1z'] = load_image(name + '.sh1z.exr', num_channels=3)
113+
sh1x, num_channels['sh1x'] = load_image(name % 'sh1x', num_channels=3)
114+
sh1y, num_channels['sh1y'] = load_image(name % 'sh1y', num_channels=3)
115+
sh1z, num_channels['sh1z'] = load_image(name % 'sh1z', num_channels=3)
109116

110117
for sh1 in [sh1x, sh1y, sh1z]:
111118
# Clip to [-1..1] range (coefficients are assumed to be normalized)
@@ -118,13 +125,13 @@ def load_image_features(name, features):
118125

119126
# Albedo
120127
if 'alb' in features:
121-
albedo, num_channels['alb'] = load_image(name + '.alb.exr', num_channels=3)
128+
albedo, num_channels['alb'] = load_image(name % 'alb', num_channels=3)
122129
albedo = np.clip(albedo, 0., 1.)
123130
images.append(albedo)
124131

125132
# Normal
126133
if 'nrm' in features:
127-
normal, num_channels['nrm'] = load_image(name + '.nrm.exr', num_channels=3)
134+
normal, num_channels['nrm'] = load_image(name % 'nrm', num_channels=3)
128135
normal = np.clip(normal, -1., 1.)
129136

130137
# Transform to [0..1] range
@@ -134,7 +141,7 @@ def load_image_features(name, features):
134141

135142
# Depth
136143
if 'z' in features:
137-
depth, num_channels['z'] = load_image(name + '.z.exr', num_channels=1)
144+
depth, num_channels['z'] = load_image(name % 'z', num_channels=1)
138145
depth = np.maximum(depth, 0.)
139146

140147
# If this is an auxiliary feature, transform to [0..1] range
@@ -149,19 +156,16 @@ def load_image_features(name, features):
149156
# Concatenate all feature images into one image
150157
return np.concatenate(images, axis=2), num_channels
151158

152-
# Tries to load metadata for an image with given filename/prefix, returns None if it fails
159+
# Loads image metadata in JSON format if it exists, where name is optionally a format string
160+
# returned by get_image_sample_groups() (e.g. 'image%s.01'),
153161
def load_image_metadata(name):
154-
dirname, basename = os.path.split(name)
155-
basename = basename.split('.')[0] # remove all extensions
156-
while basename:
157-
metadata_filename = os.path.join(dirname, basename) + '.json'
158-
if os.path.isfile(metadata_filename):
159-
return load_json(metadata_filename)
160-
if '_' in basename:
161-
basename = basename.rsplit('_', 1)[0]
162-
else:
163-
break
164-
return None
162+
if '%s' in name:
163+
name = name % ''
164+
name += '.json'
165+
if os.path.isfile(name):
166+
return load_json(name)
167+
else:
168+
return {}
165169

166170
# Saves image metadata to a file with given prefix
167171
def save_image_metadata(name, metadata):
@@ -180,19 +184,42 @@ def get_image_sample_groups(dir, input_features, target_features=None):
180184

181185
# Make image groups
182186
image_groups = defaultdict(set)
187+
183188
for filename in image_filenames:
184-
image_name = os.path.relpath(filename, dir) # remove dir path
185-
image_name, _, _ = image_name.rsplit('.', 2) # remove extensions
186-
group = image_name
187-
if '_' in image_name:
188-
prefix, suffix = image_name.rsplit('_', 1)
189-
suffix = suffix.lower()
189+
# Process the image filename
190+
filename = os.path.relpath(filename, dir) # remove directory path
191+
filename = filename.replace('%', '%%') # escape for string formatting
192+
filename_split = filename.rsplit('.', 3)
193+
group_name = filename_split[0]
194+
195+
# Identify the feature extension and the optional frame number
196+
feature_ext_idx = None
197+
frame_str = None
198+
for i in range(1, len(filename_split)-1):
199+
if filename_split[i].isdecimal():
200+
frame_str = filename_split[i]
201+
else:
202+
feature_ext_idx = i
203+
if feature_ext_idx is None:
204+
continue # skip image with missing feature extension
205+
206+
# The image name is a format string which can be passed to load_image_features()
207+
# Replace the feature extension with a placeholder and remove the format extension ('exr')
208+
image_name = filename_split[0]
209+
for i in range(1, len(filename_split)-1):
210+
image_name += '%s' if i == feature_ext_idx else ('.' + filename_split[i])
211+
212+
# Identify the samples per pixel
213+
if '_' in group_name:
214+
prefix, suffix = group_name.rsplit('_', 1)
190215
if (suffix.isdecimal() or
191-
(suffix.endswith('spp') and suffix[:-3].isdecimal()) or
192-
suffix == 'ref' or suffix == 'reference' or
193-
suffix == 'gt' or suffix == 'target'):
194-
group = prefix
195-
image_groups[group].add(image_name)
216+
(suffix.lower().endswith('spp') and suffix[:-3].isdecimal()) or
217+
suffix.lower() in {'ref', 'reference', 'gt', 'target'}):
218+
group_name = prefix
219+
220+
if frame_str is not None:
221+
group_name += '.' + frame_str
222+
image_groups[group_name].add(image_name)
196223

197224
# Make sorted image sample (inputs + target) groups
198225
image_sample_groups = []
@@ -205,8 +232,8 @@ def get_image_sample_groups(dir, input_features, target_features=None):
205232
input_names, target_name = image_names, None
206233

207234
# Check whether all required features exist
208-
if all([image_exists(os.path.join(dir, input_name), input_features) for input_name in input_names]):
209-
if target_name and not image_exists(os.path.join(dir, target_name), target_features):
235+
if all([image_features_exist(os.path.join(dir, input_name), input_features) for input_name in input_names]):
236+
if target_name and not image_features_exist(os.path.join(dir, target_name), target_features):
210237
target_name = None # discard target due to missing features
211238

212239
# Add sample

training/infer.py

Lines changed: 27 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -138,24 +138,24 @@ def main():
138138
metric_count = 0
139139

140140
# Saves an image in different formats
141-
def save_images(path, image, image_srgb, num_channels, feature_ext=infer.main_feature):
141+
def save_images(name, image, image_srgb, num_channels, feature_ext=infer.main_feature):
142142
if feature_ext == 'sh1':
143143
# Iterate over x, y, z
144144
for i, axis in [(0, 'x'), (3, 'y'), (6, 'z')]:
145-
save_images(path, image[:, i:i+3, ...], image_srgb[:, i:i+3, ...], num_channels, 'sh1' + axis)
145+
save_images(name, image[:, i:i+3], image_srgb[:, i:i+3], num_channels, 'sh1' + axis)
146146
return
147147

148148
image = tensor_to_image(image)
149149
image_srgb = tensor_to_image(image_srgb)
150-
filename_prefix = path + '.' + feature_ext + '.'
151150
for format in cfg.format:
151+
filename = (name % ('.' + feature_ext)) + '.' + format
152152
if format in {'exr', 'pfm', 'phm', 'hdr'}:
153153
# Transform to original range
154154
if infer.main_feature in {'sh1', 'nrm'}:
155155
image = image * 2. - 1. # [0..1] -> [-1..1]
156-
save_image(filename_prefix + format, image, num_channels=num_channels[feature_ext])
156+
save_image(filename, image, num_channels=num_channels[feature_ext])
157157
else:
158-
save_image(filename_prefix + format, image_srgb, num_channels=num_channels[feature_ext])
158+
save_image(filename, image_srgb, num_channels=num_channels[feature_ext])
159159

160160
with torch.inference_mode():
161161
for group, input_names, target_name in image_sample_groups:
@@ -164,25 +164,37 @@ def save_images(path, image, image_srgb, num_channels, feature_ext=infer.main_fe
164164
if not os.path.isdir(output_group_dir):
165165
os.makedirs(output_group_dir)
166166

167-
# Load metadata for the images if it exists
168-
tonemap_exposure = 1.
169-
metadata = load_image_metadata(os.path.join(data_dir, group))
170-
if metadata:
171-
tonemap_exposure = metadata['exposure']
172-
save_image_metadata(os.path.join(output_dir, group), metadata)
173-
174167
# Load the target image if it exists
168+
target_tonemap_exposure = 1.
169+
175170
if target_name:
176171
target, target_num_channels = load_image_features(os.path.join(data_dir, target_name), infer.main_feature)
177172
target = image_to_tensor(target, batch=True).to(device)
178-
target_srgb = transform_feature(target, infer.main_feature, 'srgb', tonemap_exposure)
173+
target_metadata = load_image_metadata(os.path.join(data_dir, target_name))
174+
175+
# Determine the tonemap exposure for the target image
176+
if cfg.exposure is not None:
177+
target_tonemap_exposure = cfg.exposure
178+
elif target_metadata and 'exposure' in target_metadata:
179+
target_tonemap_exposure = target_metadata['exposure']
180+
181+
target_srgb = transform_feature(target, infer.main_feature, 'srgb', target_tonemap_exposure)
179182

180183
# Iterate over the input images
181184
for input_name in input_names:
182-
print(input_name, '...', end='', flush=True)
185+
print(input_name % '', '...', end='', flush=True)
183186

184187
# Load the input image
185188
input, input_num_channels = load_image_features(os.path.join(data_dir, input_name), infer.features)
189+
input_metadata = load_image_metadata(os.path.join(data_dir, input_name))
190+
191+
# Determine the tonemap exposure for the input image
192+
if cfg.exposure is not None:
193+
tonemap_exposure = cfg.exposure
194+
elif input_metadata and 'exposure' in input_metadata:
195+
tonemap_exposure = input_metadata['exposure']
196+
else:
197+
tonemap_exposure = target_tonemap_exposure
186198

187199
# Compute the autoexposure value
188200
exposure = autoexposure(input) if infer.main_feature == 'hdr' else 1.
@@ -208,7 +220,7 @@ def save_images(path, image, image_srgb, num_channels, feature_ext=infer.main_fe
208220

209221
# Save the input and output images
210222
output_suffix = cfg.result if cfg.output_suffix is None else cfg.output_suffix
211-
output_name = input_name + '.' + output_suffix
223+
output_name = input_name % f'.{output_suffix}%s'
212224
if cfg.num_epochs:
213225
output_name += f'_{epoch}'
214226
if cfg.save_all:

training/preprocess.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,8 @@ def preprocess_sample_group(input_dir, output_tza, input_names, target_name):
7777
samples = []
7878

7979
# Load the target image
80-
print(target_name)
80+
target_id = target_name % ''
81+
print(target_id)
8182
target_image, _ = load_image_features(os.path.join(input_dir, target_name), target_features)
8283

8384
# Compute the autoexposure value
@@ -87,12 +88,13 @@ def preprocess_sample_group(input_dir, output_tza, input_names, target_name):
8788
target_image = preprocess_image(target_image, exposure)
8889

8990
# Save the target image
90-
output_tza.write(target_name, target_image, 'hwc')
91+
output_tza.write(target_id, target_image, 'hwc')
9192

9293
# Process the input images
9394
for input_name in input_names:
9495
# Load the image
95-
print(input_name)
96+
input_id = input_name % ''
97+
print(input_id)
9698
input_image, _ = load_image_features(os.path.join(input_dir, input_name), input_features)
9799

98100
if input_image.shape[0:2] != target_image.shape[0:2]:
@@ -102,10 +104,10 @@ def preprocess_sample_group(input_dir, output_tza, input_names, target_name):
102104
input_image = preprocess_image(input_image, exposure, prefilter=True)
103105

104106
# Save the image
105-
output_tza.write(input_name, input_image, 'hwc')
107+
output_tza.write(input_id, input_image, 'hwc')
106108

107109
# Add sample
108-
samples.append((input_name, target_name))
110+
samples.append((input_id, target_id))
109111

110112
return samples
111113

0 commit comments

Comments
 (0)