Skip to content

Commit d4cc602

Browse files
committed
✅ ensure the weight loading approach match hf
1 parent b2d21db commit d4cc602

4 files changed

Lines changed: 120 additions & 111 deletions

File tree

generate_transfer.py

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,6 @@
1313
from modeling.translation import TranslationDiffusion
1414

1515

16-
def copy_parameters(from_parameters, to_parameters):
17-
to_parameters = list(to_parameters)
18-
assert len(from_parameters) == len(to_parameters)
19-
for s_param, param in zip(from_parameters, to_parameters):
20-
param.data.copy_(s_param.to(param.device).data)
21-
22-
2316
def parse_args():
2417
parser = argparse.ArgumentParser()
2518
parser.add_argument("--config", default=None, type=str)
@@ -51,20 +44,12 @@ def generate_image(
5144
):
5245
torch.cuda.set_device(device)
5346
model = build_model(cfg).to(device)
54-
if cfg.MODEL.PRETRAINED:
55-
logger.info(f"Loading pretrained model from {cfg.MODEL.PRETRAINED}")
56-
weight = torch.load(cfg.MODEL.PRETRAINED, map_location=device)
57-
copy_parameters(weight["ema_state_dict"]["shadow_params"], model.parameters())
58-
del weight
59-
torch.cuda.empty_cache()
6047
model.eval()
6148

6249
diffuser = TranslationDiffusion(cfg, device)
6350
os.makedirs(args.save_folder, exist_ok=True)
6451

65-
progress_bar = tqdm(total=len(source_list), position=int(device.split(":")[-1]))
6652
count_error = 0
67-
6853
with tqdm(total=len(source_list), position=int(device.split(":")[-1])) as progress_bar:
6954
for idx, ((source_image, source_mask), (target_image, target_mask)) in enumerate(
7055
zip(source_list, target_list)

generate_translation.py

Lines changed: 31 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,6 @@
1313
from modeling.translation import TranslationDiffusion
1414

1515

16-
def copy_parameters(from_parameters, to_parameters):
17-
to_parameters = list(to_parameters)
18-
assert len(from_parameters) == len(to_parameters)
19-
for s_param, param in zip(from_parameters, to_parameters):
20-
param.data.copy_(s_param.to(param.device).data)
21-
22-
2316
def parse_args():
2417
parser = argparse.ArgumentParser()
2518
parser.add_argument("--config", default=None, type=str)
@@ -46,46 +39,42 @@ def generate_image(
4639
):
4740
torch.cuda.set_device(device)
4841
model = build_model(cfg).to(device)
49-
if cfg.MODEL.PRETRAINED:
50-
logger.info(f"Loading pretrained model from {cfg.MODEL.PRETRAINED}")
51-
weight = torch.load(cfg.MODEL.PRETRAINED, map_location="cpu")
52-
copy_parameters(weight["ema_state_dict"]["shadow_params"], model.parameters())
53-
del weight
54-
torch.cuda.empty_cache()
5542
model.eval()
5643

5744
diffuser = TranslationDiffusion(cfg, device)
5845
os.makedirs(args.save_folder, exist_ok=True)
5946

60-
progress_bar = tqdm(total=len(source_list), position=int(device.split(":")[-1]))
6147
count_error = 0
62-
63-
for idx, (source_image, source_mask) in enumerate(source_list):
64-
save_image_name = os.path.join(save_folder, f"pred_{idx + offset}.png")
65-
if os.path.exists(save_image_name):
66-
progress_bar.update(1)
67-
continue
68-
if source_mask.endswith("jpg"):
69-
source_mask = source_mask.replace("jpg", "png")
70-
try:
71-
transfer_result = diffuser.domain_translation(
72-
source_model=model,
73-
target_model=model,
74-
source_image=source_image,
75-
source_class_label=source_label,
76-
target_class_label=target_label,
77-
parsing_mask=source_mask,
78-
start_from_step=num_of_step,
48+
with tqdm(
49+
total=len(source_list), position=int(device.split(":")[-1])
50+
) as progress_bar:
51+
for idx, (source_image, source_mask) in enumerate(source_list):
52+
save_image_name = os.path.join(save_folder, f"pred_{idx + offset}.png")
53+
if os.path.exists(save_image_name):
54+
progress_bar.update(1)
55+
continue
56+
if source_mask.endswith("jpg"):
57+
source_mask = source_mask.replace("jpg", "png")
58+
try:
59+
transfer_result = diffuser.domain_translation(
60+
source_model=model,
61+
target_model=model,
62+
source_image=source_image,
63+
source_class_label=source_label,
64+
target_class_label=target_label,
65+
parsing_mask=source_mask,
66+
start_from_step=num_of_step,
67+
)
68+
except Exception as e:
69+
logger.error(str(e))
70+
count_error += 1
71+
continue
72+
save_image = Image.fromarray(
73+
(transfer_result[0].permute(1, 2, 0).cpu().numpy() * 255).astype(
74+
"uint8"
75+
)
7976
)
80-
except Exception as e:
81-
logger.error(str(e))
82-
count_error += 1
83-
continue
84-
save_image = Image.fromarray(
85-
(transfer_result[0].permute(1, 2, 0).cpu().numpy() * 255).astype("uint8")
86-
)
87-
save_image.save(save_image_name)
88-
progress_bar.close()
77+
save_image.save(save_image_name)
8978

9079
if count_error != 0:
9180
print(f"Error in {device}: {count_error}")
@@ -114,8 +103,7 @@ def generate_image(
114103
cfg,
115104
args.save_folder,
116105
source_list=source_list[
117-
gpu_idx
118-
* task_per_process : (
106+
gpu_idx * task_per_process : (
119107
((gpu_idx + 1) * task_per_process)
120108
if gpu_idx < args.num_process - 1
121109
else len(source_list)
@@ -129,3 +117,4 @@ def generate_image(
129117
)
130118
for gpu_idx in range(args.num_process)
131119
)
120+

0 commit comments

Comments
 (0)