Skip to content

Commit 21b9a62

Browse files
committed
minor fix due to numpy bahavior diff across versions
1 parent ded3d6a commit 21b9a62

2 files changed

Lines changed: 1 addition & 4 deletions

File tree

tpcav/concepts.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,6 @@ def _construct_motif_concept_dataloader_from_control(
6262
wds.RandomMix(datasets),
6363
batch_size=batch_size,
6464
num_workers=num_workers,
65-
pin_memory=True,
6665
drop_last=False,
6766
)
6867
return mixed_dl

tpcav/utils.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -192,9 +192,7 @@ def __iter__(self):
192192
if worker_info is None:
193193
chunk = self.seq_df
194194
else:
195-
chunk = np.array_split(self.seq_df, worker_info.num_workers)[
196-
worker_info.id
197-
]
195+
chunk = [self.seq_df.iloc[idx] for idx in np.array_split(np.arange(len(self.seq_df)), worker_info.num_workers)][worker_info.id]
198196
yield from iterate_seq_df_chunk(
199197
chunk,
200198
genome=self.genome,

0 commit comments

Comments
 (0)