Skip to content

Commit 23787e3

Browse files
committed
ts data test fix
1 parent 86a71fc commit 23787e3

1 file changed

Lines changed: 5 additions & 2 deletions

File tree

tests/test_TestConvertXyToTimeSeries.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,16 +28,19 @@ def setUp(self):
2828
# Assuming max_seq_length is defined
2929
self.max_seq_length = 10
3030

31+
# Get the actual number of unique clients in the training set
32+
self.num_unique_clients_train = self.X_train['client_idcode'].nunique()
33+
3134
# Converting train data into time series format
3235
self.X_train_ts, self.y_train_ts = convert_Xy_to_time_series(
3336
self.X_train, self.y_train, self.max_seq_length
3437
)
3538

3639
def test_X_train_ts_shape(self):
37-
self.assertEqual(self.X_train_ts.shape, (100, 10, 10))
40+
self.assertEqual(self.X_train_ts.shape, (self.num_unique_clients_train, 10, 10))
3841

3942
def test_y_train_ts_shape(self):
40-
self.assertEqual(self.y_train_ts.shape, (100,))
43+
self.assertEqual(self.y_train_ts.shape, (self.num_unique_clients_train,))
4144

4245
def test_returns_tuple(self):
4346
result = convert_Xy_to_time_series(

0 commit comments

Comments
 (0)