1313from modeling .translation import TranslationDiffusion
1414
1515
16- def copy_parameters (from_parameters , to_parameters ):
17- to_parameters = list (to_parameters )
18- assert len (from_parameters ) == len (to_parameters )
19- for s_param , param in zip (from_parameters , to_parameters ):
20- param .data .copy_ (s_param .to (param .device ).data )
21-
22-
2316def parse_args ():
2417 parser = argparse .ArgumentParser ()
2518 parser .add_argument ("--config" , default = None , type = str )
@@ -46,46 +39,42 @@ def generate_image(
4639):
4740 torch .cuda .set_device (device )
4841 model = build_model (cfg ).to (device )
49- if cfg .MODEL .PRETRAINED :
50- logger .info (f"Loading pretrained model from { cfg .MODEL .PRETRAINED } " )
51- weight = torch .load (cfg .MODEL .PRETRAINED , map_location = "cpu" )
52- copy_parameters (weight ["ema_state_dict" ]["shadow_params" ], model .parameters ())
53- del weight
54- torch .cuda .empty_cache ()
5542 model .eval ()
5643
5744 diffuser = TranslationDiffusion (cfg , device )
5845 os .makedirs (args .save_folder , exist_ok = True )
5946
60- progress_bar = tqdm (total = len (source_list ), position = int (device .split (":" )[- 1 ]))
6147 count_error = 0
62-
63- for idx , (source_image , source_mask ) in enumerate (source_list ):
64- save_image_name = os .path .join (save_folder , f"pred_{ idx + offset } .png" )
65- if os .path .exists (save_image_name ):
66- progress_bar .update (1 )
67- continue
68- if source_mask .endswith ("jpg" ):
69- source_mask = source_mask .replace ("jpg" , "png" )
70- try :
71- transfer_result = diffuser .domain_translation (
72- source_model = model ,
73- target_model = model ,
74- source_image = source_image ,
75- source_class_label = source_label ,
76- target_class_label = target_label ,
77- parsing_mask = source_mask ,
78- start_from_step = num_of_step ,
48+ with tqdm (
49+ total = len (source_list ), position = int (device .split (":" )[- 1 ])
50+ ) as progress_bar :
51+ for idx , (source_image , source_mask ) in enumerate (source_list ):
52+ save_image_name = os .path .join (save_folder , f"pred_{ idx + offset } .png" )
53+ if os .path .exists (save_image_name ):
54+ progress_bar .update (1 )
55+ continue
56+ if source_mask .endswith ("jpg" ):
57+ source_mask = source_mask .replace ("jpg" , "png" )
58+ try :
59+ transfer_result = diffuser .domain_translation (
60+ source_model = model ,
61+ target_model = model ,
62+ source_image = source_image ,
63+ source_class_label = source_label ,
64+ target_class_label = target_label ,
65+ parsing_mask = source_mask ,
66+ start_from_step = num_of_step ,
67+ )
68+ except Exception as e :
69+ logger .error (str (e ))
70+ count_error += 1
71+ continue
72+ save_image = Image .fromarray (
73+ (transfer_result [0 ].permute (1 , 2 , 0 ).cpu ().numpy () * 255 ).astype (
74+ "uint8"
75+ )
7976 )
80- except Exception as e :
81- logger .error (str (e ))
82- count_error += 1
83- continue
84- save_image = Image .fromarray (
85- (transfer_result [0 ].permute (1 , 2 , 0 ).cpu ().numpy () * 255 ).astype ("uint8" )
86- )
87- save_image .save (save_image_name )
88- progress_bar .close ()
77+ save_image .save (save_image_name )
8978
9079 if count_error != 0 :
9180 print (f"Error in { device } : { count_error } " )
@@ -114,8 +103,7 @@ def generate_image(
114103 cfg ,
115104 args .save_folder ,
116105 source_list = source_list [
117- gpu_idx
118- * task_per_process : (
106+ gpu_idx * task_per_process : (
119107 ((gpu_idx + 1 ) * task_per_process )
120108 if gpu_idx < args .num_process - 1
121109 else len (source_list )
@@ -129,3 +117,4 @@ def generate_image(
129117 )
130118 for gpu_idx in range (args .num_process )
131119 )
120+
0 commit comments