Skip to content

Commit 7217ba2

Browse files
authored
Add output mask output format
1 parent 3e06792 commit 7217ba2

13 files changed

Lines changed: 144 additions & 46 deletions

tests.py

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,7 @@ def test_generate_data_with_format(self):
150150
0,
151151
(5, 5, 5, 5),
152152
0,
153+
0,
153154
)
154155

155156
self.assertTrue(
@@ -184,6 +185,7 @@ def test_generate_data_with_extension(self):
184185
0,
185186
(5, 5, 5, 5),
186187
0,
188+
0,
187189
)
188190

189191
self.assertTrue(
@@ -218,6 +220,7 @@ def test_generate_data_with_skew_angle(self):
218220
0,
219221
(5, 5, 5, 5),
220222
0,
223+
0,
221224
)
222225

223226
self.assertTrue(
@@ -252,6 +255,7 @@ def test_generate_data_with_blur(self):
252255
0,
253256
(5, 5, 5, 5),
254257
0,
258+
0,
255259
)
256260

257261
self.assertTrue(
@@ -286,6 +290,7 @@ def test_generate_data_with_sine_distorsion(self):
286290
0,
287291
(5, 5, 5, 5),
288292
0,
293+
0,
289294
)
290295

291296
self.assertTrue(
@@ -320,6 +325,7 @@ def test_generate_data_with_cosine_distorsion(self):
320325
0,
321326
(5, 5, 5, 5),
322327
0,
328+
0,
323329
)
324330

325331
self.assertTrue(
@@ -354,6 +360,7 @@ def test_generate_data_with_left_alignment(self):
354360
0,
355361
(5, 5, 5, 5),
356362
0,
363+
0,
357364
)
358365

359366
self.assertTrue(
@@ -388,6 +395,7 @@ def test_generate_data_with_center_alignment(self):
388395
0,
389396
(5, 5, 5, 5),
390397
0,
398+
0,
391399
)
392400

393401
self.assertTrue(
@@ -422,6 +430,7 @@ def test_generate_data_with_right_alignment(self):
422430
0,
423431
(5, 5, 5, 5),
424432
0,
433+
0,
425434
)
426435

427436
self.assertTrue(
@@ -457,6 +466,7 @@ def test_raise_if_handwritten_and_vertical(self):
457466
0,
458467
(5, 5, 5, 5),
459468
0,
469+
0,
460470
)
461471
raise Exception("Vertical handwritten did not throw")
462472
except ValueError:
@@ -487,6 +497,7 @@ def test_generate_vertical_text(self):
487497
0,
488498
(5, 5, 5, 5),
489499
0,
500+
0,
490501
)
491502

492503
self.assertTrue(
@@ -521,6 +532,7 @@ def test_generate_horizontal_text_with_variable_space(self):
521532
0,
522533
(5, 5, 5, 5),
523534
0,
535+
0,
524536
)
525537

526538
self.assertTrue(
@@ -555,6 +567,7 @@ def test_generate_vertical_text_with_variable_space(self):
555567
0,
556568
(5, 5, 5, 5),
557569
0,
570+
0,
558571
)
559572

560573
self.assertTrue(
@@ -590,6 +603,7 @@ def test_generate_text_with_unknown_orientation(self):
590603
0,
591604
(5, 5, 5, 5),
592605
0,
606+
0,
593607
)
594608
raise Exception("Unknown orientation did not throw")
595609
except ValueError:
@@ -620,6 +634,7 @@ def test_generate_data_with_fit(self):
620634
0,
621635
(0, 0, 0, 0),
622636
1,
637+
0,
623638
)
624639

