Skip to content

Commit 4582945

Browse files
Add download model weights + fix for mask
1 parent 7cf75c1 commit 4582945

1 file changed

Lines changed: 28 additions & 10 deletions

File tree

trdg/handwritten_text_generator.py

Lines changed: 28 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,20 @@
1010
from PIL import Image, ImageColor
1111
from collections import namedtuple
1212

13+
def download_model_weights():
14+
from pathlib import Path
15+
import urllib.request
16+
cwd = os.path.dirname(os.path.abspath(__file__))
17+
for k in ['model-29.data-00000-of-00001','model-29.index','model-29.meta','translation.pkl']:
18+
download_dir = Path(cwd)/'handwritten_model/'
19+
download_dir.mkdir(exist_ok=True,parents=True)
20+
if (download_dir/f'{k}').exists(): continue
21+
print(f'file {k} not found, downloading from git repo..')
22+
urllib.request.urlretrieve(
23+
f'https://raw.github.com/Belval/TextRecognitionDataGenerator/master/trdg/handwritten_model/{k}',
24+
download_dir/f'{k}')
25+
print(f'file {k} saved to disk')
26+
return cwd
1327

1428
def _sample(e, mu1, mu2, std1, std2, rho):
1529
cov = np.array([[std1 * std1, std1 * std2 * rho], [std1 * std2 * rho, std2 * std2]])
@@ -57,7 +71,7 @@ def _sample_text(sess, args_text, translation):
5771
"finish",
5872
"zero_states",
5973
]
60-
vs = namedtuple("Params", fields)(*[tf.get_collection(name)[0] for name in fields])
74+
vs = namedtuple("Params", fields)(*[tf.compat.v1.get_collection(name)[0] for name in fields])
6175

6276
text = np.array([translation.get(c, 0) for c in args_text])
6377
sequence = np.eye(len(translation), dtype=np.float32)[text]
@@ -148,14 +162,15 @@ def _join_images(images):
148162

149163

150164
def generate(text, text_color):
151-
with open(os.path.join("handwritten_model", "translation.pkl"), "rb") as file:
165+
cd = download_model_weights()
166+
with open(os.path.join(cd, os.path.join("handwritten_model", "translation.pkl")), "rb") as file:
152167
translation = pickle.load(file)
153168

154-
config = tf.ConfigProto(device_count={"GPU": 0})
155-
tf.reset_default_graph()
156-
with tf.Session(config=config) as sess:
157-
saver = tf.train.import_meta_graph("handwritten_model/model-29.meta")
158-
saver.restore(sess, "handwritten_model/model-29")
169+
config = tf.compat.v1.ConfigProto(device_count={"GPU": 0})
170+
tf.compat.v1.reset_default_graph()
171+
with tf.compat.v1.Session(config=config) as sess:
172+
saver = tf.compat.v1.train.import_meta_graph(os.path.join(cd,"handwritten_model/model-29.meta"))
173+
saver.restore(sess,os.path.join(cd,os.path.join("handwritten_model/model-29")))
159174
images = []
160175
colors = [ImageColor.getrgb(c) for c in text_color.split(",")]
161176
c1, c2 = colors[0], colors[-1]
@@ -188,12 +203,15 @@ def generate(text, text_color):
188203

189204
canvas = plt.get_current_fig_manager().canvas
190205
canvas.draw()
191-
206+
207+
s, (width, height) = canvas.print_to_buffer()
192208
image = Image.frombytes(
193-
"RGBA", canvas.get_width_height(), canvas.buffer_rgba()
209+
"RGBA", (width, height), s
194210
)
211+
mask = Image.new("RGB", (width, height), (0, 0, 0))
212+
195213
images.append(_crop_white_borders(image))
196214

197215
plt.close()
198216

199-
return _join_images(images)
217+
return _join_images(images), mask

0 commit comments

Comments
 (0)