11import os
22import pickle
33import numpy as np
4+ import random
45import tensorflow as tf
56import matplotlib .pyplot as plt
67import matplotlib .cm as cm
78import matplotlib .mlab as mlab
89import seaborn
9- from PIL import Image
10+ from PIL import Image , ImageColor
1011from collections import namedtuple
1112
1213class HandwrittenTextGenerator (object ):
@@ -115,7 +116,7 @@ def __join_images(cls, images):
115116 return compound_image
116117
117118 @classmethod
118- def generate (cls , text ):
119+ def generate (cls , text , text_color ):
119120 with open (os .path .join ('handwritten_model' , 'translation.pkl' ), 'rb' ) as file :
120121 translation = pickle .load (file )
121122
@@ -127,6 +128,15 @@ def generate(cls, text):
127128 saver = tf .train .import_meta_graph ('handwritten_model/model-29.meta' )
128129 saver .restore (sess , 'handwritten_model/model-29' )
129130 images = []
131+ colors = [ImageColor .getrgb (c ) for c in text_color .split (',' )]
132+ c1 , c2 = colors [0 ], colors [- 1 ]
133+
134+ color = '#{:02x}{:02x}{:02x}' .format (
135+ random .randint (min (c1 [0 ], c2 [0 ]), max (c1 [0 ], c2 [0 ])),
136+ random .randint (min (c1 [1 ], c2 [1 ]), max (c1 [1 ], c2 [1 ])),
137+ random .randint (min (c1 [2 ], c2 [2 ]), max (c1 [2 ], c2 [2 ]))
138+ )
139+
130140 for word in text .split (' ' ):
131141 _ , window_data , kappa_data , stroke_data , coords = cls .__sample_text (sess , word , translation )
132142
@@ -140,7 +150,7 @@ def generate(cls, text):
140150 ax .axis ('off' )
141151
142152 for stroke in cls .__split_strokes (cls .__cumsum (np .array (coords ))):
143- plt .plot (stroke [:, 0 ], - stroke [:, 1 ], color = '#080808' )
153+ plt .plot (stroke [:, 0 ], - stroke [:, 1 ], color = color )
144154
145155 fig .patch .set_alpha (0 )
146156 fig .patch .set_facecolor ('none' )
0 commit comments