625640
self.assertTrue(
@@ -663,18 +678,6 @@ def test_generate_data_with_white_background(self):
663678

664679
os.remove("tests/out/white_background.jpg")
665680

666-
def test_generate_data_with_gaussian_background(self):
667-
background_generator.gaussian_noise(64, 128).convert("RGB").save(
668-
"tests/out/gaussian_background.jpg"
669-
)
670-
671-
self.assertTrue(
672-
md5("tests/out/gaussian_background.jpg")
673-
== md5("tests/expected_results/gaussian_background.jpg")
674-
)
675-
676-
os.remove("tests/out/gaussian_background.jpg")
677-
678681
def test_generate_data_with_quasicrystal_background(self):
679682
bkgd = background_generator.quasicrystal(64, 128)
680683

200 Bytes
Loading
206 Bytes
Loading
-2.09 KB
Binary file not shown.

trdg/computer_text_generator.py

Lines changed: 25 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,11 @@ def _generate_horizontal_text(
3030
text_height = max([image_font.getsize(c)[1] for c in text])
3131

3232
txt_img = Image.new("RGBA", (text_width, text_height), (0, 0, 0, 0))
33+
txt_mask = Image.new("RGB", (text_width, text_height), (0, 0, 0))
3334

34-
txt_draw = ImageDraw.Draw(txt_img)
35+
txt_img_draw = ImageDraw.Draw(txt_img)
36+
txt_mask_draw = ImageDraw.Draw(txt_mask, mode="RGB")
37+
txt_mask_draw.fontmode = "1"
3538

3639
colors = [ImageColor.getrgb(c) for c in text_color.split(",")]
3740
c1, c2 = colors[0], colors[-1]
@@ -43,17 +46,23 @@ def _generate_horizontal_text(
4346
)
4447

4548
for i, c in enumerate(text):
46-
txt_draw.text(
49+
txt_img_draw.text(
4750
(sum(char_widths[0:i]) + i * character_spacing, 0),
4851
c,
4952
fill=fill,
5053
font=image_font,
5154
)
55+
txt_mask_draw.text(
56+
(sum(char_widths[0:i]) + i * character_spacing, 0),
57+
c,
58+
fill=((i + 1) // (255 * 255), (i + 1) // 255, (i + 1) % 255),
59+
font=image_font,
60+
)
5261

5362
if fit:
54-
return txt_img.crop(txt_img.getbbox())
63+
return txt_img.crop(txt_img.getbbox()), txt_mask.crop(txt_img.getbbox())
5564
else:
56-
return txt_img
65+
return txt_img, txt_mask
5766

5867

5968
def _generate_vertical_text(
@@ -70,8 +79,10 @@ def _generate_vertical_text(
7079
text_height = sum(char_heights) + character_spacing * len(text)
7180

7281
txt_img = Image.new("RGBA", (text_width, text_height), (0, 0, 0, 0))
82+
txt_mask = Image.new("RGBA", (text_width, text_height), (0, 0, 0, 0))
7383

74-
txt_draw = ImageDraw.Draw(txt_img)
84+
txt_img_draw = ImageDraw.Draw(txt_img)
85+
txt_mask_draw = ImageDraw.Draw(txt_img)
7586

7687
colors = [ImageColor.getrgb(c) for c in text_color.split(",")]
7788
c1, c2 = colors[0], colors[-1]
@@ -83,14 +94,20 @@ def _generate_vertical_text(
8394
)
8495

8596
for i, c in enumerate(text):
86-
txt_draw.text(
97+
txt_img_draw.text(
8798
(0, sum(char_heights[0:i]) + i * character_spacing),
8899
c,
89100
fill=fill,
90101
font=image_font,
91102
)
103+
txt_mask_draw.text(
104+
(0, sum(char_heights[0:i]) + i * character_spacing),
105+
c,
106+
fill=(i // (255 * 255), i // 255, i % 255),
107+
font=image_font,
108+
)
92109

93110
if fit:
94-
return txt_img.crop(txt_img.getbbox())
111+
return txt_img.crop(txt_img.getbbox()), txt_mask.crop(txt_img.getbbox())
95112
else:
96-
return txt_img
113+
return txt_img, txt_mask

trdg/data_generator.py

Lines changed: 46 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ def generate(
4646
character_spacing,
4747
margins,
4848
fit,
49+
output_mask,
4950
):
5051
image = None
5152

@@ -61,7 +62,7 @@ def generate(
6162
raise ValueError("Vertical handwritten text is unavailable")
6263
image = handwritten_text_generator.generate(text, text_color)
6364
else:
64-
image = computer_text_generator.generate(
65+
image, mask = computer_text_generator.generate(
6566
text,
6667
font,
6768
text_color,
@@ -71,33 +72,40 @@ def generate(
7172
character_spacing,
7273
fit,
7374
)
74-
7575
random_angle = rnd.randint(0 - skewing_angle, skewing_angle)
7676

7777
rotated_img = image.rotate(
7878
skewing_angle if not random_skew else random_angle, expand=1
7979
)
8080

81+
rotated_mask = mask.rotate(
82+
skewing_angle if not random_skew else random_angle, expand=1
83+
)
84+
8185
#############################
8286
# Apply distorsion to image #
8387
#############################
8488
if distorsion_type == 0:
8589
distorted_img = rotated_img # Mind = blown
90+
distorted_mask = rotated_mask
8691
elif distorsion_type == 1:
87-
distorted_img = distorsion_generator.sin(
92+
distorted_img, distorted_mask = distorsion_generator.sin(
8893
rotated_img,
94+
rotated_mask,
8995
vertical=(distorsion_orientation == 0 or distorsion_orientation == 2),
9096
horizontal=(distorsion_orientation == 1 or distorsion_orientation == 2),
9197
)
9298
elif distorsion_type == 2:
93-
distorted_img = distorsion_generator.cos(
99+
distorted_img, distorted_mask = distorsion_generator.cos(
94100
rotated_img,
101+
rotated_mask,
95102
vertical=(distorsion_orientation == 0 or distorsion_orientation == 2),
96103
horizontal=(distorsion_orientation == 1 or distorsion_orientation == 2),
97104
)
98105
else:
99-
distorted_img = distorsion_generator.random(
106+
distorted_img, distorted_mask = distorsion_generator.random(
100107
rotated_img,
108+
rotated_mask,
101109
vertical=(distorsion_orientation == 0 or distorsion_orientation == 2),
102110
horizontal=(distorsion_orientation == 1 or distorsion_orientation == 2),
103111
)
@@ -115,6 +123,7 @@ def generate(
115123
resized_img = distorted_img.resize(
116124
(new_width, size - vertical_margin), Image.ANTIALIAS
117125
)
126+
resized_mask = distorted_mask.resize((new_width, size - vertical_margin))
118127
background_width = width if width > 0 else new_width + horizontal_margin
119128
background_height = size
120129
# Vertical text
@@ -126,6 +135,9 @@ def generate(
126135
resized_img = distorted_img.resize(
127136
(size - horizontal_margin, new_height), Image.ANTIALIAS
128137
)
138+
resized_mask = distorted_mask.resize(
139+
(size - horizontal_margin, new_height), Image.ANTIALIAS
140+
)
129141
background_width = size
130142
background_height = new_height + vertical_margin
131143
else:
@@ -135,21 +147,22 @@ def generate(
135147
# Generate background image #
136148
#############################
137149
if background_type == 0:
138-
background = background_generator.gaussian_noise(
150+
background_img = background_generator.gaussian_noise(
139151
background_height, background_width
140152
)
141153
elif background_type == 1:
142-
background = background_generator.plain_white(
154+
background_img = background_generator.plain_white(
143155
background_height, background_width
144156
)
145157
elif background_type == 2:
146-
background = background_generator.quasicrystal(
158+
background_img = background_generator.quasicrystal(
147159
background_height, background_width
148160
)
149161
else:
150-
background = background_generator.picture(
162+
background_img = background_generator.picture(
151163
background_height, background_width
152164
)
165+
background_mask = Image.new("RGB", (background_width, background_height), (0, 0, 0))
153166

154167
#############################
155168
# Place text with alignment #
@@ -158,45 +171,62 @@ def generate(
158171
new_text_width, _ = resized_img.size
159172

160173
if alignment == 0 or width == -1:
161-
background.paste(resized_img, (margin_left, margin_top), resized_img)
174+
background_img.paste(resized_img, (margin_left, margin_top), resized_img)
175+
background_mask.paste(resized_mask, (margin_left, margin_top))
162176
elif alignment == 1:
163-
background.paste(
177+
background_img.paste(
164178
resized_img,
165179
(int(background_width / 2 - new_text_width / 2), margin_top),
166180
resized_img,
167181
)
182+
background_mask.paste(
183+
resized_mask,
184+
(int(background_width / 2 - new_text_width / 2), margin_top),
185+
)
168186
else:
169-
background.paste(
187+
background_img.paste(
170188
resized_img,
171189
(background_width - new_text_width - margin_right, margin_top),
172190
resized_img,
173191
)
192+
background_mask.paste(
193+
resized_mask,
194+
(background_width - new_text_width - margin_right, margin_top),
195+
)
174196

175197
##################################
176198
# Apply gaussian blur #
177199
##################################
178200

179-
final_image = background.filter(
180-
ImageFilter.GaussianBlur(
181-
radius=(blur if not random_blur else rnd.randint(0, blur))
182-
)
201+
gaussian_filter = ImageFilter.GaussianBlur(
202+
radius=blur if not random_blur else rnd.randint(0, blur)
183203
)
204+
final_image = background_img.filter(gaussian_filter)
205+
final_mask = background_mask.filter(gaussian_filter)
184206

185207
#####################################
186208
# Generate name for resulting image #
187209
#####################################
188210
if name_format == 0:
189211
image_name = "{}_{}.{}".format(text, str(index), extension)
212+
mask_name = "{}_{}_mask.png".format(text, str(index))
190213
elif name_format == 1:
191214
image_name = "{}_{}.{}".format(str(index), text, extension)
215+
mask_name = "{}_{}_mask.png".format(str(index), text)
192216
elif name_format == 2:
193217
image_name = "{}.{}".format(str(index), extension)
218+
mask_name = "{}_mask.png".format(str(index))
194219
else:
195220
print("{} is not a valid name format. Using default.".format(name_format))
196221
image_name = "{}_{}.{}".format(text, str(index), extension)
222+
mask_name = "{}_{}_mask.png".format(text, str(index))
197223

198224
# Save the image
199225
if out_dir is not None:
200226
final_image.convert("RGB").save(os.path.join(out_dir, image_name))
227+
if output_mask == 1:
228+
final_mask.convert("RGB").save(os.path.join(out_dir, mask_name))
201229
else:
230+
if output_mask == 1:
231+
return final_image.convert("RGB"), final_mask.convert("RGB")
202232
return final_image.convert("RGB")

trdg/dicts/TEST TEST TEST_4.jpg

2.29 KB
Loading

0 commit comments

Comments
 (0)