Skip to content

Commit 972d5d3

Browse files
committed
minor improve logic
1 parent 5fb6431 commit 972d5d3

1 file changed

Lines changed: 4 additions & 8 deletions

File tree

tpcav/cavs.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -58,10 +58,6 @@ def shuffle_tensor(self, *tensors):
5858
def fit(self, train_avs: np.ndarray, train_ls: np.ndarray,
5959
val_avs: np.ndarray, val_ls: np.ndarray,
6060
patience=10, lr=1e-2, weight_decay=1e-2, max_epochs=1000):
61-
train_avs = torch.from_numpy(train_avs)
62-
train_ls = torch.from_numpy(train_ls)
63-
val_avs = torch.from_numpy(val_avs)
64-
val_ls = torch.from_numpy(val_ls)
6561

6662
optimizer = torch.optim.AdamW(self.parameters(), lr=lr, weight_decay=weight_decay)
6763

@@ -80,8 +76,8 @@ def fit(self, train_avs: np.ndarray, train_ls: np.ndarray,
8076
for i in range(0, len(train_avs), 32):
8177
optimizer.zero_grad()
8278

83-
avs = train_avs[i: (i+32)]
84-
l = train_ls[i: (i+32)]
79+
avs = torch.from_numpy(train_avs[i: (i+32)])
80+
l = torch.from_numpy(train_ls[i: (i+32)])
8581

8682
y_hat = self(avs.to(self.device))
8783
loss = torch.mean(torch.clamp(1 - l.to(self.device) * y_hat, min=0))
@@ -94,8 +90,8 @@ def fit(self, train_avs: np.ndarray, train_ls: np.ndarray,
9490
self.eval()
9591
val_loss_all = []
9692
for i in range(0, len(val_avs), 32):
97-
avs = val_avs[i: (i+32)]
98-
l = val_ls[i: (i+32)]
93+
avs = torch.from_numpy(val_avs[i: (i+32)])
94+
l = torch.from_numpy(val_ls[i: (i+32)])
9995

10096
y_hat = self(avs.to(self.device))
10197
val_loss = torch.mean(torch.clamp(1 - l.to(self.device) * y_hat, min=0))

0 commit comments

Comments
 (0)