Skip to content

Commit a0cf992

Browse files
committed
Add switch in shell scripts to run the process adaptively. COD/HRSOD needs lower LR.
1 parent 1cb9f3f commit a0cf992

4 files changed

Lines changed: 46 additions & 35 deletions

File tree

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ Our BiRefNet has achieved SOTA on many similar HR tasks:
3030

3131
+ **Inference and evaluation** of your given weights: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1MaEiBfJ4xIaZZn0DqKrhydHB8X97hNXl#scrollTo=DJ4meUYjia6S)
3232
+ **Online Inference with GUI** with adjustable resolutions: [![Hugging Face Spaces](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue)](https://huggingface.co/spaces/ZhengPeng7/BiRefNet_demo)
33+
+ Online **Single Image Inference** on Colab: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/14Dqg7oeBkFEtchaHLNpig2BcdkZEogba?usp=drive_link)
3334
<img src="https://drive.google.com/thumbnail?id=12XmDhKtO1o2fEvBu4OE4ULVB2BK0ecWi&sz=w1620" />
3435

3536
## Third-Party Creations

config.py

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import os
22
import math
3-
import torch
43

54

65
class Config():
@@ -13,7 +12,7 @@ def __init__(self) -> None:
1312
self.training_set = {
1413
'DIS5K': ['DIS-TR', 'DIS-TR+DIS-TE1+DIS-TE2+DIS-TE3+DIS-TE4'][0],
1514
'COD': 'TR-COD10K+TR-CAMO',
16-
'HRSOD': ['TR-DUTS', 'TR-DUTS+TR-HRSOD', 'TR-DUTS+TR-UHRSD', 'TR-HRSOD+TR-UHRSD', 'TR-DUTS+TR-HRSOD+TR-UHRSD'][3],
15+
'HRSOD': ['TR-DUTS', 'TR-HRSOD', 'TR-UHRSD', 'TR-DUTS+TR-HRSOD', 'TR-DUTS+TR-UHRSD', 'TR-HRSOD+TR-UHRSD', 'TR-DUTS+TR-HRSOD+TR-UHRSD'][5],
1716
'DIS5K+HRSOD+HRS10K': 'DIS-TE1+DIS-TE2+DIS-TE3+DIS-TE4+DIS-TR+TE-HRS10K+TE-HRSOD+TE-UHRSD+TR-HRS10K+TR-HRSOD+TR-UHRSD', # leave DIS-VD for evaluation.
1817
'P3M-10k': 'TR-P3M-10k',
1918
}[self.task]
@@ -39,14 +38,14 @@ def __init__(self) -> None:
3938
self.IoU_finetune_last_epochs = [
4039
0,
4140
{
42-
'DIS5K': -100,
43-
'COD': -30,
44-
'HRSOD': -30,
45-
'DIS5K+HRSOD+HRS10K': -50,
46-
'P3M-10k': -30,
41+
'DIS5K': -50,
42+
'COD': -20,
43+
'HRSOD': -20,
44+
'DIS5K+HRSOD+HRS10K': -20,
45+
'P3M-10k': -20,
4746
}[self.task]
4847
][1] # choose 0 to skip
49-
self.lr = 1e-4 * math.sqrt(self.batch_size / 4) # adapt the lr linearly
48+
self.lr = (1e-4 if 'DIS5K' in self.task else 1e-5) * math.sqrt(self.batch_size / 4) # DIS needs high lr to converge faster. Adapt the lr linearly
5049
self.size = 1024
5150
self.num_workers = max(4, self.batch_size) # will be decrease to min(it, batch_size) at the initialization of the data_loader
5251

