@@ -58,10 +58,6 @@ def shuffle_tensor(self, *tensors):
5858 def fit (self , train_avs : np .ndarray , train_ls : np .ndarray ,
5959 val_avs : np .ndarray , val_ls : np .ndarray ,
6060 patience = 10 , lr = 1e-2 , weight_decay = 1e-2 , max_epochs = 1000 ):
61- train_avs = torch .from_numpy (train_avs )
62- train_ls = torch .from_numpy (train_ls )
63- val_avs = torch .from_numpy (val_avs )
64- val_ls = torch .from_numpy (val_ls )
6561
6662 optimizer = torch .optim .AdamW (self .parameters (), lr = lr , weight_decay = weight_decay )
6763
@@ -80,8 +76,8 @@ def fit(self, train_avs: np.ndarray, train_ls: np.ndarray,
8076 for i in range (0 , len (train_avs ), 32 ):
8177 optimizer .zero_grad ()
8278
83- avs = train_avs [i : (i + 32 )]
84- l = train_ls [i : (i + 32 )]
79+ avs = torch . from_numpy ( train_avs [i : (i + 32 )])
80+ l = torch . from_numpy ( train_ls [i : (i + 32 )])
8581
8682 y_hat = self (avs .to (self .device ))
8783 loss = torch .mean (torch .clamp (1 - l .to (self .device ) * y_hat , min = 0 ))
@@ -94,8 +90,8 @@ def fit(self, train_avs: np.ndarray, train_ls: np.ndarray,
9490 self .eval ()
9591 val_loss_all = []
9692 for i in range (0 , len (val_avs ), 32 ):
97- avs = val_avs [i : (i + 32 )]
98- l = val_ls [i : (i + 32 )]
93+ avs = torch . from_numpy ( val_avs [i : (i + 32 )])
94+ l = torch . from_numpy ( val_ls [i : (i + 32 )])
9995
10096 y_hat = self (avs .to (self .device ))
10197 val_loss = torch .mean (torch .clamp (1 - l .to (self .device ) * y_hat , min = 0 ))
0 commit comments