Skip to content

Commit 2832aab

Browse files
committed
Add option of structure loss and link to single image inference on colab to README.
1 parent bda8a1a commit 2832aab

7 files changed

Lines changed: 31 additions & 10 deletions

File tree

README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,8 @@ Choose the one you like to try with clicks instead of codes:
4949
+ Thanks [**viperyl/ComfyUI-BiRefNet**](https://github.com/viperyl/ComfyUI-BiRefNet): this project packs BiRefNet as **ComfyUI nodes**, and makes this SOTA model easier use for everyone.
5050
<p align="center"><img src="https://drive.google.com/thumbnail?id=1KfxCQUUa2y9T-aysEaeVVjCUt3Z0zSkL&sz=w1620" /></p>
5151

52+
+ Thanks [**Rishabh**](https://github.com/rishabh063) for offerring a demo for the [easier single image inference on colab](https://colab.research.google.com/drive/14Dqg7oeBkFEtchaHLNpig2BcdkZEogba?usp=drive_link).
53+
5254
2. **More Visual Comparisons**
5355
+ Thanks [**twitter.com/ZHOZHO672070**](https://twitter.com/ZHOZHO672070) for the comparison with more background-removal methods in images:
5456

config.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ def __init__(self) -> None:
1313
self.training_set = {
1414
'DIS5K': 'DIS-TR',
1515
'COD': 'TR-COD10K+TR-CAMO',
16-
'HRSOD': ['TR-DUTS', 'TR-HRSOD+TR-UHRSD', 'TR-DUTS+TR-HRSOD+TR-UHRSD'][1],
16+
'HRSOD': ['TR-DUTS', 'TR-DUTS+TR-HRSOD', 'TR-DUTS+TR-UHRSD', 'TR-HRSOD+TR-UHRSD', 'TR-DUTS+TR-HRSOD+TR-UHRSD'][3],
1717
'DIS5K+HRSOD+HRS10K': 'DIS-TE1+DIS-TE2+DIS-TE3+DIS-TE4+DIS-TR+TE-HRS10K+TE-HRSOD+TE-UHRSD+TR-HRS10K+TR-HRSOD+TR-UHRSD', # leave DIS-VD for evaluation.
1818
'P3M-10k': 'TR-P3M-10k',
1919
}[self.task]
@@ -102,6 +102,7 @@ def __init__(self) -> None:
102102
'reg': 100 * 0,
103103
'ssim': 10 * 1, # help contours,
104104
'cnt': 5 * 0, # help contours
105+
'structure': 5 * 0, # structure loss
105106
}
106107
self.lambdas_cls = {
107108
'ce': 5.0

eval_existingOnes.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ def do_eval(opt):
121121
opt = parser.parse_args()
122122

123123
os.makedirs(opt.save_dir, exist_ok=True)
124-
opt.model_lst = [m for m in sorted(os.listdir(opt.pred_root), key=lambda x: int(x.split('ep')[-1]), reverse=True) if int(m.split('ep')[-1]) % 1 == 0]
124+
opt.model_lst = [m for m in sorted(os.listdir(opt.pred_root), key=lambda x: int(x.split('epoch_')[-1]), reverse=True) if int(m.split('epoch_')[-1]) % 1 == 0]
125125

126126
# check the integrity of each candidates
127127
if opt.check_integrity:

inference.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ def main(args):
5454
model = BiRefNet(bb_pretrained=False)
5555
weights_lst = sorted(
5656
glob(os.path.join(args.ckpt_folder, '*.pth')) if args.ckpt_folder else [args.ckpt],
57-
key=lambda x: int(x.split('ep')[-1].split('.pth')[0]),
57+
key=lambda x: int(x.split('epoch_')[-1].split('.pth')[0]),
5858
reverse=True
5959
)
6060
for testset in args.testsets.split('+'):
@@ -64,7 +64,7 @@ def main(args):
6464
batch_size=config.batch_size_valid, shuffle=False, num_workers=config.num_workers, pin_memory=True
6565
)
6666
for weights in weights_lst:
67-
if int(weights.strip('.pth').split('ep')[-1]) % 1 != 0:
67+
if int(weights.strip('.pth').split('epoch_')[-1]) % 1 != 0:
6868
continue
6969
print('\tInferencing {}...'.format(weights))
7070
# model.load_state_dict(torch.load(weights, map_location='cpu'))

loss.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,23 @@ def forward(self, pred, target):
8686
return IoU
8787

8888

89+
class StructureLoss(torch.nn.Module):
90+
def __init__(self):
91+
super(StructureLoss, self).__init__()
92+
93+
def forward(self, pred, target):
94+
weit = 1+5*torch.abs(F.avg_pool2d(target, kernel_size=31, stride=1, padding=15)-target)
95+
wbce = F.binary_cross_entropy_with_logits(pred, target, reduction='none')
96+
wbce = (weit*wbce).sum(dim=(2,3))/weit.sum(dim=(2,3))
97+
98+
pred = torch.sigmoid(pred)
99+
inter = ((pred * target) * weit).sum(dim=(2, 3))
100+
union = ((pred + target) * weit).sum(dim=(2, 3))
101+
wiou = 1-(inter+1)/(union-inter+1)
102+
103+
return (wbce+wiou).mean()
104+
105+
89106
class PatchIoULoss(torch.nn.Module):
90107
def __init__(self):
91108
super(PatchIoULoss, self).__init__()
@@ -158,15 +175,16 @@ def __init__(self):
158175
self.criterions_last['reg'] = ThrReg_loss()
159176
if 'cnt' in self.lambdas_pix_last and self.lambdas_pix_last['cnt']:
160177
self.criterions_last['cnt'] = ContourLoss()
178+
if 'structure' in self.lambdas_pix_last and self.lambdas_pix_last['structure']:
179+
self.criterions_last['structure'] = StructureLoss()
161180

162181
def forward(self, scaled_preds, gt):
163182
loss = 0.
164183
for _, pred_lvl in enumerate(scaled_preds):
165184
if pred_lvl.shape != gt.shape:
166185
pred_lvl = nn.functional.interpolate(pred_lvl, size=gt.shape[2:], mode='bilinear', align_corners=True)
167-
pred_lvl = pred_lvl.sigmoid()
168186
for criterion_name, criterion in self.criterions_last.items():
169-
_loss = criterion(pred_lvl, gt) * self.lambdas_pix_last[criterion_name]
187+
_loss = criterion(pred_lvl.sigmoid() if criterion_name not in ('structure',) else pred_lvl, gt) * self.lambdas_pix_last[criterion_name]
170188
loss += _loss
171189
# print(criterion_name, _loss.item())
172190
return loss

train.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ def init_models_optimizers(epochs, to_be_distributed):
101101
state_dict = torch.load(args.resume, map_location='cpu')
102102
state_dict = check_state_dict(state_dict)
103103
model.load_state_dict(state_dict)
104-
epoch_st = int(args.resume.rstrip('.pth').split('ep')[-1]) + 1
104+
epoch_st = int(args.resume.rstrip('.pth').split('epoch_')[-1]) + 1
105105
else:
106106
logger.info("=> no checkpoint found at '{}'".format(args.resume))
107107
if to_be_distributed:
@@ -301,7 +301,7 @@ def main():
301301
if epoch >= args.epochs - config.save_last and epoch % config.save_step == 0:
302302
torch.save(
303303
trainer.model.module.state_dict() if to_be_distributed else trainer.model.state_dict(),
304-
os.path.join(args.ckpt_dir, 'ep{}.pth'.format(epoch))
304+
os.path.join(args.ckpt_dir, 'epoch_{}.pth'.format(epoch))
305305
)
306306
if config.val_step and epoch >= args.epochs - config.save_last and (args.epochs - epoch) % config.val_step == 0:
307307
if to_be_distributed:

waiting4eval.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ def main():
111111
models_detected = [
112112
m for idx_m, m in enumerate(sorted(
113113
glob(os.path.join(ckpt_dir, '*.pth')),
114-
key=lambda x: int(x.rstrip('.pth').split('ep')[-1]), reverse=True
114+
key=lambda x: int(x.rstrip('.pth').split('epoch_')[-1]), reverse=True
115115
)) if idx_m % args_eval.val_step == args_eval.program_id and m not in models_evaluated + models_evaluated_global
116116
]
117117
if models_detected:
@@ -127,7 +127,7 @@ def main():
127127
# evaluate the current model
128128
state_dict = torch.load(model_not_evaluated_latest, map_location=device)
129129
model.load_state_dict(state_dict, strict=False)
130-
validate_model(model, test_loaders, int(model_not_evaluated_latest.rstrip('.pth').split('ep')[-1]))
130+
validate_model(model, test_loaders, int(model_not_evaluated_latest.rstrip('.pth').split('epoch_')[-1]))
131131
continous_sleep_time = 0
132132
print('Duration of this evaluation:', time() - time_st)
133133
else:

0 commit comments

Comments
 (0)