|
9 | 9 | import seaborn |
10 | 10 | from PIL import Image, ImageColor |
11 | 11 | from collections import namedtuple |
| 12 | +import warnings |
| 13 | + |
| 14 | +warnings.filterwarnings("ignore") |
| 15 | + |
12 | 16 |
|
13 | 17 | def download_model_weights(): |
14 | 18 | from pathlib import Path |
15 | | - import urllib.request |
| 19 | + import urllib.request |
| 20 | + |
16 | 21 | cwd = os.path.dirname(os.path.abspath(__file__)) |
17 | | - for k in ['model-29.data-00000-of-00001','model-29.index','model-29.meta','translation.pkl']: |
18 | | - download_dir = Path(cwd)/'handwritten_model/' |
19 | | - download_dir.mkdir(exist_ok=True,parents=True) |
20 | | - if (download_dir/f'{k}').exists(): continue |
21 | | - print(f'file {k} not found, downloading from git repo..') |
| 22 | + for k in [ |
| 23 | + "model-29.data-00000-of-00001", |
| 24 | + "model-29.index", |
| 25 | + "model-29.meta", |
| 26 | + "translation.pkl", |
| 27 | + ]: |
| 28 | + download_dir = Path(cwd) / "handwritten_model/" |
| 29 | + download_dir.mkdir(exist_ok=True, parents=True) |
| 30 | + if (download_dir / f"{k}").exists(): |
| 31 | + continue |
| 32 | + print(f"file {k} not found, downloading from git repo..") |
22 | 33 | urllib.request.urlretrieve( |
23 | | - f'https://raw.github.com/Belval/TextRecognitionDataGenerator/master/trdg/handwritten_model/{k}', |
24 | | - download_dir/f'{k}') |
25 | | - print(f'file {k} saved to disk') |
| 34 | + f"https://raw.github.com/Belval/TextRecognitionDataGenerator/master/trdg/handwritten_model/{k}", |
| 35 | + download_dir / f"{k}", |
| 36 | + ) |
| 37 | + print(f"file {k} saved to disk") |
26 | 38 | return cwd |
27 | 39 |
|
| 40 | + |
28 | 41 | def _sample(e, mu1, mu2, std1, std2, rho): |
29 | 42 | cov = np.array([[std1 * std1, std1 * std2 * rho], [std1 * std2 * rho, std2 * std2]]) |
30 | 43 | mean = np.array([mu1, mu2]) |
@@ -71,7 +84,9 @@ def _sample_text(sess, args_text, translation): |
71 | 84 | "finish", |
72 | 85 | "zero_states", |
73 | 86 | ] |
74 | | - vs = namedtuple("Params", fields)(*[tf.compat.v1.get_collection(name)[0] for name in fields]) |
| 87 | + vs = namedtuple("Params", fields)( |
| 88 | + *[tf.compat.v1.get_collection(name)[0] for name in fields] |
| 89 | + ) |
75 | 90 |
|
76 | 91 | text = np.array([translation.get(c, 0) for c in args_text]) |
77 | 92 | sequence = np.eye(len(translation), dtype=np.float32)[text] |
@@ -163,14 +178,20 @@ def _join_images(images): |
163 | 178 |
|
164 | 179 | def generate(text, text_color): |
165 | 180 | cd = download_model_weights() |
166 | | - with open(os.path.join(cd, os.path.join("handwritten_model", "translation.pkl")), "rb") as file: |
| 181 | + with open( |
| 182 | + os.path.join(cd, os.path.join("handwritten_model", "translation.pkl")), "rb" |
| 183 | + ) as file: |
167 | 184 | translation = pickle.load(file) |
168 | 185 |
|
169 | 186 | config = tf.compat.v1.ConfigProto(device_count={"GPU": 0}) |
170 | 187 | tf.compat.v1.reset_default_graph() |
171 | 188 | with tf.compat.v1.Session(config=config) as sess: |
172 | | - saver = tf.compat.v1.train.import_meta_graph(os.path.join(cd,"handwritten_model/model-29.meta")) |
173 | | - saver.restore(sess,os.path.join(cd,os.path.join("handwritten_model/model-29"))) |
| 189 | + saver = tf.compat.v1.train.import_meta_graph( |
| 190 | + os.path.join(cd, "handwritten_model/model-29.meta") |
| 191 | + ) |
| 192 | + saver.restore( |
| 193 | + sess, os.path.join(cd, os.path.join("handwritten_model/model-29")) |
| 194 | + ) |
174 | 195 | images = [] |
175 | 196 | colors = [ImageColor.getrgb(c) for c in text_color.split(",")] |
176 | 197 | c1, c2 = colors[0], colors[-1] |
@@ -203,13 +224,11 @@ def generate(text, text_color): |
203 | 224 |
|
204 | 225 | canvas = plt.get_current_fig_manager().canvas |
205 | 226 | canvas.draw() |
206 | | - |
| 227 | + |
207 | 228 | s, (width, height) = canvas.print_to_buffer() |
208 | | - image = Image.frombytes( |
209 | | - "RGBA", (width, height), s |
210 | | - ) |
| 229 | + image = Image.frombytes("RGBA", (width, height), s) |
211 | 230 | mask = Image.new("RGB", (width, height), (0, 0, 0)) |
212 | | - |
| 231 | + |
213 | 232 | images.append(_crop_white_borders(image)) |
214 | 233 |
|
215 | 234 | plt.close() |
|
0 commit comments