|
10 | 10 | from PIL import Image, ImageColor |
11 | 11 | from collections import namedtuple |
12 | 12 |
|
| 13 | +def download_model_weights(): |
| 14 | + from pathlib import Path |
| 15 | + import urllib.request |
| 16 | + cwd = os.path.dirname(os.path.abspath(__file__)) |
| 17 | + for k in ['model-29.data-00000-of-00001','model-29.index','model-29.meta','translation.pkl']: |
| 18 | + download_dir = Path(cwd)/'handwritten_model/' |
| 19 | + download_dir.mkdir(exist_ok=True,parents=True) |
| 20 | + if (download_dir/f'{k}').exists(): continue |
| 21 | + print(f'file {k} not found, downloading from git repo..') |
| 22 | + urllib.request.urlretrieve( |
| 23 | + f'https://raw.github.com/Belval/TextRecognitionDataGenerator/master/trdg/handwritten_model/{k}', |
| 24 | + download_dir/f'{k}') |
| 25 | + print(f'file {k} saved to disk') |
| 26 | + return cwd |
13 | 27 |
|
14 | 28 | def _sample(e, mu1, mu2, std1, std2, rho): |
15 | 29 | cov = np.array([[std1 * std1, std1 * std2 * rho], [std1 * std2 * rho, std2 * std2]]) |
@@ -57,7 +71,7 @@ def _sample_text(sess, args_text, translation): |
57 | 71 | "finish", |
58 | 72 | "zero_states", |
59 | 73 | ] |
60 | | - vs = namedtuple("Params", fields)(*[tf.get_collection(name)[0] for name in fields]) |
| 74 | + vs = namedtuple("Params", fields)(*[tf.compat.v1.get_collection(name)[0] for name in fields]) |
61 | 75 |
|
62 | 76 | text = np.array([translation.get(c, 0) for c in args_text]) |
63 | 77 | sequence = np.eye(len(translation), dtype=np.float32)[text] |
@@ -148,14 +162,15 @@ def _join_images(images): |
148 | 162 |
|
149 | 163 |
|
150 | 164 | def generate(text, text_color): |
151 | | - with open(os.path.join("handwritten_model", "translation.pkl"), "rb") as file: |
| 165 | + cd = download_model_weights() |
| 166 | + with open(os.path.join(cd, os.path.join("handwritten_model", "translation.pkl")), "rb") as file: |
152 | 167 | translation = pickle.load(file) |
153 | 168 |
|
154 | | - config = tf.ConfigProto(device_count={"GPU": 0}) |
155 | | - tf.reset_default_graph() |
156 | | - with tf.Session(config=config) as sess: |
157 | | - saver = tf.train.import_meta_graph("handwritten_model/model-29.meta") |
158 | | - saver.restore(sess, "handwritten_model/model-29") |
| 169 | + config = tf.compat.v1.ConfigProto(device_count={"GPU": 0}) |
| 170 | + tf.compat.v1.reset_default_graph() |
| 171 | + with tf.compat.v1.Session(config=config) as sess: |
| 172 | + saver = tf.compat.v1.train.import_meta_graph(os.path.join(cd,"handwritten_model/model-29.meta")) |
| 173 | + saver.restore(sess,os.path.join(cd,os.path.join("handwritten_model/model-29"))) |
159 | 174 | images = [] |
160 | 175 | colors = [ImageColor.getrgb(c) for c in text_color.split(",")] |
161 | 176 | c1, c2 = colors[0], colors[-1] |
@@ -188,12 +203,15 @@ def generate(text, text_color): |
188 | 203 |
|
189 | 204 | canvas = plt.get_current_fig_manager().canvas |
190 | 205 | canvas.draw() |
191 | | - |
| 206 | + |
| 207 | + s, (width, height) = canvas.print_to_buffer() |
192 | 208 | image = Image.frombytes( |
193 | | - "RGBA", canvas.get_width_height(), canvas.buffer_rgba() |
| 209 | + "RGBA", (width, height), s |
194 | 210 | ) |
| 211 | + mask = Image.new("RGB", (width, height), (0, 0, 0)) |
| 212 | + |
195 | 213 | images.append(_crop_white_borders(image)) |
196 | 214 |
|
197 | 215 | plt.close() |
198 | 216 |
|
199 | | - return _join_images(images) |
| 217 | + return _join_images(images), mask |
0 commit comments