Skip to content

Commit 5aa1472

Browse files
Fix data type (#26)
Fixes # da ta type error ### Description File "/localhome/local-mingxueg/VISTA/scripts/utils/trans_utils.py", line 352, in __call__ pred[pred == frac] = data["label_prompt"][i - 1].to(pred.dtype) AttributeError: 'int' object has no attribute 'to' ### Types of changes pred[pred == frac] = torch.tensor(data["label_prompt"][i - 1]).to(pred.dtype) --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 6448762 commit 5aa1472

1 file changed

Lines changed: 3 additions & 1 deletion

File tree

scripts/utils/trans_utils.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -349,7 +349,9 @@ def __call__(
349349
pred += 0.5 # inplace mapping to avoid cloning pred
350350
for i in range(1, object_num + 1):
351351
frac = i + 0.5
352-
pred[pred == frac] = data["label_prompt"][i - 1].to(pred.dtype)
352+
pred[pred == frac] = torch.tensor(
353+
data["label_prompt"][i - 1]
354+
).to(pred.dtype)
353355
pred[pred == 0.5] = 0.0
354356
data[keys] = pred
355357
return data

0 commit comments

Comments
 (0)