-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathtrain.py
More file actions
132 lines (96 loc) · 4.45 KB
/
train.py
File metadata and controls
132 lines (96 loc) · 4.45 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
import os
import csv
import torch
import torch.nn.functional as F
import torch.distributed as dist
from tqdm import tqdm
from losses import pixel_reprojection_loss
from utils import create_mask, adjust_learning_rate
def save_checkpoint(model, epoch, args, local_rank):
savefilename = os.path.join(args.save_ckpt_path, f'train_ckpt_{epoch}.tar')
if args.is_distributed:
if args.rank == 0:
torch.save({'state_dict': model.module.state_dict()}, savefilename)
else:
torch.save({'state_dict': model.state_dict()}, savefilename)
def log_results(epoch, train_loss, valid_loss, args, fieldnames):
with open(args.save_csv_file_path, 'a', newline='') as csvfile:
writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
writer.writerow({
'epoch': epoch,
'train_loss': train_loss,
'valid_loss': valid_loss,
})
def train_one_epoch(epoch, model, optimizer, train_loader, local_rank, args, loss_r, scaler):
model.train()
total_train_loss = 0.0
# Learning rate adjustment and distributed training setup
adjust_learning_rate(optimizer, epoch)
if args.is_distributed:
train_loader.sampler.set_epoch(epoch)
# Training loop
for batch_idx, (left, right, disps) in tqdm(enumerate(train_loader), total=len(train_loader)):
disp_16x, disp_8x, disp_4x, disp = disps
left, right, disp, disp_4x, disp_8x, disp_16x = [
tensor.to(local_rank).float() for tensor in
[left, right, disp, disp_4x, disp_8x, disp_16x]
]
optimizer.zero_grad()
with torch.cuda.amp.autocast(enabled=True):
disp4, disp3, disp2, disp1 = model(left, right)
# Create masks for different disparity scales
masks = [
create_mask(d, args.maxdisp // (1 if i == 0 else 2**(i+1)), args.mindisp // (1 if i == 0 else 2**(i+1)))
for i, d in enumerate([disp, disp_4x, disp_8x, disp_16x])
]
# Compute losses for different scales
losses = [
F.smooth_l1_loss(pred[mask], gt[mask], reduction="mean")
for pred, (gt, mask) in zip([disp1, disp2, disp3, disp4], masks)
]
# Add pixel reprojection loss
losses.append(pixel_reprojection_loss(left, right, disp1, masks[0][1]))
# Compute total loss
loss = loss_r(args, tuple(losses[:-1]), losses[-1])
if args.is_distributed:
dist.all_reduce(loss, op=dist.ReduceOp.SUM)
loss = loss / args.world_size
# loss.backward()
# optimizer.step()
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
if args.rank == 0:
total_train_loss += loss.detach().cpu().numpy()
return total_train_loss / len(train_loader)
def validate_one_epoch(epoch, model, valid_loader, local_rank, args, loss_r):
model.eval()
total_valid_loss = 0.0
if args.is_distributed:
valid_loader.sampler.set_epoch(epoch)
# Validation loop
with torch.no_grad():
with torch.cuda.amp.autocast(enabled=True):
for batch_idx, (left, right, disp) in tqdm(enumerate(valid_loader), total=len(valid_loader)):
disp_16x, disp_8x, disp_4x, disp = disp
left, right, disp, disp_4x, disp_8x, disp_16x = [
tensor.to(local_rank).float() for tensor in
[left, right, disp, disp_4x, disp_8x, disp_16x]
]
disp4, disp3, disp2, disp1 = model(left, right)
masks = [
create_mask(d, args.maxdisp // (1 if i == 0 else 2**(i+1)), args.mindisp // (1 if i == 0 else 2**(i+1)))
for i, d in enumerate([disp, disp_4x, disp_8x, disp_16x])
]
losses = [
F.smooth_l1_loss(pred[mask], gt[mask], reduction="mean")
for pred, (gt, mask) in zip([disp1, disp2, disp3, disp4], masks)
]
losses.append(pixel_reprojection_loss(left, right, disp1, masks[0][1]))
loss = loss_r(args, tuple(losses[:-1]), losses[-1])
if args.is_distributed:
dist.all_reduce(loss, op=dist.ReduceOp.SUM)
loss = loss / args.world_size
if args.rank == 0:
total_valid_loss += loss.detach().cpu().numpy()
return total_valid_loss / len(valid_loader)