11import argparse
22import json
33import os
4- from typing import List , Tuple
4+ from typing import List , Optional , Tuple
55
6- import torch
76from joblib import Parallel , delayed
87from loguru import logger
9- from PIL import Image
108from tqdm import tqdm
119
12- from config import create_cfg , merge_possible_with_base , show_config
13- from modeling import build_model
1410from modeling .text_translation import TextTranslationDiffusion
1511
1612
@@ -23,34 +19,36 @@ def copy_parameters(from_parameters, to_parameters):
2319
2420def parse_args ():
2521 parser = argparse .ArgumentParser ()
26- parser .add_argument ("--config" , default = None , type = str )
2722 parser .add_argument ("--save-folder" , default = "batch_images" , type = str )
2823 parser .add_argument ("--source-root" , required = True , type = str )
2924 parser .add_argument ("--source-list" , required = True , type = str )
30- parser .add_argument ("--source-label" , required = True , type = int )
3125 parser .add_argument ("--num-process" , default = 1 , type = int )
3226 parser .add_argument ("--num-of-step" , default = 180 , type = int )
33- parser .add_argument ("--opts" , nargs = argparse .REMAINDER , default = None , type = str )
27+ parser .add_argument ("--img-size" , default = 512 , type = int )
28+ parser .add_argument ("--model-path" , default = None , type = str )
29+ parser .add_argument ("--scheduler" , default = "ddpm" , type = str )
30+ parser .add_argument ("--sample-steps" , default = 1000 , type = int )
3431 return parser .parse_args ()
3532
3633
3734def generate_image (
38- cfg ,
35+ img_size : int ,
3936 save_folder : str ,
4037 source_list : List [Tuple [str , str ]],
41- source_label : int ,
4238 offset : int ,
4339 device : str ,
4440 num_of_step : int ,
41+ scheduler : str ,
42+ sample_steps : int ,
43+ model_path : Optional [str ] = None ,
4544):
46- model = build_model (cfg ).to (device )
47- if cfg .MODEL .PRETRAINED :
48- logger .info (f"Loading pretrained model from { cfg .MODEL .PRETRAINED } " )
49- weight = torch .load (cfg .MODEL .PRETRAINED , map_location = device )
50- copy_parameters (weight ["ema_state_dict" ]["shadow_params" ], model .parameters ())
51- del weight
52- torch .cuda .empty_cache ()
53- diffuser = TextTranslationDiffusion (cfg , device = device )
45+ diffuser = TextTranslationDiffusion (
46+ img_size = img_size ,
47+ scheduler = scheduler ,
48+ device = device ,
49+ model_path = model_path ,
50+ sample_steps = sample_steps ,
51+ )
5452 os .makedirs (args .save_folder , exist_ok = True )
5553
5654 progress_bar = tqdm (total = len (source_list ), position = int (device .split (":" )[- 1 ]))
@@ -65,8 +63,6 @@ def generate_image(
6563 source_mask = source_mask .replace ("jpg" , "png" )
6664 try :
6765 editing_result = diffuser .modify_with_text (
68- model = model ,
69- source_label = source_label ,
7066 image = source_image ,
7167 mask = source_mask ,
7268 prompt = [editing_prompt ],
@@ -77,9 +73,7 @@ def generate_image(
7773 logger .error (str (e ))
7874 count_error += 1
7975 continue
80- save_image = Image .fromarray (
81- (editing_result [0 ].permute (1 , 2 , 0 ).cpu ().numpy () * 255 ).astype ("uint8" )
82- )
76+ save_image = editing_result [0 ]
8377 save_image .save (save_image_name )
8478 progress_bar .update (1 )
8579 progress_bar .close ()
@@ -90,12 +84,6 @@ def generate_image(
9084
9185if __name__ == "__main__" :
9286 args = parse_args ()
93- cfg = create_cfg ()
94- if args .config :
95- merge_possible_with_base (cfg , args .config )
96- if args .opts :
97- cfg .merge_from_list (args .opts )
98- show_config (cfg )
9987
10088 with open (args .source_list , "r" ) as f :
10189 data = json .load (f )
@@ -118,7 +106,6 @@ def generate_image(
118106 else len (source_list )
119107 )
120108 ],
121- source_label = args .source_label ,
122109 offset = gpu_idx * task_per_process ,
123110 device = f"cuda:{ gpu_idx } " ,
124111 num_of_step = args .num_of_step ,
0 commit comments