Skip to content

Commit 9d0d73b

Browse files
authored
Add text_color support for handwritten
1 parent 720a216 commit 9d0d73b

3 files changed

Lines changed: 17 additions & 7 deletions

File tree

TextRecognitionDataGenerator/computer_text_generator.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,9 @@ def __generate_horizontal_text(cls, text, font, text_color, font_size, space_wid
3030
c1, c2 = colors[0], colors[-1]
3131

3232
fill = (
33-
random.randint(c1[0], c2[0]),
34-
random.randint(c1[1], c2[1]),
35-
random.randint(c1[2], c2[2])
33+
random.randint(min(c1[0], c2[0]), max(c1[0], c2[0])),
34+
random.randint(min(c1[1], c2[1]), max(c1[1], c2[1])),
35+
random.randint(min(c1[2], c2[2]), max(c1[2], c2[2]))
3636
)
3737

3838
for i, w in enumerate(words):

TextRecognitionDataGenerator/data_generator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def generate(cls, index, text, font, out_dir, size, extension, skewing_angle, ra
3030
if is_handwritten:
3131
if orientation == 1:
3232
raise ValueError("Vertical handwritten text is unavailable")
33-
image = HandwrittenTextGenerator.generate(text)
33+
image = HandwrittenTextGenerator.generate(text, text_color)
3434
else:
3535
image = ComputerTextGenerator.generate(text, font, text_color, size, orientation, space_width)
3636

TextRecognitionDataGenerator/handwritten_text_generator.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
import os
22
import pickle
33
import numpy as np
4+
import random
45
import tensorflow as tf
56
import matplotlib.pyplot as plt
67
import matplotlib.cm as cm
78
import matplotlib.mlab as mlab
89
import seaborn
9-
from PIL import Image
10+
from PIL import Image, ImageColor
1011
from collections import namedtuple
1112

1213
class HandwrittenTextGenerator(object):
@@ -115,7 +116,7 @@ def __join_images(cls, images):
115116
return compound_image
116117

117118
@classmethod
118-
def generate(cls, text):
119+
def generate(cls, text, text_color):
119120
with open(os.path.join('handwritten_model', 'translation.pkl'), 'rb') as file:
120121
translation = pickle.load(file)
121122

@@ -127,6 +128,15 @@ def generate(cls, text):
127128
saver = tf.train.import_meta_graph('handwritten_model/model-29.meta')
128129
saver.restore(sess, 'handwritten_model/model-29')
129130
images = []
131+
colors = [ImageColor.getrgb(c) for c in text_color.split(',')]
132+
c1, c2 = colors[0], colors[-1]
133+
134+
color = '#{:02x}{:02x}{:02x}'.format(
135+
random.randint(min(c1[0], c2[0]), max(c1[0], c2[0])),
136+
random.randint(min(c1[1], c2[1]), max(c1[1], c2[1])),
137+
random.randint(min(c1[2], c2[2]), max(c1[2], c2[2]))
138+
)
139+
130140
for word in text.split(' '):
131141
_, window_data, kappa_data, stroke_data, coords = cls.__sample_text(sess, word, translation)
132142

@@ -140,7 +150,7 @@ def generate(cls, text):
140150
ax.axis('off')
141151

142152
for stroke in cls.__split_strokes(cls.__cumsum(np.array(coords))):
143-
plt.plot(stroke[:, 0], -stroke[:, 1], color='#080808')
153+
plt.plot(stroke[:, 0], -stroke[:, 1], color=color)
144154

145155
fig.patch.set_alpha(0)
146156
fig.patch.set_facecolor('none')

0 commit comments

Comments
 (0)