@@ -76,7 +75,7 @@ def __init__(self) -> None:
7675
self.progressive_ref = self.refine and True
7776
self.ender = self.progressive_ref and False
7877
self.scale = self.progressive_ref and 2
79-
self.auxiliary_classification = False
78+
self.auxiliary_classification = False # Only for DIS5K, where class labels are saved in `dataset.py`.
8079
self.refine_iteration = 1
8180
self.freeze_bb = False
8281
self.model = [
@@ -131,13 +130,22 @@ def __init__(self) -> None:
131130
self.SDPA_enabled = False # Bugs. Slower and errors occur in multi-GPUs
132131

133132
# others
134-
self.device = [0, 'cpu'][0 if torch.cuda.is_available() else 1] # .to(0) == .to('cuda:0')
133+
self.device = [0, 'cpu'][0] # .to(0) == .to('cuda:0')
135134

136135
self.batch_size_valid = 1
137136
self.rand_seed = 7
138137
run_sh_file = [f for f in os.listdir('.') if 'train.sh' == f] + [os.path.join('..', f) for f in os.listdir('..') if 'train.sh' == f]
139138
with open(run_sh_file[0], 'r') as f:
140139
lines = f.readlines()
141-
self.save_last = int([l.strip() for l in lines if 'val_last=' in l][0].split('=')[-1])
142-
self.save_step = int([l.strip() for l in lines if 'step=' in l][0].split('=')[-1])
140+
self.save_last = int([l.strip() for l in lines if '"{}")'.format(self.task) in l and 'val_last=' in l][0].split('val_last=')[-1].split()[0])
141+
self.save_step = int([l.strip() for l in lines if '"{}")'.format(self.task) in l and 'step=' in l][0].split('step=')[-1].split()[0])
143142
self.val_step = [0, self.save_step][0]
143+
144+
def print_task(self) -> None:
145+
# Return task for choosing settings in shell scripts.
146+
print(self.task)
147+
148+
if __name__ == '__main__':
149+
config = Config()
150+
config.print_task()
151+

test.sh

Lines changed: 16 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -8,24 +8,21 @@ CUDA_VISIBLE_DEVICES=${devices} python inference.py --pred_root ${pred_root}
88
echo Inference finished at $(date)
99

1010
# Evaluation
11-
log_dir=e_logs
12-
mkdir ${log_dir}
13-
14-
testsets=DIS-VD && nohup python eval_existingOnes.py --pred_root ${pred_root} --data_lst ${testsets} > ${log_dir}/eval_${testsets}.out 2>&1 &
15-
testsets=DIS-TE1 && nohup python eval_existingOnes.py --pred_root ${pred_root} --data_lst ${testsets} > ${log_dir}/eval_${testsets}.out 2>&1 &
16-
testsets=DIS-TE2 && nohup python eval_existingOnes.py --pred_root ${pred_root} --data_lst ${testsets} > ${log_dir}/eval_${testsets}.out 2>&1 &
17-
testsets=DIS-TE3 && nohup python eval_existingOnes.py --pred_root ${pred_root} --data_lst ${testsets} > ${log_dir}/eval_${testsets}.out 2>&1 &
18-
testsets=DIS-TE4 && nohup python eval_existingOnes.py --pred_root ${pred_root} --data_lst ${testsets} > ${log_dir}/eval_${testsets}.out 2>&1 &
19-
20-
# testsets=CHAMELEON && nohup python eval_existingOnes.py --pred_root ${pred_root} --data_lst ${testsets} > ${log_dir}/eval_${testsets}.out 2>&1 &
21-
# testsets=NC4K && nohup python eval_existingOnes.py --pred_root ${pred_root} --data_lst ${testsets} > ${log_dir}/eval_${testsets}.out 2>&1 &
22-
# testsets=TE-CAMO && nohup python eval_existingOnes.py --pred_root ${pred_root} --data_lst ${testsets} > ${log_dir}/eval_${testsets}.out 2>&1 &
23-
# testsets=TE-COD10K && nohup python eval_existingOnes.py --pred_root ${pred_root} --data_lst ${testsets} > ${log_dir}/eval_${testsets}.out 2>&1 &
24-
25-
# testsets=DAVIS-S && nohup python eval_existingOnes.py --pred_root ${pred_root} --data_lst ${testsets} > ${log_dir}/eval_${testsets}.out 2>&1 &
26-
# testsets=TE-HRSOD && nohup python eval_existingOnes.py --pred_root ${pred_root} --data_lst ${testsets} > ${log_dir}/eval_${testsets}.out 2>&1 &
27-
# testsets=TE-UHRSD && nohup python eval_existingOnes.py --pred_root ${pred_root} --data_lst ${testsets} > ${log_dir}/eval_${testsets}.out 2>&1 &
28-
# testsets=DUT-OMRON && nohup python eval_existingOnes.py --pred_root ${pred_root} --data_lst ${testsets} > ${log_dir}/eval_${testsets}.out 2>&1 &
29-
# testsets=TE-DUTS && nohup python eval_existingOnes.py --pred_root ${pred_root} --data_lst ${testsets} > ${log_dir}/eval_${testsets}.out 2>&1 &
11+
log_dir=e_logs && mkdir ${log_dir}
12+
13+
task=$(python3 config.py)
14+
case "${task}" in
15+
"DIS5K") testsets='DIS-VD,DIS-TE1,DIS-TE2,DIS-TE3,DIS-TE4' ;;
16+
"COD") testsets='CHAMELEON,NC4K,TE-CAMO,TE-COD10K' ;;
17+
"HRSOD") testsets='DAVIS-S,TE-HRSOD,TE-UHRSD,DUT-OMRON,TE-DUTS' ;;
18+
"DIS5K+HRSOD+HRS10K") testsets='DIS-VD' ;;
19+
"P3M-10k") testsets='TE-P3M-500-P,TE-P3M-500-NP' ;;
20+
esac
21+
testsets=(`echo ${testsets} | tr ',' ' '`) && testsets=${testsets[@]}
22+
23+
for testset in ${testsets}; do
24+
nohup python eval_existingOnes.py --pred_root ${pred_root} --data_lst ${testset} > ${log_dir}/eval_${testset}.out 2>&1 &
25+
done
26+
3027

3128
echo Evaluation started at $(date)

train.sh

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,15 @@
11
#!/bin/bash
22
# Run script
3-
# DIS/COD/HRSOD/massive/P3M-10k: epochs,val_last,step:[600,200,10]/[150,50,10]/[150,50,10]/[300,100,10]/[150,50,10]
3+
# Settings of training & test for different tasks.
44
method="$1"
5-
epochs=600
6-
val_last=200
7-
step=10
5+
task=$(python3 config.py)
6+
case "${task}" in
7+
"DIS5K") epochs=600 && val_last=200 && step=10 ;;
8+
"COD") epochs=150 && val_last=50 && step=5 ;;
9+
"HRSOD") epochs=150 && val_last=50 && step=5 ;;
10+
"DIS5K+HRSOD+HRS10K") epochs=300 && val_last=50 && step=5 ;;
11+
"P3M-10k") epochs=150 && val_last=50 && step=5 ;;
12+
esac
813
testsets=NO # Non-existing folder to skip.
914
# testsets=TE-COD10K # for COD
1015

0 commit comments

Comments
 (0)